reworked scalar

This commit is contained in:
Jean-Philippe Bossuat
2025-04-30 23:11:43 +02:00
parent 6f7b93c7ca
commit 9ade995cd7
8 changed files with 311 additions and 338 deletions

View File

@@ -1,6 +1,6 @@
use base2k::{
Encoding, FFT64, Module, Sampling, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned,
Encoding, FFT64, Module, Sampling, Scalar, ScalarOps, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps,
VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned,
};
use itertools::izip;
use sampling::source::Source;
@@ -19,14 +19,14 @@ fn main() {
let mut source: Source = Source::new(seed);
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
let mut s: Scalar = Scalar::new(n);
s.fill_ternary_prob(0.5, &mut source);
let mut s: Scalar = module.new_scalar(1);
s.fill_ternary_prob(0, 0.5, &mut source);
// Buffer to store s in the DFT domain
let mut s_dft: ScalarZnxDft<FFT64> = module.new_scalar_znx_dft();
let mut s_dft: ScalarZnxDft<FFT64> = module.new_scalar_znx_dft(s.cols());
// s_dft <- DFT(s)
module.svp_prepare(&mut s_dft, &s);
module.svp_prepare(&mut s_dft, 0, &s, 0);
// Allocates a VecZnx with two columns: ct=(0, 0)
let mut ct: VecZnx = module.new_vec_znx(
@@ -48,6 +48,7 @@ fn main() {
&mut buf_dft, // DFT(ct[1] * s)
0, // Selects the first column of res
&s_dft, // DFT(s)
0, // Selects the first column of s_dft
&ct,
1, // Selects the second column of ct
);
@@ -106,6 +107,7 @@ fn main() {
&mut buf_dft,
0, // Selects the first column of res.
&s_dft,
0,
&ct,
1, // Selects the second column of ct (ct[1])
);

View File

@@ -5,7 +5,9 @@ pub mod ffi;
pub mod mat_znx_dft;
pub mod module;
pub mod sampling;
pub mod scalar_znx;
pub mod scalar_znx_dft;
pub mod scalar_znx_dft_ops;
pub mod stats;
pub mod vec_znx;
pub mod vec_znx_big;
@@ -19,8 +21,11 @@ pub use encoding::*;
pub use mat_znx_dft::*;
pub use module::*;
pub use sampling::*;
#[allow(unused_imports)]
pub use scalar_znx::*;
pub use scalar_znx_dft::*;
#[allow(unused_imports)]
pub use scalar_znx_dft_ops::*;
pub use stats::*;
pub use vec_znx::*;
pub use vec_znx_big::*;
@@ -50,13 +55,13 @@ pub fn assert_alignement<T>(ptr: *const T) {
pub fn cast<T, V>(data: &[T]) -> &[V] {
let ptr: *const V = data.as_ptr() as *const V;
let len: usize = data.len() / std::mem::size_of::<V>();
let len: usize = data.len() / size_of::<V>();
unsafe { std::slice::from_raw_parts(ptr, len) }
}
pub fn cast_mut<T, V>(data: &[T]) -> &mut [V] {
let ptr: *mut V = data.as_ptr() as *mut V;
let len: usize = data.len() / std::mem::size_of::<V>();
let len: usize = data.len() / size_of::<V>();
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
}
@@ -70,7 +75,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
align
);
assert_eq!(
(size * std::mem::size_of::<u8>()) % align,
(size * size_of::<u8>()) % align,
0,
"size={} must be a multiple of align={}",
size,
@@ -98,22 +103,25 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
assert_eq!(
(size * std::mem::size_of::<T>()) % align,
(size * size_of::<T>()) % align,
0,
"size={} must be a multiple of align={}",
size,
align
);
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(std::mem::size_of::<T>() * size, align);
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(size_of::<T>() * size, align);
let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T;
let len: usize = vec_u8.len() / std::mem::size_of::<T>();
let cap: usize = vec_u8.capacity() / std::mem::size_of::<T>();
let len: usize = vec_u8.len() / size_of::<T>();
let cap: usize = vec_u8.capacity() / size_of::<T>();
std::mem::forget(vec_u8);
unsafe { Vec::from_raw_parts(ptr, len, cap) }
}
/// Allocates an aligned of size equal to the smallest multiple
/// of [DEFAULTALIGN] that is equal or greater to `size`.
/// Allocates an aligned vector of size equal to the smallest multiple
/// of [DEFAULTALIGN]/size_of::<T>() that is equal or greater to `size`.
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
alloc_aligned_custom::<T>(size + (size % DEFAULTALIGN), DEFAULTALIGN)
alloc_aligned_custom::<T>(
size + (size % (DEFAULTALIGN / size_of::<T>())),
DEFAULTALIGN,
)
}

View File

@@ -160,7 +160,7 @@ pub trait MatZnxDftOps<B: Backend> {
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
/// * `a`: [MatZnxDft] on which the values are encoded.
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, row_i: usize);
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<B>, row_i: usize, a: &MatZnxDft<B>);
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
///
@@ -170,7 +170,7 @@ pub trait MatZnxDftOps<B: Backend> {
/// * `a_size`: number of size of the input [VecZnx].
/// * `rows`: number of rows of the input [MatZnxDft].
/// * `size`: number of size of the input [MatZnxDft].
fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize;
fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
///
@@ -404,7 +404,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
}
}
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<FFT64>, a: &MatZnxDft<FFT64>, row_i: usize) {
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<FFT64>, row_i: usize, a: &MatZnxDft<FFT64>) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
@@ -422,14 +422,14 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
}
}
fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize {
fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize {
unsafe {
vmp::vmp_apply_dft_tmp_bytes(
self.ptr,
res_size as u64,
a_size as u64,
gct_rows as u64,
gct_size as u64,
b_rows as u64,
b_size as u64,
) as usize
}
}
@@ -595,7 +595,7 @@ mod tests {
assert_eq!(vmpmat_0.raw(), vmpmat_1.raw());
// Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft)
module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i);
module.vmp_extract_row_dft(&mut b_dft, row_i, &vmpmat_0);
assert_eq!(a_dft.raw(), b_dft.raw());
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)

113
base2k/src/scalar_znx.rs Normal file
View File

@@ -0,0 +1,113 @@
use crate::znx_base::{ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, GetZnxBase, Module, VecZnx};
use rand::seq::SliceRandom;
use rand_core::RngCore;
use rand_distr::{Distribution, weighted::WeightedIndex};
use sampling::source::Source;
pub const SCALAR_ZNX_ROWS: usize = 1;
pub const SCALAR_ZNX_SIZE: usize = 1;
pub struct Scalar {
pub inner: ZnxBase,
}
impl GetZnxBase for Scalar {
fn znx(&self) -> &ZnxBase {
&self.inner
}
fn znx_mut(&mut self) -> &mut ZnxBase {
&mut self.inner
}
}
impl ZnxInfos for Scalar {}
impl<B: Backend> ZnxAlloc<B> for Scalar {
type Scalar = i64;
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self {
Self {
inner: ZnxBase::from_bytes_borrow(module.n(), SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes),
}
}
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, _size: usize) -> usize {
debug_assert_eq!(
_rows, SCALAR_ZNX_ROWS,
"rows != {} not supported for Scalar",
SCALAR_ZNX_ROWS
);
debug_assert_eq!(
_size, SCALAR_ZNX_SIZE,
"rows != {} not supported for Scalar",
SCALAR_ZNX_SIZE
);
module.n() * cols * std::mem::size_of::<self::Scalar>()
}
}
impl ZnxLayout for Scalar {
type Scalar = i64;
}
impl ZnxSliceSize for Scalar {
fn sl(&self) -> usize {
self.n()
}
}
impl Scalar {
pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) {
let choices: [i64; 3] = [-1, 0, 1];
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
let dist: WeightedIndex<f64> = WeightedIndex::new(&weights).unwrap();
self.at_mut(col, 0)
.iter_mut()
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
}
pub fn fill_ternary_hw(&mut self, col: usize, hw: usize, source: &mut Source) {
assert!(hw <= self.n());
self.at_mut(col, 0)[..hw]
.iter_mut()
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
self.at_mut(col, 0).shuffle(source);
}
pub fn alias_as_vec_znx(&self) -> VecZnx {
VecZnx {
inner: ZnxBase {
n: self.n(),
rows: 1,
cols: 1,
size: 1,
data: Vec::new(),
ptr: self.ptr() as *mut u8,
},
}
}
}
pub trait ScalarOps {
fn bytes_of_scalar(&self, cols: usize) -> usize;
fn new_scalar(&self, cols: usize) -> Scalar;
fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> Scalar;
fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar;
}
impl<B: Backend> ScalarOps for Module<B> {
fn bytes_of_scalar(&self, cols: usize) -> usize {
Scalar::bytes_of(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE)
}
fn new_scalar(&self, cols: usize) -> Scalar {
Scalar::new(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE)
}
fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> Scalar {
Scalar::from_bytes(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes)
}
fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar {
Scalar::from_bytes_borrow(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes)
}
}

View File

@@ -1,279 +1,66 @@
use std::marker::PhantomData;
use crate::ffi::svp::{self, svp_ppol_t};
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::znx_base::{ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement, cast_mut};
use rand::seq::SliceRandom;
use rand_core::RngCore;
use rand_distr::{Distribution, weighted::WeightedIndex};
use sampling::source::Source;
use crate::ffi::svp;
use crate::znx_base::{ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, GetZnxBase, Module};
pub struct Scalar {
pub n: usize,
pub data: Vec<i64>,
pub ptr: *mut i64,
}
impl<B: Backend> Module<B> {
pub fn new_scalar(&self) -> Scalar {
Scalar::new(self.n())
}
}
impl Scalar {
pub fn new(n: usize) -> Self {
let mut data: Vec<i64> = alloc_aligned::<i64>(n);
let ptr: *mut i64 = data.as_mut_ptr();
Self {
n: n,
data: data,
ptr: ptr,
}
}
pub fn n(&self) -> usize {
self.n
}
pub fn bytes_of(n: usize) -> usize {
n * std::mem::size_of::<i64>()
}
pub fn from_bytes(n: usize, bytes: &mut [u8]) -> Self {
let size: usize = Self::bytes_of(n);
debug_assert!(
bytes.len() == size,
"invalid buffer: bytes.len()={} < self.bytes_of(n={})={}",
bytes.len(),
n,
size
);
#[cfg(debug_assertions)]
{
assert_alignement(bytes.as_ptr())
}
unsafe {
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
Self {
n: n,
data: Vec::from_raw_parts(bytes_i64.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: ptr,
}
}
}
pub fn from_bytes_borrow(n: usize, bytes: &mut [u8]) -> Self {
let size: usize = Self::bytes_of(n);
debug_assert!(
bytes.len() == size,
"invalid buffer: bytes.len()={} < self.bytes_of(n={})={}",
bytes.len(),
n,
size
);
#[cfg(debug_assertions)]
{
assert_alignement(bytes.as_ptr())
}
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
Self {
n: n,
data: Vec::new(),
ptr: ptr,
}
}
pub fn as_ptr(&self) -> *const i64 {
self.ptr
}
pub fn raw(&self) -> &[i64] {
unsafe { std::slice::from_raw_parts(self.ptr, self.n) }
}
pub fn raw_mut(&self) -> &mut [i64] {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n) }
}
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
let choices: [i64; 3] = [-1, 0, 1];
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
let dist: WeightedIndex<f64> = WeightedIndex::new(&weights).unwrap();
self.data
.iter_mut()
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
}
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
assert!(hw <= self.n());
self.data[..hw]
.iter_mut()
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
self.data.shuffle(source);
}
pub fn as_vec_znx(&self) -> VecZnx {
VecZnx {
inner: ZnxBase {
n: self.n,
rows: 1,
cols: 1,
size: 1,
data: Vec::new(),
ptr: self.ptr as *mut u8,
},
}
}
}
pub trait ScalarOps {
fn bytes_of_scalar(&self) -> usize;
fn new_scalar(&self) -> Scalar;
fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar;
fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar;
}
impl<B: Backend> ScalarOps for Module<B> {
fn bytes_of_scalar(&self) -> usize {
Scalar::bytes_of(self.n())
}
fn new_scalar(&self) -> Scalar {
Scalar::new(self.n())
}
fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar {
Scalar::from_bytes(self.n(), bytes)
}
fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar {
Scalar::from_bytes_borrow(self.n(), tmp_bytes)
}
}
pub const SCALAR_ZNX_DFT_ROWS: usize = 1;
pub const SCALAR_ZNX_DFT_SIZE: usize = 1;
pub struct ScalarZnxDft<B: Backend> {
pub n: usize,
pub data: Vec<u8>,
pub ptr: *mut u8,
pub inner: ZnxBase,
_marker: PhantomData<B>,
}
/// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft].
/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb.
impl ScalarZnxDft<FFT64> {
pub fn new(module: &Module<FFT64>) -> Self {
module.new_scalar_znx_dft()
impl<B: Backend> GetZnxBase for ScalarZnxDft<B> {
fn znx(&self) -> &ZnxBase {
&self.inner
}
/// Returns the ring degree of the [SvpPPol].
pub fn n(&self) -> usize {
self.n
fn znx_mut(&mut self) -> &mut ZnxBase {
&mut self.inner
}
}
pub fn bytes_of(module: &Module<FFT64>) -> usize {
module.bytes_of_scalar_znx_dft()
}
impl<B: Backend> ZnxInfos for ScalarZnxDft<B> {}
pub fn from_bytes(module: &Module<FFT64>, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert_alignement(bytes.as_ptr());
assert_eq!(bytes.len(), module.bytes_of_scalar_znx_dft());
}
unsafe {
Self {
n: module.n(),
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: bytes.as_mut_ptr(),
_marker: PhantomData,
}
}
}
impl<B: Backend> ZnxAlloc<B> for ScalarZnxDft<B> {
type Scalar = u8;
pub fn from_bytes_borrow(module: &Module<FFT64>, tmp_bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
assert_eq!(tmp_bytes.len(), module.bytes_of_scalar_znx_dft());
}
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self {
Self {
n: module.n(),
data: Vec::new(),
ptr: tmp_bytes.as_mut_ptr(),
inner: ZnxBase::from_bytes_borrow(
module.n(),
SCALAR_ZNX_DFT_ROWS,
cols,
SCALAR_ZNX_DFT_SIZE,
bytes,
),
_marker: PhantomData,
}
}
/// Returns the number of cols of the [SvpPPol], which is always 1.
pub fn cols(&self) -> usize {
1
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, _size: usize) -> usize {
debug_assert_eq!(
_rows, SCALAR_ZNX_DFT_ROWS,
"rows != {} not supported for ScalarZnxDft",
SCALAR_ZNX_DFT_ROWS
);
debug_assert_eq!(
_size, SCALAR_ZNX_DFT_SIZE,
"rows != {} not supported for ScalarZnxDft",
SCALAR_ZNX_DFT_SIZE
);
unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols }
}
}
pub trait ScalarZnxDftOps<B: Backend> {
/// Allocates a new [SvpPPol].
fn new_scalar_znx_dft(&self) -> ScalarZnxDft<B>;
/// Returns the minimum number of bytes necessary to allocate
/// a new [SvpPPol] through [SvpPPol::from_bytes] ro.
fn bytes_of_scalar_znx_dft(&self) -> usize;
/// Allocates a new [SvpPPol] from an array of bytes.
/// The array of bytes is owned by the [SvpPPol].
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<B>;
/// Allocates a new [SvpPPol] from an array of bytes.
/// The array of bytes is borrowed by the [SvpPPol].
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<B>;
/// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft].
fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft<B>, a: &Scalar);
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
/// the [VecZnxDft] is multiplied with [SvpPPol].
fn svp_apply_dft(&self, res: &mut VecZnxDft<B>, res_col: usize, a: &ScalarZnxDft<B>, b: &VecZnx, b_col: usize);
impl ZnxLayout for ScalarZnxDft<FFT64> {
type Scalar = f64;
}
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
fn new_scalar_znx_dft(&self) -> ScalarZnxDft<FFT64> {
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_scalar_znx_dft());
let ptr: *mut u8 = data.as_mut_ptr();
ScalarZnxDft::<FFT64> {
data: data,
ptr: ptr,
n: self.n(),
_marker: PhantomData,
}
}
fn bytes_of_scalar_znx_dft(&self) -> usize {
unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize }
}
fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
ScalarZnxDft::from_bytes(self, bytes)
}
fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
ScalarZnxDft::from_bytes_borrow(self, tmp_bytes)
}
fn svp_prepare(&self, res: &mut ScalarZnxDft<FFT64>, a: &Scalar) {
unsafe { svp::svp_prepare(self.ptr, res.ptr as *mut svp_ppol_t, a.as_ptr()) }
}
fn svp_apply_dft(&self, res: &mut VecZnxDft<FFT64>, res_col: usize, a: &ScalarZnxDft<FFT64>, b: &VecZnx, b_col: usize) {
unsafe {
svp::svp_apply_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
a.ptr as *const svp_ppol_t,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
impl ZnxSliceSize for ScalarZnxDft<FFT64> {
fn sl(&self) -> usize {
self.n()
}
}

View File

@@ -0,0 +1,63 @@
use crate::ffi::svp::{self, svp_ppol_t};
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, SCALAR_ZNX_DFT_ROWS, SCALAR_ZNX_DFT_SIZE, Scalar, ScalarZnxDft, VecZnx, VecZnxDft};
pub trait ScalarZnxDftOps<B: Backend> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDft<B>;
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDft<B>;
fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft<B>;
fn svp_prepare(&self, res: &mut ScalarZnxDft<B>, res_col: usize, a: &Scalar, a_col: usize);
fn svp_apply_dft(&self, res: &mut VecZnxDft<B>, res_col: usize, a: &ScalarZnxDft<B>, a_col: usize, b: &VecZnx, b_col: usize);
}
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDft<FFT64> {
ScalarZnxDft::<FFT64>::new(&self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE)
}
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize {
ScalarZnxDft::<FFT64>::bytes_of(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE)
}
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDft<FFT64> {
ScalarZnxDft::from_bytes(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes)
}
fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
ScalarZnxDft::from_bytes_borrow(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes)
}
fn svp_prepare(&self, res: &mut ScalarZnxDft<FFT64>, res_col: usize, a: &Scalar, a_col: usize) {
unsafe {
svp::svp_prepare(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut svp_ppol_t,
a.at_ptr(a_col, 0),
)
}
}
fn svp_apply_dft(
&self,
res: &mut VecZnxDft<FFT64>,
res_col: usize,
a: &ScalarZnxDft<FFT64>,
a_col: usize,
b: &VecZnx,
b_col: usize,
) {
unsafe {
svp::svp_apply_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
a.at_ptr(a_col, 0) as *const svp_ppol_t,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
}

View File

@@ -26,7 +26,7 @@ impl<B: Backend> ZnxAlloc<B> for VecZnxDft<B> {
type Scalar = u8;
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
VecZnxDft {
Self {
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes),
_marker: PhantomData,
}

View File

@@ -47,47 +47,47 @@ pub trait VecZnxOps {
&self,
log_base2k: usize,
res: &mut VecZnx,
col_res: usize,
res_col: usize,
a: &VecZnx,
col_a: usize,
a_col: usize,
tmp_bytes: &mut [u8],
);
/// Normalizes the selected column of `a`.
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, tmp_bytes: &mut [u8]);
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]);
/// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `c`.
fn vec_znx_add(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize);
fn vec_znx_add(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize);
/// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `res`.
fn vec_znx_add_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize);
/// Subtracts the selected column of `b` to the selected column of `a` and write the result on the selected column of `res`.
fn vec_znx_sub(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize);
fn vec_znx_sub(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize);
/// Subtracts the selected column of `a` to the selected column of `res`.
fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize);
/// Subtracts the selected column of `a` to the selected column of `res` and negates the selected column of `res`.
fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize);
// Negates the selected column of `a` and stores the result on the selected column of `res`.
fn vec_znx_negate(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize);
/// Negates the selected column of `a`.
fn vec_znx_negate_inplace(&self, a: &mut VecZnx, col_a: usize);
fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize);
/// Multiplies the selected column of `a` by X^k and stores the result on the selected column of `res`.
fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize);
/// Multiplies the selected column of `a` by X^k.
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize);
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize);
/// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result on the selected column of `res`.
fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize);
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize);
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize);
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
///
@@ -95,7 +95,7 @@ pub trait VecZnxOps {
///
/// This method requires that all [VecZnx] of b have the same ring degree
/// and that b.n() * b.len() <= a.n()
fn vec_znx_split(&self, res: &mut Vec<VecZnx>, col_res: usize, a: &VecZnx, col_a: usize, buf: &mut VecZnx);
fn vec_znx_split(&self, res: &mut Vec<VecZnx>, res_col: usize, a: &VecZnx, a_col: usize, buf: &mut VecZnx);
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
///
@@ -103,7 +103,7 @@ pub trait VecZnxOps {
///
/// This method requires that all [VecZnx] of a have the same ring degree
/// and that a.n() * a.len() <= b.n()
fn vec_znx_merge(&self, res: &mut VecZnx, col_res: usize, a: &Vec<VecZnx>, col_a: usize);
fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec<VecZnx>, a_col: usize);
}
impl<B: Backend> VecZnxOps for Module<B> {
@@ -131,9 +131,9 @@ impl<B: Backend> VecZnxOps for Module<B> {
&self,
log_base2k: usize,
res: &mut VecZnx,
col_res: usize,
res_col: usize,
a: &VecZnx,
col_a: usize,
a_col: usize,
tmp_bytes: &mut [u8],
) {
#[cfg(debug_assertions)]
@@ -147,10 +147,10 @@ impl<B: Backend> VecZnxOps for Module<B> {
vec_znx::vec_znx_normalize_base2k(
self.ptr,
log_base2k as u64,
res.at_mut_ptr(col_res, 0),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(col_a, 0),
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
@@ -158,22 +158,22 @@ impl<B: Backend> VecZnxOps for Module<B> {
}
}
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, tmp_bytes: &mut [u8]) {
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) {
unsafe {
let a_ptr: *mut VecZnx = a as *mut VecZnx;
Self::vec_znx_normalize(
self,
log_base2k,
&mut *a_ptr,
col_a,
a_col,
&*a_ptr,
col_a,
a_col,
tmp_bytes,
);
}
}
fn vec_znx_add(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize) {
fn vec_znx_add(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
@@ -184,27 +184,27 @@ impl<B: Backend> VecZnxOps for Module<B> {
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(col_res, 0),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(col_a, 0),
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(col_b, 0),
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_add_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) {
unsafe {
let res_ptr: *mut VecZnx = res as *mut VecZnx;
Self::vec_znx_add(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res);
Self::vec_znx_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
}
}
fn vec_znx_sub(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize) {
fn vec_znx_sub(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
@@ -215,34 +215,34 @@ impl<B: Backend> VecZnxOps for Module<B> {
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(col_res, 0),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(col_a, 0),
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(col_b, 0),
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) {
unsafe {
let res_ptr: *mut VecZnx = res as *mut VecZnx;
Self::vec_znx_sub(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res);
Self::vec_znx_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
}
}
fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) {
unsafe {
let res_ptr: *mut VecZnx = res as *mut VecZnx;
Self::vec_znx_sub(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a);
Self::vec_znx_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
}
}
fn vec_znx_negate(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
@@ -251,24 +251,24 @@ impl<B: Backend> VecZnxOps for Module<B> {
unsafe {
vec_znx::vec_znx_negate(
self.ptr,
res.at_mut_ptr(col_res, 0),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(col_a, 0),
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_negate_inplace(&self, a: &mut VecZnx, col_a: usize) {
fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize) {
unsafe {
let a_ptr: *mut VecZnx = a as *mut VecZnx;
Self::vec_znx_negate(self, &mut *a_ptr, col_a, &*a_ptr, col_a);
Self::vec_znx_negate(self, &mut *a_ptr, a_col, &*a_ptr, a_col);
}
}
fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
@@ -278,24 +278,24 @@ impl<B: Backend> VecZnxOps for Module<B> {
vec_znx::vec_znx_rotate(
self.ptr,
k,
res.at_mut_ptr(col_res, 0),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(col_a, 0),
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize) {
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) {
unsafe {
let a_ptr: *mut VecZnx = a as *mut VecZnx;
Self::vec_znx_rotate(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a);
Self::vec_znx_rotate(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col);
}
}
fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
@@ -305,24 +305,24 @@ impl<B: Backend> VecZnxOps for Module<B> {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
res.at_mut_ptr(col_res, 0),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(col_a, 0),
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize) {
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) {
unsafe {
let a_ptr: *mut VecZnx = a as *mut VecZnx;
Self::vec_znx_automorphism(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a);
Self::vec_znx_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col);
}
}
fn vec_znx_split(&self, res: &mut Vec<VecZnx>, col_res: usize, a: &VecZnx, col_a: usize, buf: &mut VecZnx) {
fn vec_znx_split(&self, res: &mut Vec<VecZnx>, res_col: usize, a: &VecZnx, a_col: usize, buf: &mut VecZnx) {
let (n_in, n_out) = (a.n(), res[0].n());
debug_assert!(
@@ -339,16 +339,16 @@ impl<B: Backend> VecZnxOps for Module<B> {
res.iter_mut().enumerate().for_each(|(i, bi)| {
if i == 0 {
switch_degree(bi, col_res, a, col_a);
self.vec_znx_rotate(-1, buf, 0, a, col_a);
switch_degree(bi, res_col, a, a_col);
self.vec_znx_rotate(-1, buf, 0, a, a_col);
} else {
switch_degree(bi, col_res, buf, col_a);
self.vec_znx_rotate_inplace(-1, buf, col_a);
switch_degree(bi, res_col, buf, a_col);
self.vec_znx_rotate_inplace(-1, buf, a_col);
}
})
}
fn vec_znx_merge(&self, res: &mut VecZnx, col_res: usize, a: &Vec<VecZnx>, col_a: usize) {
fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec<VecZnx>, a_col: usize) {
let (n_in, n_out) = (res.n(), a[0].n());
debug_assert!(
@@ -364,10 +364,10 @@ impl<B: Backend> VecZnxOps for Module<B> {
});
a.iter().enumerate().for_each(|(_, ai)| {
switch_degree(res, col_res, ai, col_a);
self.vec_znx_rotate_inplace(-1, res, col_res);
switch_degree(res, res_col, ai, a_col);
self.vec_znx_rotate_inplace(-1, res, res_col);
});
self.vec_znx_rotate_inplace(a.len() as i64, res, col_res);
self.vec_znx_rotate_inplace(a.len() as i64, res, res_col);
}
}