diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index f66a4d1..1da44e9 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,5 +1,5 @@ use base2k::{ - BACKEND, Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, + Encoding, FFT64, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned, }; use itertools::izip; @@ -8,44 +8,48 @@ use sampling::source::Source; fn main() { let n: usize = 16; let log_base2k: usize = 18; - let cols: usize = 3; + let limbs: usize = 3; let msg_cols: usize = 2; let log_scale: usize = msg_cols * log_base2k - 5; - let module: Module = Module::new(n, BACKEND::FFT64); + let module: Module = Module::::new(n); let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); - let mut res: VecZnx = module.new_vec_znx(1, cols); + let mut res: VecZnx = module.new_vec_znx(1, limbs); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) let mut s: Scalar = Scalar::new(n); s.fill_ternary_prob(0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_ppol: SvpPPol = module.new_svp_ppol(); + let mut s_ppol: SvpPPol = module.new_svp_ppol(); // s_ppol <- DFT(s) module.svp_prepare(&mut s_ppol, &s); // a <- Z_{2^prec}[X]/(X^{N}+1) - let mut a: VecZnx = module.new_vec_znx(1, cols); - module.fill_uniform(log_base2k, &mut a, cols, &mut source); + let mut a: VecZnx = module.new_vec_znx(1, limbs); + module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); + + // Scratch space for DFT values - let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.cols()); + let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.limbs()); // Applies buf_dft <- s * a module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); // Alias scratch space - let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); + let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); // buf_big <- IDFT(buf_dft) (not normalized) module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); + println!("{:?}", buf_big.raw()); + let mut m: VecZnx = module.new_vec_znx(1, msg_cols); let mut want: Vec = vec![0; n]; @@ -59,13 +63,17 @@ fn main() { // buf_big <- m - buf_big module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m); + println!("{:?}", buf_big.raw()); + // b <- normalize(buf_big) + e - let mut b: VecZnx = module.new_vec_znx(1, cols); + let mut b: VecZnx = module.new_vec_znx(1, limbs); module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry); + b.print(n); module.add_normal( log_base2k, &mut b, - log_base2k * cols, + 0, + log_base2k * limbs, &mut source, 3.2, 19.0, @@ -80,14 +88,18 @@ fn main() { // buf_big <- a * s + b module.vec_znx_big_add_small_inplace(&mut buf_big, &b); + println!("raw: {:?}", &buf_big.raw()); + // res <- normalize(buf_big) module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry); + + // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; - res.decode_vec_i64(0, log_base2k, res.cols() * log_base2k, &mut have); + res.decode_vec_i64(0, log_base2k, res.limbs() * log_base2k, &mut have); - let scale: f64 = (1 << (res.cols() * log_base2k - log_scale)) as f64; + let scale: f64 = (1 << (res.limbs() * log_base2k - log_scale)) as f64; izip!(want.iter(), have.iter()) .enumerate() .for_each(|(i, (a, b))| { diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index a69c857..4e8b97e 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -1,5 +1,5 @@ use base2k::{ - BACKEND, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, + Encoding, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, alloc_aligned, }; @@ -7,50 +7,51 @@ fn main() { let log_n: i32 = 5; let n: usize = 1 << log_n; - let module: Module = Module::new(n, BACKEND::FFT64); + let module: Module = Module::::new(n); let log_base2k: usize = 15; - let cols: usize = 5; - let log_k: usize = log_base2k * cols - 5; + let limbs_vec: usize = 5; + let log_k: usize = log_base2k * limbs_vec - 5; - let rows: usize = cols; - let cols: usize = cols + 1; + let rows_mat: usize = limbs_vec; + let limbs_mat: usize = limbs_vec + 1; // Maximum size of the byte scratch needed - let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows, cols) | module.vmp_apply_dft_tmp_bytes(cols, cols, rows, cols); + let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows_mat, 1, limbs_mat) + | module.vmp_apply_dft_tmp_bytes(limbs_vec, limbs_vec, rows_mat, limbs_mat); let mut buf: Vec = alloc_aligned(tmp_bytes); let mut a_values: Vec = vec![i64::default(); n]; a_values[1] = (1 << log_base2k) + 1; - let mut a: VecZnx = module.new_vec_znx(1, rows); + let mut a: VecZnx = module.new_vec_znx(1, limbs_vec); a.encode_vec_i64(0, log_base2k, log_k, &a_values, 32); a.normalize(log_base2k, &mut buf); - a.print(0, a.cols(), n); + a.print(n); println!(); - let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(1, rows, cols); + let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows_mat, 1, limbs_mat); - (0..a.cols()).for_each(|row_i| { - let mut tmp: VecZnx = module.new_vec_znx(1, cols); - tmp.at_mut(row_i)[1] = 1 as i64; + (0..a.limbs()).for_each(|row_i| { + let mut tmp: VecZnx = module.new_vec_znx(1, limbs_mat); + tmp.at_limb_mut(row_i)[1] = 1 as i64; module.vmp_prepare_row(&mut vmp_pmat, tmp.raw(), row_i, &mut buf); }); - let mut c_dft: VecZnxDft = module.new_vec_znx_dft(1, cols); + let mut c_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs_mat); module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); - let mut c_big: VecZnxBig = c_dft.as_vec_znx_big(); + let mut c_big: VecZnxBig = c_dft.as_vec_znx_big(); module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft); - let mut res: VecZnx = module.new_vec_znx(1, rows); + let mut res: VecZnx = module.new_vec_znx(1, limbs_vec); module.vec_znx_big_normalize(log_base2k, &mut res, &c_big, &mut buf); let mut values_res: Vec = vec![i64::default(); n]; res.decode_vec_i64(0, log_base2k, log_k, &mut values_res); - res.print(0, res.cols(), n); + res.print(n); module.free(); diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index c8c08e9..d4085cb 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -9,129 +9,130 @@ pub trait Encoding { /// /// # Arguments /// - /// * `poly_idx`: the index of the poly where to encode the data. + /// * `col_i`: the index of the poly where to encode the data. /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `log_k`: base two negative logarithm of the scaling of the data. /// * `data`: data to encode on the receiver. /// * `log_max`: base two logarithm of the infinity norm of the input data. - fn encode_vec_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize); + fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize); /// decode a vector of i64 from the receiver. /// /// # Arguments /// - /// * `poly_idx`: the index of the poly where to encode the data. + /// * `col_i`: the index of the poly where to encode the data. /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `log_k`: base two logarithm of the scaling of the data. /// * `data`: data to decode from the receiver. - fn decode_vec_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]); + fn decode_vec_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]); /// decode a vector of Float from the receiver. /// /// # Arguments - /// * `poly_idx`: the index of the poly where to encode the data. + /// * `col_i`: the index of the poly where to encode the data. /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `data`: data to decode from the receiver. - fn decode_vec_float(&self, poly_idx: usize, log_base2k: usize, data: &mut [Float]); + fn decode_vec_float(&self, col_i: usize, log_base2k: usize, data: &mut [Float]); /// encodes a single i64 on the receiver at the given index. /// /// # Arguments /// - /// * `poly_idx`: the index of the poly where to encode the data. + /// * `col_i`: the index of the poly where to encode the data. /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `log_k`: base two negative logarithm of the scaling of the data. /// * `i`: index of the coefficient on which to encode the data. /// * `data`: data to encode on the receiver. /// * `log_max`: base two logarithm of the infinity norm of the input data. - fn encode_coeff_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize); + fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize); /// decode a single of i64 from the receiver at the given index. /// /// # Arguments /// - /// * `poly_idx`: the index of the poly where to encode the data. + /// * `col_i`: the index of the poly where to encode the data. /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `log_k`: base two negative logarithm of the scaling of the data. /// * `i`: index of the coefficient to decode. /// * `data`: data to decode from the receiver. - fn decode_coeff_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64; + fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64; } impl Encoding for VecZnx { - fn encode_vec_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { - encode_vec_i64(self, poly_idx, log_base2k, log_k, data, log_max) + fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { + encode_vec_i64(self, col_i, log_base2k, log_k, data, log_max) } - fn decode_vec_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { - decode_vec_i64(self, poly_idx, log_base2k, log_k, data) + fn decode_vec_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { + decode_vec_i64(self, col_i, log_base2k, log_k, data) } - fn decode_vec_float(&self, poly_idx: usize, log_base2k: usize, data: &mut [Float]) { - decode_vec_float(self, poly_idx, log_base2k, data) + fn decode_vec_float(&self, col_i: usize, log_base2k: usize, data: &mut [Float]) { + decode_vec_float(self, col_i, log_base2k, data) } - fn encode_coeff_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { - encode_coeff_i64(self, poly_idx, log_base2k, log_k, i, value, log_max) + fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { + encode_coeff_i64(self, col_i, log_base2k, log_k, i, value, log_max) } - fn decode_coeff_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { - decode_coeff_i64(self, poly_idx, log_base2k, log_k, i) + fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { + decode_coeff_i64(self, col_i, log_base2k, log_k, i) } } -fn encode_vec_i64(a: &mut VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { - let cols: usize = (log_k + log_base2k - 1) / log_base2k; +fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { + let limbs: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!( - cols <= a.cols(), - "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", - cols, - a.cols() + limbs <= a.limbs(), + "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.limbs()={}", + limbs, + a.limbs() ); - assert!(poly_idx < a.size); + assert!(col_i < a.cols()); assert!(data.len() <= a.n()) } let data_len: usize = data.len(); let log_k_rem: usize = log_base2k - (log_k % log_base2k); - (0..a.cols()).for_each(|i| unsafe { - znx_zero_i64_ref(a.n() as u64, a.at_poly_mut_ptr(poly_idx, i)); + // Zeroes coefficients of the i-th column + (0..a.limbs()).for_each(|i| unsafe { + znx_zero_i64_ref(a.n() as u64, a.at_mut_ptr(col_i, i)); }); // If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_poly_mut(poly_idx, cols - 1)[..data_len].copy_from_slice(&data[..data_len]); + a.at_poly_mut(col_i, limbs - 1)[..data_len].copy_from_slice(&data[..data_len]); } else { let mask: i64 = (1 << log_base2k) - 1; - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - (cols - steps..cols) + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); + (limbs - steps..limbs) .rev() .enumerate() .for_each(|(i, i_rev)| { let shift: usize = i * log_base2k; - izip!(a.at_poly_mut(poly_idx, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); + izip!(a.at_poly_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); }) } // Case where self.prec % self.k != 0. if log_k_rem != log_base2k { - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - (cols - steps..cols).rev().for_each(|i| { - a.at_poly_mut(poly_idx, i)[..data_len] + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); + (limbs - steps..limbs).rev().for_each(|i| { + a.at_poly_mut(col_i, i)[..data_len] .iter_mut() .for_each(|x| *x <<= log_k_rem); }) } } -fn decode_vec_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { - let cols: usize = (log_k + log_base2k - 1) / log_base2k; +fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { + let limbs: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!( @@ -140,26 +141,26 @@ fn decode_vec_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data.len(), a.n() ); - assert!(poly_idx < a.size()); + assert!(col_i < a.cols()); } - data.copy_from_slice(a.at_poly(poly_idx, 0)); + data.copy_from_slice(a.at_poly(col_i, 0)); let rem: usize = log_base2k - (log_k % log_base2k); - (1..cols).for_each(|i| { - if i == cols - 1 && rem != log_base2k { + (1..limbs).for_each(|i| { + if i == limbs - 1 && rem != log_base2k { let k_rem: usize = log_base2k - rem; - izip!(a.at_poly(poly_idx, i).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << k_rem) + (x >> rem); }); } else { - izip!(a.at_poly(poly_idx, i).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << log_base2k) + x; }); } }) } -fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [Float]) { - let cols: usize = a.cols(); +fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) { + let limbs: usize = a.limbs(); #[cfg(debug_assertions)] { assert!( @@ -168,23 +169,23 @@ fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [ data.len(), a.n() ); - assert!(poly_idx < a.size()); + assert!(col_i < a.cols()); } - let prec: u32 = (log_base2k * cols) as u32; + let prec: u32 = (log_base2k * limbs) as u32; // 2^{log_base2k} let base = Float::with_val(prec, (1 << log_base2k) as f64); // y[i] = sum x[j][i] * 2^{-log_base2k*j} - (0..cols).for_each(|i| { + (0..limbs).for_each(|i| { if i == 0 { - izip!(a.at_poly(poly_idx, cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(col_i, limbs - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { y.assign(*x); *y /= &base; }); } else { - izip!(a.at_poly(poly_idx, cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(col_i, limbs - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { *y += Float::with_val(prec, *x); *y /= &base; }); @@ -192,61 +193,61 @@ fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [ }); } -fn encode_coeff_i64(a: &mut VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { - let cols: usize = (log_k + log_base2k - 1) / log_base2k; +fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { + let limbs: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!(i < a.n()); assert!( - cols <= a.cols(), - "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", - cols, - a.cols() + limbs <= a.limbs(), + "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.limbs()={}", + limbs, + a.limbs() ); - assert!(poly_idx < a.size()); + assert!(col_i < a.cols()); } let log_k_rem: usize = log_base2k - (log_k % log_base2k); - (0..a.cols()).for_each(|j| a.at_poly_mut(poly_idx, j)[i] = 0); + (0..a.limbs()).for_each(|j| a.at_poly_mut(col_i, j)[i] = 0); // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_poly_mut(poly_idx, cols - 1)[i] = value; + a.at_poly_mut(col_i, limbs - 1)[i] = value; } else { let mask: i64 = (1 << log_base2k) - 1; - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - (cols - steps..cols) + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); + (limbs - steps..limbs) .rev() .enumerate() .for_each(|(j, j_rev)| { - a.at_poly_mut(poly_idx, j_rev)[i] = (value >> (j * log_base2k)) & mask; + a.at_poly_mut(col_i, j_rev)[i] = (value >> (j * log_base2k)) & mask; }) } // Case where prec % k != 0. if log_k_rem != log_base2k { - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - (cols - steps..cols).rev().for_each(|j| { - a.at_poly_mut(poly_idx, j)[i] <<= log_k_rem; + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); + (limbs - steps..limbs).rev().for_each(|j| { + a.at_poly_mut(col_i, j)[i] <<= log_k_rem; }) } } -fn decode_coeff_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { +fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { #[cfg(debug_assertions)] { assert!(i < a.n()); - assert!(poly_idx < a.size()) + assert!(col_i < a.cols()) } let cols: usize = (log_k + log_base2k - 1) / log_base2k; let data: &[i64] = a.raw(); let mut res: i64 = data[i]; let rem: usize = log_base2k - (log_k % log_base2k); - let slice_size: usize = a.n() * a.size(); + let slice_size: usize = a.n() * a.limbs(); (1..cols).for_each(|i| { let x = data[i * slice_size]; if i == cols - 1 && rem != log_base2k { @@ -275,13 +276,13 @@ mod tests { let mut source: Source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - (0..a.size()).for_each(|poly_idx| { + (0..a.cols()).for_each(|col_i| { let mut have: Vec = vec![i64::default(); n]; have.iter_mut() .for_each(|x| *x = (source.next_i64() << 56) >> 56); - a.encode_vec_i64(poly_idx, log_base2k, log_k, &have, 10); + a.encode_vec_i64(col_i, log_base2k, log_k, &have, 10); let mut want: Vec = vec![i64::default(); n]; - a.decode_vec_i64(poly_idx, log_base2k, log_k, &mut want); + a.decode_vec_i64(col_i, log_base2k, log_k, &mut want); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); }); } @@ -296,12 +297,12 @@ mod tests { let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - (0..a.size()).for_each(|poly_idx| { + (0..a.cols()).for_each(|col_i| { let mut have: Vec = vec![i64::default(); n]; have.iter_mut().for_each(|x| *x = source.next_i64()); - a.encode_vec_i64(poly_idx, log_base2k, log_k, &have, 64); + a.encode_vec_i64(col_i, log_base2k, log_k, &have, 64); let mut want = vec![i64::default(); n]; - a.decode_vec_i64(poly_idx, log_base2k, log_k, &mut want); + a.decode_vec_i64(col_i, log_base2k, log_k, &mut want); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); }) } diff --git a/base2k/src/ffi/vec_znx_big.rs b/base2k/src/ffi/vec_znx_big.rs index e1222c3..8c06e90 100644 --- a/base2k/src/ffi/vec_znx_big.rs +++ b/base2k/src/ffi/vec_znx_big.rs @@ -8,17 +8,17 @@ pub struct vec_znx_big_t { pub type VEC_ZNX_BIG = vec_znx_big_t; unsafe extern "C" { - pub fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64; + pub unsafe fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64; } unsafe extern "C" { - pub fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG; + pub unsafe fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG; } unsafe extern "C" { - pub fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG); + pub unsafe fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG); } unsafe extern "C" { - pub fn vec_znx_big_add( + pub unsafe fn vec_znx_big_add( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -29,7 +29,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_add_small( + pub unsafe fn vec_znx_big_add_small( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -41,7 +41,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_add_small2( + pub unsafe fn vec_znx_big_add_small2( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -54,7 +54,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_sub( + pub unsafe fn vec_znx_big_sub( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -65,7 +65,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_sub_small_b( + pub unsafe fn vec_znx_big_sub_small_b( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -77,7 +77,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_sub_small_a( + pub unsafe fn vec_znx_big_sub_small_a( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -89,7 +89,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_sub_small2( + pub unsafe fn vec_znx_big_sub_small2( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -101,8 +101,13 @@ unsafe extern "C" { b_sl: u64, ); } + unsafe extern "C" { - pub fn vec_znx_big_normalize_base2k( + pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_normalize_base2k( module: *const MODULE, log2_base2k: u64, res: *mut i64, @@ -113,34 +118,9 @@ unsafe extern "C" { tmp_space: *mut u8, ); } -unsafe extern "C" { - pub fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; -} unsafe extern "C" { - pub fn vec_znx_big_automorphism( - module: *const MODULE, - p: i64, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const VEC_ZNX_BIG, - a_size: u64, - ); -} - -unsafe extern "C" { - pub fn vec_znx_big_rotate( - module: *const MODULE, - p: i64, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const VEC_ZNX_BIG, - a_size: u64, - ); -} - -unsafe extern "C" { - pub fn vec_znx_big_range_normalize_base2k( + pub unsafe fn vec_znx_big_range_normalize_base2k( module: *const MODULE, log2_base2k: u64, res: *mut i64, @@ -153,6 +133,29 @@ unsafe extern "C" { tmp_space: *mut u8, ); } + unsafe extern "C" { - pub fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; + pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_automorphism( + module: *const MODULE, + p: i64, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + ); +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_rotate( + module: *const MODULE, + p: i64, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + ); } diff --git a/base2k/src/infos.rs b/base2k/src/infos.rs index ba799d7..764a7fe 100644 --- a/base2k/src/infos.rs +++ b/base2k/src/infos.rs @@ -1,28 +1,4 @@ -#[derive(Copy, Clone)] -#[repr(C)] -pub struct LAYOUT{ - /// Ring degree. - n: usize, - /// Number of logical rows in the layout. - rows: usize, - /// Number of polynomials per row. - cols: usize, - /// Number of limbs per polynomial. - size: usize, - /// Whether limbs are interleaved inside a row. - /// - /// For example, for (rows, cols, size) = (2, 2, 3): - /// - /// - `true`: layout is ((a0, b0, a1, b1, a2, b2), (c0, d0, c1, d1, c2, d2)) - /// - `false`: layout is ((a0, a1, a2, b0, b1, b2), (c0, c1, c2, d0, d1, d2)) - interleaved : bool, -} - pub trait Infos { - - /// Returns the full layout. - fn layout(&self) -> LAYOUT; - /// Returns the ring degree of the polynomials. fn n(&self) -> usize; @@ -36,8 +12,8 @@ pub trait Infos { fn cols(&self) -> usize; /// Returns the number of limbs per polynomial. - fn size(&self) -> usize; + fn limbs(&self) -> usize; - /// Whether limbs are interleaved across rows. - fn interleaved(&self) -> bool; + /// Returns the total number of small polynomials. + fn poly_count(&self) -> usize; } diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index ec0d2b7..5144afd 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -27,7 +27,7 @@ pub use vmp::*; pub const GALOISGENERATOR: u64 = 5; pub const DEFAULTALIGN: usize = 64; -pub fn is_aligned_custom(ptr: *const T, align: usize) -> bool { +fn is_aligned_custom(ptr: *const T, align: usize) -> bool { (ptr as usize) % align == 0 } @@ -54,13 +54,10 @@ pub fn cast_mut(data: &[T]) -> &mut [V] { unsafe { std::slice::from_raw_parts_mut(ptr, len) } } -use std::alloc::{Layout, alloc}; -use std::ptr; - /// Allocates a block of bytes with a custom alignement. /// Alignement must be a power of two and size a multiple of the alignement. /// Allocated memory is initialized to zero. -pub fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { +fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { assert!( align.is_power_of_two(), "Alignment must be a power of two but is {}", @@ -74,8 +71,8 @@ pub fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { align ); unsafe { - let layout: Layout = Layout::from_size_align(size, align).expect("Invalid alignment"); - let ptr: *mut u8 = alloc(layout); + let layout: std::alloc::Layout = std::alloc::Layout::from_size_align(size, align).expect("Invalid alignment"); + let ptr: *mut u8 = std::alloc::alloc(layout); if ptr.is_null() { panic!("Memory allocation failed"); } @@ -86,18 +83,11 @@ pub fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { align ); // Init allocated memory to zero - ptr::write_bytes(ptr, 0, size); + std::ptr::write_bytes(ptr, 0, size); Vec::from_raw_parts(ptr, size, size) } } -/// Allocates a block of bytes aligned with [DEFAULTALIGN]. -/// Size must be amultiple of [DEFAULTALIGN]. -/// /// Allocated memory is initialized to zero. -pub fn alloc_aligned_u8(size: usize) -> Vec { - alloc_aligned_custom_u8(size, DEFAULTALIGN) -} - /// Allocates a block of T aligned with [DEFAULTALIGN]. /// Size of T * size msut be a multiple of [DEFAULTALIGN]. pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { diff --git a/base2k/src/module.rs b/base2k/src/module.rs index 8cbdbca..205cf62 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -1,5 +1,6 @@ use crate::GALOISGENERATOR; use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info}; +use std::marker::PhantomData; #[derive(Copy, Clone)] #[repr(u8)] @@ -8,37 +9,50 @@ pub enum BACKEND { NTT120, } -pub struct Module { - pub ptr: *mut MODULE, - pub n: usize, - pub backend: BACKEND, +pub trait Backend { + const KIND: BACKEND; + fn module_type() -> u32; } -impl Module { +pub struct FFT64; +pub struct NTT120; + +impl Backend for FFT64 { + const KIND: BACKEND = BACKEND::FFT64; + fn module_type() -> u32 { + 0 + } +} + +impl Backend for NTT120 { + const KIND: BACKEND = BACKEND::NTT120; + fn module_type() -> u32 { + 1 + } +} + +pub struct Module { + pub ptr: *mut MODULE, + pub n: usize, + _marker: PhantomData, +} + +impl Module { // Instantiates a new module. - pub fn new(n: usize, module_type: BACKEND) -> Self { + pub fn new(n: usize) -> Self { unsafe { - let module_type_u32: u32; - match module_type { - BACKEND::FFT64 => module_type_u32 = 0, - BACKEND::NTT120 => module_type_u32 = 1, - } - let m: *mut module_info_t = new_module_info(n as u64, module_type_u32); + let m: *mut module_info_t = new_module_info(n as u64, B::module_type()); if m.is_null() { panic!("Failed to create module."); } Self { ptr: m, n: n, - backend: module_type, + _marker: PhantomData, } } } - pub fn backend(&self) -> BACKEND { - self.backend - } - pub fn n(&self) -> usize { self.n } diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index 064c1e2..db9a79b 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,16 +1,17 @@ -use crate::{Infos, Module, VecZnx}; +use crate::{Backend, Infos, Module, VecZnx}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; pub trait Sampling { - /// Fills the first `cols` cols with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source); + /// Fills the first `limbs` limbs with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] + fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source); /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. fn add_dist_f64>( &self, log_base2k: usize, a: &mut VecZnx, + col_i: usize, log_k: usize, source: &mut Source, dist: D, @@ -18,24 +19,35 @@ pub trait Sampling { ); /// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\]. - fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64); + fn add_normal( + &self, + log_base2k: usize, + a: &mut VecZnx, + col_i: usize, + log_k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ); } -impl Sampling for Module { - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source) { +impl Sampling for Module { + fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source) { let base2k: u64 = 1 << log_base2k; let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; - let size: usize = a.n() * cols; - a.raw_mut()[..size] - .iter_mut() - .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); + (0..limbs).for_each(|j| { + a.at_poly_mut(col_i, j) + .iter_mut() + .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); + }) } fn add_dist_f64>( &self, log_base2k: usize, a: &mut VecZnx, + col_i: usize, log_k: usize, source: &mut Source, dist: D, @@ -50,28 +62,42 @@ impl Sampling for Module { let log_base2k_rem: usize = log_k % log_base2k; if log_base2k_rem != 0 { - a.at_mut(a.cols() - 1).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += (dist_f64.round() as i64) << log_base2k_rem; - }); + a.at_poly_mut(col_i, a.limbs() - 1) + .iter_mut() + .for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += (dist_f64.round() as i64) << log_base2k_rem; + }); } else { - a.at_mut(a.cols() - 1).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += dist_f64.round() as i64 - }); + a.at_poly_mut(col_i, a.limbs() - 1) + .iter_mut() + .for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += dist_f64.round() as i64 + }); } } - fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64) { + fn add_normal( + &self, + log_base2k: usize, + a: &mut VecZnx, + col_i: usize, + log_k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) { self.add_dist_f64( log_base2k, a, + col_i, log_k, source, Normal::new(0.0, sigma).unwrap(), diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs index bc37f86..e293668 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/svp.rs @@ -1,6 +1,8 @@ +use std::marker::PhantomData; + use crate::ffi::svp::{self, svp_ppol_t}; use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::{BACKEND, LAYOUT, Module, VecZnx, VecZnxDft, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, assert_alignement}; use crate::{Infos, alloc_aligned, cast_mut}; use rand::seq::SliceRandom; @@ -14,7 +16,7 @@ pub struct Scalar { pub ptr: *mut i64, } -impl Module { +impl Module { pub fn new_scalar(&self) -> Scalar { Scalar::new(self.n()) } @@ -117,9 +119,8 @@ impl Scalar { pub fn as_vec_znx(&self) -> VecZnx { VecZnx { n: self.n, - size: 1, // TODO REVIEW IF NEED TO ADD size TO SCALAR cols: 1, - layout: LAYOUT::COL(1, 1), + limbs: 1, data: Vec::new(), ptr: self.ptr, } @@ -132,7 +133,7 @@ pub trait ScalarOps { fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar; fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar; } -impl ScalarOps for Module { +impl ScalarOps for Module { fn bytes_of_scalar(&self) -> usize { Scalar::bytes_of(self.n()) } @@ -147,17 +148,17 @@ impl ScalarOps for Module { } } -pub struct SvpPPol { +pub struct SvpPPol { pub n: usize, pub data: Vec, pub ptr: *mut u8, - pub backend: BACKEND, + _marker: PhantomData, } /// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft]. /// An [SvpPPol] an be seen as a [VecZnxDft] of one limb. -impl SvpPPol { - pub fn new(module: &Module) -> Self { +impl SvpPPol { + pub fn new(module: &Module) -> Self { module.new_svp_ppol() } @@ -166,11 +167,11 @@ impl SvpPPol { self.n } - pub fn bytes_of(module: &Module) -> usize { + pub fn bytes_of(module: &Module) -> usize { module.bytes_of_svp_ppol() } - pub fn from_bytes(module: &Module, bytes: &mut [u8]) -> SvpPPol { + pub fn from_bytes(module: &Module, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { assert_alignement(bytes.as_ptr()); @@ -181,12 +182,12 @@ impl SvpPPol { n: module.n(), data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), ptr: bytes.as_mut_ptr(), - backend: module.backend(), + _marker: PhantomData, } } } - pub fn from_bytes_borrow(module: &Module, tmp_bytes: &mut [u8]) -> SvpPPol { + pub fn from_bytes_borrow(module: &Module, tmp_bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -196,7 +197,7 @@ impl SvpPPol { n: module.n(), data: Vec::new(), ptr: tmp_bytes.as_mut_ptr(), - backend: module.backend(), + _marker: PhantomData, } } @@ -206,9 +207,9 @@ impl SvpPPol { } } -pub trait SvpPPolOps { +pub trait SvpPPolOps { /// Allocates a new [SvpPPol]. - fn new_svp_ppol(&self) -> SvpPPol; + fn new_svp_ppol(&self) -> SvpPPol; /// Returns the minimum number of bytes necessary to allocate /// a new [SvpPPol] through [SvpPPol::from_bytes] ro. @@ -217,30 +218,30 @@ pub trait SvpPPolOps { /// Allocates a new [SvpPPol] from an array of bytes. /// The array of bytes is owned by the [SvpPPol]. /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol; + fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol; /// Allocates a new [SvpPPol] from an array of bytes. /// The array of bytes is borrowed by the [SvpPPol]. /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol; + fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol; /// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft]. - fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar); + fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar); /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// the [VecZnxDft] is multiplied with [SvpPPol]. - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx); + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx); } -impl SvpPPolOps for Module { - fn new_svp_ppol(&self) -> SvpPPol { +impl SvpPPolOps for Module { + fn new_svp_ppol(&self) -> SvpPPol { let mut data: Vec = alloc_aligned::(self.bytes_of_svp_ppol()); let ptr: *mut u8 = data.as_mut_ptr(); - SvpPPol { + SvpPPol:: { data: data, ptr: ptr, n: self.n(), - backend: self.backend(), + _marker: PhantomData, } } @@ -248,19 +249,19 @@ impl SvpPPolOps for Module { unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize } } - fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol { + fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol { SvpPPol::from_bytes(self, bytes) } - fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol { + fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol { SvpPPol::from_bytes_borrow(self, tmp_bytes) } - fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) { + fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) { unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) } } - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) { + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) { unsafe { svp::svp_apply_dft( self.ptr, diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 71a315e..aff1ce9 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,4 +1,4 @@ -use crate::LAYOUT; +use crate::Backend; use crate::cast_mut; use crate::ffi::vec_znx; use crate::ffi::znx; @@ -22,11 +22,11 @@ pub struct VecZnx { /// Polynomial degree. pub n: usize, - /// Number of limbs - pub size: usize, + /// The number of polynomials + pub cols: usize, - /// Layout - pub layout: LAYOUT, + /// The number of limbs per polynomial (a.k.a small polynomials). + pub limbs: usize, /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. pub data: Vec, @@ -35,58 +35,60 @@ pub struct VecZnx { pub ptr: *mut i64, } -pub fn bytes_of_vec_znx(n: usize, layout: LAYOUT, size: usize) -> usize { - n * layout.size() * size * 8 +pub fn bytes_of_vec_znx(n: usize, cols: usize, limbs: usize) -> usize { + n * cols * limbs * size_of::() } impl VecZnx { /// Returns a new struct implementing [VecZnx] with the provided data as backing array. /// - /// The struct will take ownership of buf[..[VecZnx::bytes_of]] + /// The struct will take ownership of buf[..[Self::bytes_of]] /// /// User must ensure that data is properly alligned and that - /// the size of data is equal to [VecZnx::bytes_of]. - pub fn from_bytes(n: usize, layout: LAYOUT, size: usize, bytes: &mut [u8]) -> Self { + /// the limbs of data is equal to [Self::bytes_of]. + pub fn from_bytes(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(n, layout, size)); + assert!(cols > 0); + assert!(limbs > 0); + assert_eq!(bytes.len(), Self::bytes_of(n, cols, limbs)); assert_alignement(bytes.as_ptr()); } unsafe { let bytes_i64: &mut [i64] = cast_mut::(bytes); let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - VecZnx { + Self { n: n, - size: size, - layout: layout, + cols: cols, + limbs: limbs, data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), ptr: ptr, } } } - pub fn from_bytes_borrow(n: usize, layout: LAYOUT, size: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes_borrow(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert!(size > 0); - assert!(bytes.len() >= Self::bytes_of(n, layout, size)); + assert!(cols > 0); + assert!(limbs > 0); + assert!(bytes.len() >= Self::bytes_of(n, cols, limbs)); assert_alignement(bytes.as_ptr()); } - VecZnx { + Self { n: n, - size: size, - layout: layout, + cols: cols, + limbs: limbs, data: Vec::new(), ptr: bytes.as_mut_ptr() as *mut i64, } } - pub fn bytes_of(n: usize, layout: LAYOUT, size: usize) -> usize { - bytes_of_vec_znx(n, layout, size) + pub fn bytes_of(n: usize, cols: usize, limbs: usize) -> usize { + bytes_of_vec_znx(n, cols, limbs) } - pub fn copy_from(&mut self, a: &VecZnx) { + pub fn copy_from(&mut self, a: &Self) { copy_vec_znx_from(self, a); } @@ -94,15 +96,15 @@ impl VecZnx { self.data.len() == 0 } - /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::size()]. + /// Total limbs is [Self::n()] * [Self::poly_count()]. pub fn raw(&self) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.size * self.size) } + unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.poly_count()) } } /// Returns a reference to backend slice of the receiver. - /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::size()]. + /// Total size is [Self::n()] * [Self::poly_count()]. pub fn raw_mut(&mut self) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.size * self.size) } + unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.poly_count()) } } /// Returns a non-mutable pointer to the backedn slice of the receiver. @@ -115,76 +117,55 @@ impl VecZnx { self.ptr } - /// Returns a non-mutable pointer starting a the j-th column. - pub fn at_ptr(&self, i: usize) -> *const i64 { + /// Returns a non-mutable pointer starting a the (i, j)-th small poly. + pub fn at_ptr(&self, i: usize, j: usize) -> *const i64 { #[cfg(debug_assertions)] { - assert!(i < self.size); + assert!(i < self.cols()); + assert!(j < self.limbs()); } - let offset: usize = self.n * self.size * i; + let offset: usize = self.n * (j * self.cols() + i); self.ptr.wrapping_add(offset) } - /// Returns non-mutable reference to the ith-column. - /// The slice contains [VecZnx::size()] small polynomials, each of [VecZnx::n()] coefficients. - pub fn at(&self, i: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(i), self.n * self.size) } + /// Returns a non-mutable reference to the i-th limb. + /// The returned array is of size [Self::n()] * [Self::cols()]. + pub fn at_limb(&self, i: usize) -> &[i64] { + unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) } } - /// Returns a non-mutable pointer starting a the j-th column of the i-th polynomial. - pub fn at_poly_ptr(&self, i: usize, j: usize) -> *const i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.size); - assert!(j < self.size); - } - let offset: usize = self.n * (self.size * j + i); - self.ptr.wrapping_add(offset) - } - - /// Returns non-mutable reference to the j-th column of the i-th polynomial. - /// The slice contains one small polynomial of [VecZnx::n()] coefficients. + /// Returns a non-mutable reference to the (i, j)-th poly. + /// The returned array is of size [Self::n()]. pub fn at_poly(&self, i: usize, j: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.at_poly_ptr(i, j), self.n) } + unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) } } - /// Returns a mutable pointer starting a the j-th column. - pub fn at_mut_ptr(&self, i: usize) -> *mut i64 { + /// Returns a mutable pointer starting a the (i, j)-th small poly. + pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut i64 { #[cfg(debug_assertions)] { - assert!(i < self.size); + assert!(i < self.cols()); + assert!(j < self.limbs()); } - let offset: usize = self.n * self.size * i; + + let offset: usize = self.n * (j * self.cols() + i); self.ptr.wrapping_add(offset) } - /// Returns mutable reference to the ith-column. - /// The slice contains [VecZnx::size()] small polynomials, each of [VecZnx::n()] coefficients. - pub fn at_mut(&mut self, i: usize) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i), self.n * self.size) } + /// Returns a mutable reference to the i-th limb. + /// The returned array is of size [Self::n()] * [Self::cols()]. + pub fn at_limb_mut(&mut self, i: usize) -> &mut [i64] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) } } - /// Returns a mutable pointer starting a the j-th column of the i-th polynomial. - pub fn at_poly_mut_ptr(&mut self, i: usize, j: usize) -> *mut i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.size); - assert!(j < self.size); - } - - let offset: usize = self.n * (self.size * j + i); - self.ptr.wrapping_add(offset) - } - - /// Returns mutable reference to the j-th column of the i-th polynomial. - /// The slice contains one small polynomial of [VecZnx::n()] coefficients. + /// Returns a mutable reference to the (i, j)-th poly. + /// The returned array is of size [Self::n()]. pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [i64] { - let ptr: *mut i64 = self.at_poly_mut_ptr(i, j); - unsafe { std::slice::from_raw_parts_mut(ptr, self.n) } + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) } } pub fn zero(&mut self) { - unsafe { znx::znx_zero_i64_ref((self.n * self.size * self.size) as u64, self.ptr) } + unsafe { znx::znx_zero_i64_ref((self.n * self.poly_count()) as u64, self.ptr) } } pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { @@ -195,48 +176,47 @@ impl VecZnx { rsh(log_base2k, self, k, carry) } - pub fn switch_degree(&self, a: &mut VecZnx) { + pub fn switch_degree(&self, a: &mut Self) { switch_degree(a, self) } - pub fn print(&self, poly: usize, size: usize, n: usize) { - (0..size).for_each(|i| println!("{}: {:?}", i, &self.at_poly(poly, i)[..n])) + // Prints the first `n` coefficients of each limb + pub fn print(&self, n: usize) { + (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])) } } impl Infos for VecZnx { - /// Returns the base 2 logarithm of the [VecZnx] degree. - fn log_n(&self) -> usize { - (usize::BITS - (self.n - 1).leading_zeros()) as _ - } - - /// Returns the [VecZnx] degree. fn n(&self) -> usize { self.n } - fn size(&self) -> usize { - self.size + fn log_n(&self) -> usize { + (usize::BITS - (self.n() - 1).leading_zeros()) as _ } - fn layout(&self) -> LAYOUT { - self.layout - } - - /// Returns the number of size of the [VecZnx]. - fn size(&self) -> usize { - self.size - } - - /// Returns the number of rows of the [VecZnx]. fn rows(&self) -> usize { 1 } + + fn cols(&self) -> usize { + self.cols + } + + fn limbs(&self) -> usize { + self.limbs + } + + fn poly_count(&self) -> usize { + self.cols * self.limbs + } } /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. +/// Panics if the cols do not match. pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { + assert_eq!(b.cols(), a.cols()); let data_a: &[i64] = a.raw(); let data_b: &mut [i64] = b.raw_mut(); let size = min(data_b.len(), data_a.len()); @@ -245,21 +225,20 @@ pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { impl VecZnx { /// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\]. - pub fn new(n: usize, size: usize, size: usize) -> Self { + pub fn new(n: usize, cols: usize, limbs: usize) -> Self { #[cfg(debug_assertions)] { assert!(n > 0); assert!(n & (n - 1) == 0); - assert!(size > 0); - assert!(size <= u8::MAX as usize); - assert!(size > 0); + assert!(cols > 0); + assert!(limbs > 0); } - let mut data: Vec = alloc_aligned::(n * size * size); + let mut data: Vec = alloc_aligned::(n * cols * limbs); let ptr: *mut i64 = data.as_mut_ptr(); Self { n: n, - layout: LAYOUT::COL(1, size as u8), - size: size, + cols: cols, + limbs: limbs, data: data, ptr: ptr, } @@ -278,16 +257,16 @@ impl VecZnx { if !self.borrowing() { self.data - .truncate((self.size() - k / log_base2k) * self.n() * self.size()); + .truncate(self.n() * self.cols() * (self.limbs() - k / log_base2k)); } - self.size -= k / log_base2k; + self.limbs -= k / log_base2k; let k_rem: usize = k % log_base2k; if k_rem != 0 { let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem; - self.at_mut(self.size() - 1) + self.at_limb_mut(self.limbs() - 1) .iter_mut() .for_each(|x: &mut i64| *x &= mask) } @@ -305,31 +284,31 @@ pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) { b.zero(); } - let size = min(a.size(), b.size()); + let limbs: usize = min(a.limbs(), b.limbs()); - (0..size).for_each(|i| { + (0..limbs).for_each(|i| { izip!( - a.at(i).iter().step_by(gap_in), - b.at_mut(i).iter_mut().step_by(gap_out) + a.at_limb(i).iter().step_by(gap_in), + b.at_limb_mut(i).iter_mut().step_by(gap_out) ) .for_each(|(x_in, x_out)| *x_out = *x_in); }); } -fn normalize_tmp_bytes(n: usize, size: usize) -> usize { - n * size * std::mem::size_of::() +fn normalize_tmp_bytes(n: usize, limbs: usize) -> usize { + n * limbs * std::mem::size_of::() } fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { let n: usize = a.n(); - let size: usize = a.size(); + let cols: usize = a.cols(); debug_assert!( - tmp_bytes.len() >= normalize_tmp_bytes(n, size), + tmp_bytes.len() >= normalize_tmp_bytes(n, cols), "invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({}, {})", tmp_bytes.len(), n, - size, + cols, ); #[cfg(debug_assertions)] { @@ -340,45 +319,45 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { unsafe { znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); - (0..a.size()).rev().for_each(|i| { + (0..a.limbs()).rev().for_each(|i| { znx::znx_normalize( - (n * size) as u64, + (n * cols) as u64, log_base2k as u64, - a.at_mut_ptr(i), + a.at_mut_ptr(0, i), carry_i64.as_mut_ptr(), - a.at_mut_ptr(i), + a.at_mut_ptr(0, i), carry_i64.as_mut_ptr(), ) }); } } -pub fn rsh_tmp_bytes(n: usize, size: usize) -> usize { - n * size * std::mem::size_of::() +pub fn rsh_tmp_bytes(n: usize, limbs: usize) -> usize { + n * limbs * std::mem::size_of::() } pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); - let size: usize = a.size(); + let limbs: usize = a.limbs(); #[cfg(debug_assertions)] { assert!( - tmp_bytes.len() >= rsh_tmp_bytes(n, size), + tmp_bytes.len() >= rsh_tmp_bytes(n, limbs), "invalid carry: carry.len()/8={} < rsh_tmp_bytes({}, {})", tmp_bytes.len() >> 3, n, - size, + limbs, ); assert_alignement(tmp_bytes.as_ptr()); } - let size: usize = a.size(); + let limbs: usize = a.limbs(); let size_steps: usize = k / log_base2k; - a.raw_mut().rotate_right(n * size * size_steps); + a.raw_mut().rotate_right(n * limbs * size_steps); unsafe { - znx::znx_zero_i64_ref((n * size * size_steps) as u64, a.as_mut_ptr()); + znx::znx_zero_i64_ref((n * limbs * size_steps) as u64, a.as_mut_ptr()); } let k_rem = k % log_base2k; @@ -387,13 +366,13 @@ pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) { let carry_i64: &mut [i64] = cast_mut(tmp_bytes); unsafe { - znx::znx_zero_i64_ref((n * size) as u64, carry_i64.as_mut_ptr()); + znx::znx_zero_i64_ref((n * limbs) as u64, carry_i64.as_mut_ptr()); } let log_base2k: usize = log_base2k; - (size_steps..size).for_each(|i| { - izip!(carry_i64.iter_mut(), a.at_mut(i).iter_mut()).for_each(|(ci, xi)| { + (size_steps..limbs).for_each(|i| { + izip!(carry_i64.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| { *xi += *ci << log_base2k; *ci = get_base_k_carry(*xi, k_rem); *xi = (*xi - *ci) >> k_rem; @@ -412,14 +391,15 @@ pub trait VecZnxOps { /// /// # Arguments /// - /// * `size`: the number of size. - fn new_vec_znx(&self, size: usize, size: usize) -> VecZnx; + /// * `cols`: the number of polynomials. + /// * `limbs`: the number of limbs per polynomial (a.k.a small polynomials). + fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnx] through [VecZnx::from_bytes]. - fn bytes_of_vec_znx(&self, size: usize, size: usize) -> usize; + fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; - fn vec_znx_normalize_tmp_bytes(&self, size: usize) -> usize; + fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize; /// c <- a + b. fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); @@ -471,17 +451,17 @@ pub trait VecZnxOps { fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec); } -impl VecZnxOps for Module { - fn new_vec_znx(&self, size: usize, size: usize) -> VecZnx { - VecZnx::new(self.n(), size, size) +impl VecZnxOps for Module { + fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx { + VecZnx::new(self.n(), cols, limbs) } - fn bytes_of_vec_znx(&self, size: usize, size: usize) -> usize { - bytes_of_vec_znx(self.n(), size, size) + fn bytes_of_vec_znx(&self, cols: usize, limbs: usize) -> usize { + bytes_of_vec_znx(self.n(), cols, limbs) } - fn vec_znx_normalize_tmp_bytes(&self, size: usize) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * size } + fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols } } // c <- a + b @@ -497,14 +477,14 @@ impl VecZnxOps for Module { vec_znx::vec_znx_add( self.ptr, c.as_mut_ptr(), - c.size() as u64, - (n * c.size()) as u64, + c.limbs() as u64, + (n * c.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, b.as_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, ) } } @@ -521,14 +501,14 @@ impl VecZnxOps for Module { vec_znx::vec_znx_add( self.ptr, b.as_mut_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, b.as_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, ) } } @@ -546,14 +526,14 @@ impl VecZnxOps for Module { vec_znx::vec_znx_sub( self.ptr, c.as_mut_ptr(), - c.size() as u64, - (n * c.size()) as u64, + c.limbs() as u64, + (n * c.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, b.as_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, ) } } @@ -570,14 +550,14 @@ impl VecZnxOps for Module { vec_znx::vec_znx_sub( self.ptr, b.as_mut_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, b.as_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, ) } } @@ -594,14 +574,14 @@ impl VecZnxOps for Module { vec_znx::vec_znx_sub( self.ptr, b.as_mut_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, b.as_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, ) } } @@ -617,11 +597,11 @@ impl VecZnxOps for Module { vec_znx::vec_znx_negate( self.ptr, b.as_mut_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, ) } } @@ -636,11 +616,11 @@ impl VecZnxOps for Module { vec_znx::vec_znx_negate( self.ptr, a.as_mut_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, ) } } @@ -657,11 +637,11 @@ impl VecZnxOps for Module { self.ptr, k, b.as_mut_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, ) } } @@ -677,11 +657,11 @@ impl VecZnxOps for Module { self.ptr, k, a.as_mut_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, ) } } @@ -697,7 +677,7 @@ impl VecZnxOps for Module { /// /// # Panics /// - /// The method will panic if the argument `a` is greater than `a.size()`. + /// The method will panic if the argument `a` is greater than `a.limbs()`. fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { let n: usize = self.n(); #[cfg(debug_assertions)] @@ -710,11 +690,11 @@ impl VecZnxOps for Module { self.ptr, k, b.as_mut_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, ); } } @@ -729,7 +709,7 @@ impl VecZnxOps for Module { /// /// # Panics /// - /// The method will panic if the argument `size` is greater than `self.size()`. + /// The method will panic if the argument `size` is greater than `self.limbs()`. fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { let n: usize = self.n(); #[cfg(debug_assertions)] @@ -741,11 +721,11 @@ impl VecZnxOps for Module { self.ptr, k, a.as_mut_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, ); } } diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 705a5ec..b19f126 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,24 +1,26 @@ use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; -use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement}; +use std::marker::PhantomData; -pub struct VecZnxBig { +pub struct VecZnxBig { pub data: Vec, pub ptr: *mut u8, pub n: usize, - pub size: usize, pub cols: usize, - pub layout: LAYOUT, - pub backend: BACKEND, + pub limbs: usize, + pub _marker: PhantomData, } -impl VecZnxBig { +impl VecZnxBig { /// Returns a new [VecZnxBig] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. - pub fn from_bytes(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(size, cols)); + assert!(cols > 0); + assert!(limbs > 0); + assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols, limbs)); assert_alignement(bytes.as_ptr()) }; unsafe { @@ -26,91 +28,84 @@ impl VecZnxBig { data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), ptr: bytes.as_mut_ptr(), n: module.n(), - size: size, - layout: LAYOUT::COL, cols: cols, - backend: module.backend, + limbs: limbs, + _marker: PhantomData, } } } - pub fn from_bytes_borrow(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(size, cols)); + assert!(cols > 0); + assert!(limbs > 0); + assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols, limbs)); assert_alignement(bytes.as_ptr()); } Self { data: Vec::new(), ptr: bytes.as_mut_ptr(), n: module.n(), - size: size, - layout: LAYOUT::COL, cols: cols, - backend: module.backend, + limbs: limbs, + _marker: PhantomData, } } - pub fn as_vec_znx_dft(&mut self) -> VecZnxDft { - VecZnxDft { + pub fn as_vec_znx_dft(&mut self) -> VecZnxDft { + VecZnxDft:: { data: Vec::new(), ptr: self.ptr, n: self.n, - size: self.size, - layout: LAYOUT::COL, cols: self.cols, - backend: self.backend, + limbs: self.limbs, + _marker: self._marker, } } - pub fn backend(&self) -> BACKEND { - self.backend + /// Returns a non-mutable reference to the entire contiguous array of the [VecZnxDft]. + pub fn raw(&self) -> &[i64] { + let ptr: *const i64 = self.ptr as *const i64; + unsafe { &std::slice::from_raw_parts(ptr, self.n() * self.poly_count()) } } - /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is cols * n. - pub fn raw(&self, module: &Module) -> &[T] { - let ptr: *const T = self.ptr as *const T; - let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::(); - unsafe { &std::slice::from_raw_parts(ptr, len) } + // Prints the first `n` coefficients of each limb + pub fn print(&self, n: usize) { + let raw: &[i64] = self.raw(); + (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &raw[i * self.n() * self.cols()..i * self.n() * self.cols()+n])) } } -impl Infos for VecZnxBig { - /// Returns the base 2 logarithm of the [VecZnx] degree. +impl Infos for VecZnxBig { fn log_n(&self) -> usize { (usize::BITS - (self.n - 1).leading_zeros()) as _ } - /// Returns the [VecZnx] degree. fn n(&self) -> usize { self.n } - fn size(&self) -> usize { - self.size - } - - fn layout(&self) -> LAYOUT { - self.layout - } - - /// Returns the number of cols of the [VecZnx]. fn cols(&self) -> usize { self.cols } - /// Returns the number of rows of the [VecZnx]. fn rows(&self) -> usize { 1 } + + fn limbs(&self) -> usize { + self.limbs + } + + fn poly_count(&self) -> usize { + self.cols * self.limbs + } } -pub trait VecZnxBigOps { +pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. - fn new_vec_znx_big(&self, size: usize, cols: usize) -> VecZnxBig; + fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -118,12 +113,13 @@ pub trait VecZnxBigOps { /// /// # Arguments /// - /// * `cols`: the number of cols of the [VecZnxBig]. + /// * `cols`: the number of polynomials.. + /// * `limbs`: the number of limbs (a.k.a small polynomials) per polynomial. /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxBig; + fn new_vec_znx_big_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxBig; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -131,33 +127,44 @@ pub trait VecZnxBigOps { /// /// # Arguments /// - /// * `cols`: the number of cols of the [VecZnxBig]. + /// * `cols`: the number of polynomials.. + /// * `limbs`: the number of limbs (a.k.a small polynomials) per polynomial. /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; + fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. - fn bytes_of_vec_znx_big(&self, size: usize, cols: usize) -> usize; + fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize; - /// b <- b - a - fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + /// b[VecZnxBig] <- b[VecZnxBig] - a[VecZnx] + /// + /// # Behavior + /// + /// [VecZnxBig] (3 cols and 4 limbs) + /// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3] + /// - + /// [VecZnx] (2 cols and 3 limbs) + /// [d0, e0] [d1, e1] [d2, e2] + /// = + /// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3] + fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); /// c <- b - a - fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); + fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); /// c <- b + a - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); + fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); /// b <- b + a - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; /// b <- normalize(a) - fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]); + fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]); fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize; @@ -165,100 +172,111 @@ pub trait VecZnxBigOps { &self, log_base2k: usize, res: &mut VecZnx, - a: &VecZnxBig, + a: &VecZnxBig, a_range_begin: usize, a_range_xend: usize, a_range_step: usize, tmp_bytes: &mut [u8], ); - fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig); + fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig); - fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig); + fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig); } -impl VecZnxBigOps for Module { - fn new_vec_znx_big(&self, size: usize, cols: usize) -> VecZnxBig { - let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_big(size, cols)); +impl VecZnxBigOps for Module { + fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig { + #[cfg(debug_assertions)] + { + assert!(cols > 0); + assert!(limbs > 0); + } + let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_big(cols, limbs)); let ptr: *mut u8 = data.as_mut_ptr(); - VecZnxBig { + VecZnxBig:: { data: data, ptr: ptr, n: self.n(), - size: size, - layout: LAYOUT::COL, cols: cols, - backend: self.backend(), + limbs: limbs, + _marker: PhantomData, } } - fn new_vec_znx_big_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes(self, size, cols, bytes) + fn new_vec_znx_big_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxBig { + VecZnxBig::from_bytes(self, cols, limbs, bytes) } - fn new_vec_znx_big_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes_borrow(self, size, cols, tmp_bytes) + fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { + VecZnxBig::from_bytes_borrow(self, cols, limbs, tmp_bytes) } - fn bytes_of_vec_znx_big(&self, size: usize, cols: usize) -> usize { - unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, cols as u64) as usize * size } + fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize { + unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, limbs as u64) as usize * cols } } - fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { + /// [VecZnxBig] (3 cols and 4 limbs) + /// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3] + /// - + /// [VecZnx] (2 cols and 3 limbs) + /// [d0, e0] [d1, e1] [d2, e2] + /// = + /// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3] + fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { unsafe { vec_znx_big::vec_znx_big_sub_small_a( self.ptr, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, a.as_ptr(), - a.cols() as u64, + a.poly_count() as u64, a.n() as u64, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, ) } } - fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { + fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_sub_small_a( self.ptr, c.ptr as *mut vec_znx_big_t, - c.cols() as u64, + c.poly_count() as u64, a.as_ptr(), - a.cols() as u64, + a.poly_count() as u64, a.n() as u64, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, ) } } - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { + fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_add_small( self.ptr, c.ptr as *mut vec_znx_big_t, - c.cols() as u64, + c.poly_count() as u64, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, a.as_ptr(), - a.cols() as u64, + a.poly_count() as u64, a.n() as u64, ) } } - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { + fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { unsafe { vec_znx_big::vec_znx_big_add_small( self.ptr, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, a.as_ptr(), - a.cols() as u64, + a.poly_count() as u64, a.n() as u64, ) } @@ -268,12 +286,12 @@ impl VecZnxBigOps for Module { unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize } } - fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) { + fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) { debug_assert!( - tmp_bytes.len() >= ::vec_znx_big_normalize_tmp_bytes(self), + tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(self), "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", tmp_bytes.len(), - ::vec_znx_big_normalize_tmp_bytes(self) + Self::vec_znx_big_normalize_tmp_bytes(self) ); #[cfg(debug_assertions)] { @@ -284,10 +302,10 @@ impl VecZnxBigOps for Module { self.ptr, log_base2k as u64, b.as_mut_ptr(), - b.cols() as u64, + b.limbs() as u64, b.n() as u64, a.ptr as *mut vec_znx_big_t, - a.cols() as u64, + a.limbs() as u64, tmp_bytes.as_mut_ptr(), ) } @@ -301,17 +319,17 @@ impl VecZnxBigOps for Module { &self, log_base2k: usize, res: &mut VecZnx, - a: &VecZnxBig, + a: &VecZnxBig, a_range_begin: usize, a_range_xend: usize, a_range_step: usize, tmp_bytes: &mut [u8], ) { debug_assert!( - tmp_bytes.len() >= ::vec_znx_big_range_normalize_base2k_tmp_bytes(self), + tmp_bytes.len() >= Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self), "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}", tmp_bytes.len(), - ::vec_znx_big_range_normalize_base2k_tmp_bytes(self) + Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self) ); #[cfg(debug_assertions)] { @@ -322,7 +340,7 @@ impl VecZnxBigOps for Module { self.ptr, log_base2k as u64, res.as_mut_ptr(), - res.cols() as u64, + res.limbs() as u64, res.n() as u64, a.ptr as *mut vec_znx_big_t, a_range_begin as u64, @@ -333,28 +351,28 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { + fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_automorphism( self.ptr, gal_el, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, a.ptr as *mut vec_znx_big_t, - a.cols() as u64, + a.poly_count() as u64, ); } } - fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { + fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_automorphism( self.ptr, gal_el, a.ptr as *mut vec_znx_big_t, - a.cols() as u64, + a.poly_count() as u64, a.ptr as *mut vec_znx_big_t, - a.cols() as u64, + a.poly_count() as u64, ); } } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index b512fd8..ec4067f 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,136 +1,139 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; -use crate::{BACKEND, Infos, LAYOUT, Module, VecZnxBig, assert_alignement}; +use crate::{Backend, FFT64, Infos, Module, VecZnxBig, assert_alignement}; use crate::{DEFAULTALIGN, VecZnx, alloc_aligned}; +use std::marker::PhantomData; -pub struct VecZnxDft { +pub struct VecZnxDft { pub data: Vec, pub ptr: *mut u8, pub n: usize, - pub size: usize, - pub layout: LAYOUT, pub cols: usize, - pub backend: BACKEND, + pub limbs: usize, + pub _marker: PhantomData, } -impl VecZnxDft { +impl VecZnxDft { + pub fn new(module: &Module, cols: usize, limbs: usize) -> Self { + let mut data: Vec = alloc_aligned::(module.bytes_of_vec_znx_dft(cols, limbs)); + let ptr: *mut u8 = data.as_mut_ptr(); + Self { + data: data, + ptr: ptr, + n: module.n(), + limbs: limbs, + cols: cols, + _marker: PhantomData, + } + } /// Returns a new [VecZnxDft] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. - pub fn from_bytes(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft { + pub fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(size, cols)); + assert!(cols > 0); + assert!(limbs > 0); + assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols, limbs)); assert_alignement(bytes.as_ptr()) } unsafe { - VecZnxDft { + Self { data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), ptr: bytes.as_mut_ptr(), n: module.n(), - size: size, - layout: LAYOUT::COL, cols: cols, - backend: module.backend, + limbs: limbs, + _marker: PhantomData, } } } - pub fn from_bytes_borrow(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft { + pub fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(size, cols)); + assert!(cols > 0); + assert!(limbs > 0); + assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols, limbs)); assert_alignement(bytes.as_ptr()); } - VecZnxDft { + Self { data: Vec::new(), ptr: bytes.as_mut_ptr(), n: module.n(), - size: size, - layout: LAYOUT::COL, cols: cols, - backend: module.backend, + limbs: limbs, + _marker: PhantomData, } } /// Cast a [VecZnxDft] into a [VecZnxBig]. /// The returned [VecZnxBig] shares the backing array /// with the original [VecZnxDft]. - pub fn as_vec_znx_big(&mut self) -> VecZnxBig { - VecZnxBig { + pub fn as_vec_znx_big(&mut self) -> VecZnxBig { + VecZnxBig:: { data: Vec::new(), ptr: self.ptr, n: self.n, - layout: LAYOUT::COL, - size: self.size, cols: self.cols, - backend: self.backend, + limbs: self.limbs, + _marker: PhantomData, } } - pub fn backend(&self) -> BACKEND { - self.backend + pub fn raw(&self) -> &[f64] { + let ptr: *mut f64 = self.ptr as *mut f64; + let size: usize = self.n() * self.poly_count(); + unsafe { &std::slice::from_raw_parts(ptr, size) } } - /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is cols * n. - pub fn raw(&self, module: &Module) -> &[T] { - let ptr: *const T = self.ptr as *const T; - let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::(); - unsafe { &std::slice::from_raw_parts(ptr, len) } + pub fn at(&self, col_i: usize) -> &[f64] { + &self.raw()[col_i * self.n() * self.limbs()..(col_i + 1) * self.n() * self.limbs()] } - pub fn at(&self, module: &Module, col_i: usize) -> &[T] { - &self.raw::(module)[col_i * module.n()..(col_i + 1) * module.n()] + pub fn raw_mut(&mut self) -> &mut [f64] { + let ptr: *mut f64 = self.ptr as *mut f64; + let size: usize = self.n() * self.poly_count(); + unsafe { std::slice::from_raw_parts_mut(ptr, size) } } - /// Returns a mutable reference of `T` of the entire contiguous array of the [VecZnxDft]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is cols * n. - pub fn raw_mut(&self, module: &Module) -> &mut [T] { - let ptr: *mut T = self.ptr as *mut T; - let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::(); - unsafe { std::slice::from_raw_parts_mut(ptr, len) } - } - - pub fn at_mut(&self, module: &Module, col_i: usize) -> &mut [T] { - &mut self.raw_mut::(module)[col_i * module.n()..(col_i + 1) * module.n()] + pub fn at_mut(&mut self, col_i: usize) -> &mut [f64] { + let n: usize = self.n(); + let limbs:usize = self.limbs(); + &mut self.raw_mut()[col_i * n * limbs..(col_i + 1) * n * limbs] } } -impl Infos for VecZnxDft { - /// Returns the base 2 logarithm of the [VecZnx] degree. - fn log_n(&self) -> usize { - (usize::BITS - (self.n - 1).leading_zeros()) as _ - } - - /// Returns the [VecZnx] degree. +impl Infos for VecZnxDft { fn n(&self) -> usize { self.n } - fn layout(&self) -> LAYOUT { - self.layout + fn log_n(&self) -> usize { + (usize::BITS - (self.n() - 1).leading_zeros()) as _ + } + + fn rows(&self) -> usize { + 1 } - /// Returns the number of cols of the [VecZnx]. fn cols(&self) -> usize { self.cols } - /// Returns the number of rows of the [VecZnx]. - fn rows(&self) -> usize { - 1 + fn limbs(&self) -> usize { + self.limbs + } + + fn poly_count(&self) -> usize { + self.cols * self.limbs } } -pub trait VecZnxDftOps { +pub trait VecZnxDftOps { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. - fn new_vec_znx_dft(&self, size: usize, cols: usize) -> VecZnxDft; + fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -143,7 +146,7 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft; + fn new_vec_znx_dft_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -156,7 +159,7 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes_borrow(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft; + fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -167,61 +170,51 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn bytes_of_vec_znx_dft(&self, size: usize, cols: usize) -> usize; + fn bytes_of_vec_znx_dft(&self, cols: usize, limbs: usize) -> usize; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. fn vec_znx_idft_tmp_bytes(&self) -> usize; /// b <- IDFT(a), uses a as scratch space. - fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft); + fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft); - fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]); + fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]); - fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx); + fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx); - fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft); + fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft); - fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]); + fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]); fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize; } -impl VecZnxDftOps for Module { - fn new_vec_znx_dft(&self, size: usize, cols: usize) -> VecZnxDft { - let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_dft(size, cols)); - let ptr: *mut u8 = data.as_mut_ptr(); - VecZnxDft { - data: data, - ptr: ptr, - n: self.n(), - size: size, - layout: LAYOUT::COL, - cols: cols, - backend: self.backend(), - } +impl VecZnxDftOps for Module { + fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft { + VecZnxDft::::new(&self, cols, limbs) } - fn new_vec_znx_dft_from_bytes(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes(self, size, cols, tmp_bytes) + fn new_vec_znx_dft_from_bytes(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { + VecZnxDft::from_bytes(self, cols, limbs, tmp_bytes) } - fn new_vec_znx_dft_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes_borrow(self, size, cols, tmp_bytes) + fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { + VecZnxDft::from_bytes_borrow(self, cols, limbs, tmp_bytes) } - fn bytes_of_vec_znx_dft(&self, size: usize, cols: usize) -> usize { - unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize * size } + fn bytes_of_vec_znx_dft(&self, cols: usize, limbs: usize) -> usize { + unsafe { bytes_of_vec_znx_dft(self.ptr, limbs as u64) as usize * cols } } - fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) { + fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) { unsafe { vec_znx_dft::vec_znx_idft_tmp_a( self.ptr, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, a.ptr as *mut vec_znx_dft_t, - a.cols() as u64, + a.poly_count() as u64, ) } } @@ -234,21 +227,21 @@ impl VecZnxDftOps for Module { /// /// # Panics /// If b.cols < a_cols - fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx) { + fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx) { unsafe { vec_znx_dft::vec_znx_dft( self.ptr, b.ptr as *mut vec_znx_dft_t, - b.cols() as u64, + b.limbs() as u64, a.as_ptr(), - a.cols() as u64, - a.n() as u64, + a.limbs() as u64, + (a.n() * a.cols()) as u64, ) } } // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. - fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]) { + fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] { assert!( @@ -263,29 +256,29 @@ impl VecZnxDftOps for Module { vec_znx_dft::vec_znx_idft( self.ptr, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, a.ptr as *const vec_znx_dft_t, - a.cols() as u64, + a.poly_count() as u64, tmp_bytes.as_mut_ptr(), ) } } - fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft) { + fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft) { unsafe { vec_znx_dft::vec_znx_dft_automorphism( self.ptr, k, b.ptr as *mut vec_znx_dft_t, - b.cols() as u64, + b.poly_count() as u64, a.ptr as *const vec_znx_dft_t, - a.cols() as u64, + a.poly_count() as u64, [0u8; 0].as_mut_ptr(), ); } } - fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]) { + fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] { assert!( @@ -301,9 +294,9 @@ impl VecZnxDftOps for Module { self.ptr, k, a.ptr as *mut vec_znx_dft_t, - a.cols() as u64, + a.poly_count() as u64, a.ptr as *const vec_znx_dft_t, - a.cols() as u64, + a.poly_count() as u64, tmp_bytes.as_mut_ptr(), ); } @@ -321,41 +314,47 @@ impl VecZnxDftOps for Module { #[cfg(test)] mod tests { - use crate::{BACKEND, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned}; + use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned}; use itertools::izip; use sampling::source::{Source, new_seed}; #[test] fn test_automorphism_dft() { - let module: Module = Module::new(128, BACKEND::FFT64); + let module: Module = Module::::new(128); - let cols: usize = 2; + let limbs: usize = 2; let log_base2k: usize = 17; - let mut a: VecZnx = module.new_vec_znx(1, cols); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, cols); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, cols); + let mut a: VecZnx = module.new_vec_znx(1, limbs); + let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs); + let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs); let mut source: Source = Source::new(new_seed()); - module.fill_uniform(log_base2k, &mut a, cols, &mut source); + module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); let mut tmp_bytes: Vec = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes()); let p: i64 = -5; + + // a_dft <- DFT(a) module.vec_znx_dft(&mut a_dft, &a); + + // a_dft <- AUTO(a_dft) module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes); + println!("123"); + // a <- AUTO(a) module.vec_znx_automorphism_inplace(p, &mut a); // b_dft <- DFT(AUTO(a)) module.vec_znx_dft(&mut b_dft, &a); - let a_f64: &[f64] = a_dft.raw(&module); - let b_f64: &[f64] = b_dft.raw(&module); + let a_f64: &[f64] = a_dft.raw(); + let b_f64: &[f64] = b_dft.raw(); izip!(a_f64.iter(), b_f64.iter()).for_each(|(ai, bi)| { assert!((ai - bi).abs() <= 1e-9, "{:+e} > 1e-9", (ai - bi).abs()); }); diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index b04232d..05dd027 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -1,7 +1,8 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp::{self, vmp_pmat_t}; -use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement}; +use std::marker::PhantomData; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], /// stored as a 3D matrix in the DFT domain in a single contiguous array. @@ -9,28 +10,23 @@ use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_ /// /// [VmpPMat] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [VmpPMat]. /// See the trait [VmpPMatOps] for additional information. -pub struct VmpPMat { +pub struct VmpPMat { /// Raw data, is empty if borrowing scratch space. data: Vec, /// Pointer to data. Can point to scratch space. ptr: *mut u8, - /// The size of the decomposition basis (i.e. nb. [VecZnxDft]). - rows: usize, - /// The size of each [VecZnxDft]. - cols: usize, - /// The ring degree of each [VecZnxDft]. + /// The ring degree of each polynomial. n: usize, - /// 1nd dim: the number of stacked [VecZnxDft] per decomposition basis (row-dimension). - /// A value greater than one enables to compute a sum of [VecZnx] x [VmpPMat]. - /// 2st dim: the number of stacked [VecZnxDft] (col-dimension). - /// A value greater than one enables to compute multiple [VecZnx] x [VmpPMat] in parallel. - layout: LAYOUT, - /// The backend fft or ntt. - backend: BACKEND, + /// Number of rows + rows: usize, + /// Number of cols + cols: usize, + /// The number of small polynomials + limbs: usize, + _marker: PhantomData, } -impl Infos for VmpPMat { - /// Returns the ring dimension of the [VmpPMat]. +impl Infos for VmpPMat { fn n(&self) -> usize { self.n } @@ -39,29 +35,39 @@ impl Infos for VmpPMat { (usize::BITS - (self.n() - 1).leading_zeros()) as _ } - fn size(&self) -> usize { - self.size - } - - fn layout(&self) -> LAYOUT { - self.layout - } - - /// Returns the number of rows (i.e. of [VecZnxDft]) of the [VmpPMat] fn rows(&self) -> usize { self.rows } - /// Returns the number of cols of the [VmpPMat]. - /// The number of cols refers to the number of cols - /// of each [VecZnxDft]. - /// This method is equivalent to [Self::cols]. fn cols(&self) -> usize { self.cols } + + fn limbs(&self) -> usize { + self.limbs + } + + fn poly_count(&self) -> usize { + self.rows * self.cols * self.limbs + } } -impl VmpPMat { +impl VmpPMat { + + fn new(module: &Module, rows: usize, cols: usize, limbs: usize) -> VmpPMat { + let mut data: Vec = alloc_aligned::(module.bytes_of_vmp_pmat(rows, cols, limbs)); + let ptr: *mut u8 = data.as_mut_ptr(); + VmpPMat:: { + data: data, + ptr: ptr, + n: module.n(), + rows: rows, + cols: cols, + limbs: limbs, + _marker: PhantomData, + } + } + pub fn as_ptr(&self) -> *const u8 { self.ptr } @@ -74,41 +80,31 @@ impl VmpPMat { self.data.len() == 0 } - /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is rows * cols * n. - pub fn raw(&self) -> &[T] { - let ptr: *const T = self.ptr as *const T; - let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); - unsafe { &std::slice::from_raw_parts(ptr, len) } + /// Returns a non-mutable reference to the entire contiguous array of the [VmpPMat]. + pub fn raw(&self) -> &[f64] { + let ptr: *const f64 = self.ptr as *const f64; + let size: usize = self.n() * self.poly_count(); + unsafe { &std::slice::from_raw_parts(ptr, size) } } - /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is rows * cols * n. - pub fn raw_mut(&self) -> &mut [T] { - let ptr: *mut T = self.ptr as *mut T; - let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); - unsafe { std::slice::from_raw_parts_mut(ptr, len) } + /// Returns a mutable reference of to the entire contiguous array of the [VmpPMat]. + pub fn raw_mut(&self) -> &mut [f64] { + let ptr: *mut f64 = self.ptr as *mut f64; + let size: usize = self.n() * self.poly_count(); + unsafe { std::slice::from_raw_parts_mut(ptr, size) } } /// Returns a copy of the backend array at index (i, j) of the [VmpPMat]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. /// /// # Arguments /// /// * `row`: row index (i). /// * `col`: col index (j). - pub fn at(&self, row: usize, col: usize) -> Vec { - let mut res: Vec = alloc_aligned(self.n); + pub fn at(&self, row: usize, col: usize) -> Vec { + let mut res: Vec = alloc_aligned(self.n); if self.n < 8 { - res.copy_from_slice( - &self.raw::()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)], - ); + res.copy_from_slice(&self.raw()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)]); } else { (0..self.n >> 3).for_each(|blk| { res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); @@ -118,43 +114,37 @@ impl VmpPMat { res } - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - fn at_block(&self, row: usize, col: usize, blk: usize) -> &[T] { + fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { let nrows: usize = self.rows(); - let ncols: usize = self.cols(); - if col == (ncols - 1) && (ncols & 1 == 1) { - &self.raw::()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..] + let nsize: usize = self.limbs(); + if col == (nsize - 1) && (nsize & 1 == 1) { + &self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..] } else { - &self.raw::()[blk * nrows * ncols * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] + &self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] } } - - fn backend(&self) -> BACKEND { - self.backend - } } /// This trait implements methods for vector matrix product, /// that is, multiplying a [VecZnx] with a [VmpPMat]. -pub trait VmpPMatOps { - fn bytes_of_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> usize; +pub trait VmpPMatOps { + fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> usize; /// Allocates a new [VmpPMat] with the given number of rows and columns. /// /// # Arguments /// /// * `rows`: number of rows (number of [VecZnxDft]). - /// * `cols`: number of cols (number of cols of each [VecZnxDft]). - fn new_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> VmpPMat; + /// * `size`: number of size (number of size of each [VecZnxDft]). + fn new_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> VmpPMat; /// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous]. /// /// # Arguments /// /// * `rows`: number of rows of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. - /// * `cols`: number of cols of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. - fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize) -> usize; + /// * `size`: number of size of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. + fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize; /// Prepares a [VmpPMat] from a contiguous array of [i64]. /// The helper struct [Matrix3D] can be used to contruct and populate @@ -165,18 +155,7 @@ pub trait VmpPMatOps { /// * `b`: [VmpPMat] on which the values are encoded. /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat]. /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]); - - /// Prepares a [VmpPMat] from a vector of [VecZnx]. - /// - /// # Arguments - /// - /// * `b`: [VmpPMat] on which the values are encoded. - /// * `a`: the vector of [VecZnx] to encode on the [VmpPMat]. - /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - /// - /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]); + fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]); /// Prepares the ith-row of [VmpPMat] from a [VecZnx]. /// @@ -188,7 +167,7 @@ pub trait VmpPMatOps { /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); + fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); /// Extracts the ith-row of [VmpPMat] into a [VecZnxBig]. /// @@ -197,7 +176,7 @@ pub trait VmpPMatOps { /// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat]. /// * `a`: [VmpPMat] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize); + fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize); /// Prepares the ith-row of [VmpPMat] from a [VecZnxDft]. /// @@ -208,7 +187,7 @@ pub trait VmpPMatOps { /// * `row_i`: the index of the row to prepare. /// /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize); + fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize); /// Extracts the ith-row of [VmpPMat] into a [VecZnxDft]. /// @@ -217,17 +196,17 @@ pub trait VmpPMatOps { /// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat]. /// * `a`: [VmpPMat] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize); + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize); /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. /// /// # Arguments /// - /// * `c_cols`: number of cols of the output [VecZnxDft]. - /// * `a_cols`: number of cols of the input [VecZnx]. + /// * `c_size`: number of size of the output [VecZnxDft]. + /// * `a_size`: number of size of the input [VecZnx]. /// * `rows`: number of rows of the input [VmpPMat]. - /// * `cols`: number of cols of the input [VmpPMat]. - fn vmp_apply_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize; + /// * `size`: number of size of the input [VmpPMat]. + fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; /// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. /// @@ -235,8 +214,8 @@ pub trait VmpPMatOps { /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. + /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. /// @@ -253,7 +232,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver. /// @@ -261,8 +240,8 @@ pub trait VmpPMatOps { /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. + /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. /// @@ -279,17 +258,17 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft]. /// /// # Arguments /// - /// * `c_cols`: number of cols of the output [VecZnxDft]. - /// * `a_cols`: number of cols of the input [VecZnxDft]. + /// * `c_size`: number of size of the output [VecZnxDft]. + /// * `a_size`: number of size of the input [VecZnxDft]. /// * `rows`: number of rows of the input [VmpPMat]. - /// * `cols`: number of cols of the input [VmpPMat]. - fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize; + /// * `size`: number of size of the input [VmpPMat]. + fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; /// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -298,8 +277,8 @@ pub trait VmpPMatOps { /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. + /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. /// @@ -316,7 +295,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it. /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -325,8 +304,8 @@ pub trait VmpPMatOps { /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. + /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. /// @@ -343,7 +322,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place. /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -352,8 +331,8 @@ pub trait VmpPMatOps { /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. + /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. /// @@ -369,38 +348,29 @@ pub trait VmpPMatOps { /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. /// * `a`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]); } -impl VmpPMatOps for Module { - fn bytes_of_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> usize { - unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, cols as u64) as usize * size } +impl VmpPMatOps for Module { + + fn new_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> VmpPMat { + VmpPMat::::new(self, rows, cols, limbs) } - fn new_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> VmpPMat { - let mut data: Vec = alloc_aligned::(self.bytes_of_vmp_pmat(size, rows, cols)); - let ptr: *mut u8 = data.as_mut_ptr(); - VmpPMat { - data: data, - ptr: ptr, - n: self.n(), - size: size, - layout: LAYOUT::COL, - cols: cols, - rows: rows, - backend: self.backend(), - } + fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> usize { + unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs* cols) as u64) as usize } } - fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize) -> usize { - unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, cols as u64) as usize } + fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize { + unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize } } - fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) { - debug_assert_eq!(a.len(), b.n * b.rows * b.cols); - debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols())); + fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) { + #[cfg(debug_assertions)] { + assert_eq!(a.len(), b.n() * b.poly_count()); + assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs())); assert_alignement(tmp_bytes.as_ptr()); } unsafe { @@ -409,40 +379,17 @@ impl VmpPMatOps for Module { b.as_mut_ptr() as *mut vmp_pmat_t, a.as_ptr(), b.rows() as u64, - b.cols() as u64, + (b.limbs()*b.cols()) as u64, tmp_bytes.as_mut_ptr(), ); } } - fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], tmp_bytes: &mut [u8]) { - let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect(); + fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] - { - debug_assert_eq!(a.len(), b.rows); - a.iter().for_each(|ai| { - debug_assert_eq!(ai.len(), b.n * b.cols); - }); - debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols())); - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_prepare_dblptr( - self.ptr, - b.as_mut_ptr() as *mut vmp_pmat_t, - ptrs.as_ptr(), - b.rows() as u64, - b.cols() as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { - #[cfg(debug_assertions)] - { - assert_eq!(a.len(), b.cols() * self.n()); - assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols())); + { + assert_eq!(a.len(), b.limbs() * self.n() * b.cols()); + assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs())); assert_alignement(tmp_bytes.as_ptr()); } unsafe { @@ -452,16 +399,17 @@ impl VmpPMatOps for Module { a.as_ptr(), row_i as u64, b.rows() as u64, - b.cols() as u64, + (b.limbs()*b.cols()) as u64, tmp_bytes.as_mut_ptr(), ); } } - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize) { + fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); + assert_eq!(a.limbs(), b.limbs()); assert_eq!(a.cols(), b.cols()); } unsafe { @@ -471,16 +419,16 @@ impl VmpPMatOps for Module { a.as_ptr() as *const vmp_pmat_t, row_i as u64, a.rows() as u64, - a.cols() as u64, + (a.limbs()*a.cols()) as u64, ); } } - fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize) { + fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); - assert_eq!(a.cols(), b.cols()); + assert_eq!(a.limbs(), b.limbs()); } unsafe { vmp::vmp_prepare_row_dft( @@ -489,16 +437,16 @@ impl VmpPMatOps for Module { a.ptr as *const vec_znx_dft_t, row_i as u64, b.rows() as u64, - b.cols() as u64, + b.limbs() as u64, ); } } - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize) { + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); - assert_eq!(a.cols(), b.cols()); + assert_eq!(a.limbs(), b.limbs()); } unsafe { vmp::vmp_extract_row_dft( @@ -507,48 +455,47 @@ impl VmpPMatOps for Module { a.as_ptr() as *const vmp_pmat_t, row_i as u64, a.rows() as u64, - a.cols() as u64, + a.limbs() as u64, ); } } - fn vmp_apply_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize { + fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize { unsafe { vmp::vmp_apply_dft_tmp_bytes( self.ptr, - res_cols as u64, - a_cols as u64, + res_size as u64, + a_size as u64, gct_rows as u64, - gct_cols as u64, + gct_size as u64, ) as usize } } - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())); + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); - assert_eq!(a.size()*a.size(), b.size()); } unsafe { vmp::vmp_apply_dft( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, + c.limbs() as u64, a.as_ptr(), - a.cols() as u64, - (a.n()*a.size()) as u64, + a.limbs() as u64, + (a.n() * a.cols()) as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.cols() as u64, + b.limbs() as u64, tmp_bytes.as_mut_ptr(), ) } } - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())); + fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -557,32 +504,32 @@ impl VmpPMatOps for Module { vmp::vmp_apply_dft_add( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, + c.limbs() as u64, a.as_ptr(), - a.cols() as u64, - (a.n()*a.size()) as u64, + a.limbs() as u64, + (a.n() * a.limbs()) as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.cols() as u64, + b.limbs() as u64, tmp_bytes.as_mut_ptr(), ) } } - fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize { + fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize { unsafe { vmp::vmp_apply_dft_to_dft_tmp_bytes( self.ptr, - res_cols as u64, - a_cols as u64, + res_size as u64, + a_size as u64, gct_rows as u64, - gct_cols as u64, + gct_size as u64, ) as usize } } - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())); + fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -591,19 +538,19 @@ impl VmpPMatOps for Module { vmp::vmp_apply_dft_to_dft( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, + c.limbs() as u64, a.ptr as *const vec_znx_dft_t, - a.cols() as u64, + a.limbs() as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.cols() as u64, + b.limbs() as u64, tmp_bytes.as_mut_ptr(), ) } } - fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())); + fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -612,19 +559,19 @@ impl VmpPMatOps for Module { vmp::vmp_apply_dft_to_dft_add( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, + c.limbs() as u64, a.ptr as *const vec_znx_dft_t, - a.cols() as u64, + a.limbs() as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.cols() as u64, + b.limbs() as u64, tmp_bytes.as_mut_ptr(), ) } } - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols())); + fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.limbs(), b.limbs(), a.rows(), a.limbs())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -633,12 +580,12 @@ impl VmpPMatOps for Module { vmp::vmp_apply_dft_to_dft( self.ptr, b.ptr as *mut vec_znx_dft_t, - b.cols() as u64, + b.limbs() as u64, b.ptr as *mut vec_znx_dft_t, - b.cols() as u64, + b.limbs() as u64, a.as_ptr() as *const vmp_pmat_t, a.rows() as u64, - a.cols() as u64, + a.limbs() as u64, tmp_bytes.as_mut_ptr(), ) } @@ -648,44 +595,45 @@ impl VmpPMatOps for Module { #[cfg(test)] mod tests { use crate::{ - Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, alloc_aligned, + FFT64, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, + alloc_aligned, }; use sampling::source::Source; #[test] fn vmp_prepare_row_dft() { - let module: Module = Module::new(32, crate::BACKEND::FFT64); + let module: Module = Module::::new(32); let vpmat_rows: usize = 4; - let vpmat_cols: usize = 5; + let vpmat_size: usize = 5; let log_base2k: usize = 8; - let mut a: VecZnx = module.new_vec_znx(1, vpmat_cols); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_cols); - let mut a_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_cols); - let mut b_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_cols); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_cols); - let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(1, vpmat_rows, vpmat_cols); - let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(1, vpmat_rows, vpmat_cols); + let mut a: VecZnx = module.new_vec_znx(1, vpmat_size); + let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_size); + let mut a_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); + let mut b_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); + let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_size); + let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, 1, vpmat_size); + let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, 1, vpmat_size); - let mut tmp_bytes: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols)); + let mut tmp_bytes: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, 1, vpmat_size)); for row_i in 0..vpmat_rows { let mut source: Source = Source::new([0u8; 32]); - module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source); + module.fill_uniform(log_base2k, &mut a, 0, vpmat_size, &mut source); module.vec_znx_dft(&mut a_dft, &a); module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes); // Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft) module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i); - assert_eq!(vmpmat_0.raw::(), vmpmat_1.raw::()); + assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); // Checks that a_dft = extract_dft(prepare(vmp_pmat, a), b_dft) module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i); - assert_eq!(a_dft.raw::(&module), b_dft.raw::(&module)); + assert_eq!(a_dft.raw(), b_dft.raw()); // Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big) module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes); - assert_eq!(a_big.raw::(&module), b_big.raw::(&module)); + assert_eq!(a_big.raw(), b_big.raw()); } module.free(); diff --git a/rlwe/src/ciphertext.rs b/rlwe/src/ciphertext.rs index 9d1fe1a..73addb5 100644 --- a/rlwe/src/ciphertext.rs +++ b/rlwe/src/ciphertext.rs @@ -1,6 +1,6 @@ use crate::elem::{Elem, ElemCommon}; use crate::parameters::Parameters; -use base2k::{Infos, LAYOUT, Module, VecZnx, VmpPMat}; +use base2k::{Infos, Layout, Module, VecZnx, VmpPMat}; pub struct Ciphertext(pub Elem); @@ -38,7 +38,7 @@ where self.elem().size() } - fn layout(&self) -> LAYOUT { + fn layout(&self) -> Layout { self.elem().layout() } diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index e7e61c4..656cc3a 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,4 +1,4 @@ -use base2k::{Infos, LAYOUT, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps}; +use base2k::{Infos, Layout, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps}; pub struct Elem { pub value: Vec, @@ -71,7 +71,7 @@ pub trait ElemCommon { fn elem(&self) -> &Elem; fn elem_mut(&mut self) -> &mut Elem; fn size(&self) -> usize; - fn layout(&self) -> LAYOUT; + fn layout(&self) -> Layout; fn rows(&self) -> usize; fn cols(&self) -> usize; fn log_base2k(&self) -> usize; @@ -102,7 +102,7 @@ impl ElemCommon for Elem { self.value.len() } - fn layout(&self) -> LAYOUT { + fn layout(&self) -> Layout { self.value[0].layout() } diff --git a/rlwe/src/plaintext.rs b/rlwe/src/plaintext.rs index 86f7e32..258756b 100644 --- a/rlwe/src/plaintext.rs +++ b/rlwe/src/plaintext.rs @@ -1,7 +1,7 @@ use crate::ciphertext::Ciphertext; use crate::elem::{Elem, ElemCommon, ElemVecZnx}; use crate::parameters::Parameters; -use base2k::{LAYOUT, Module, VecZnx}; +use base2k::{Layout, Module, VecZnx}; pub struct Plaintext(pub Elem); @@ -79,7 +79,7 @@ impl ElemCommon for Plaintext { self.elem().size() } - fn layout(&self) -> LAYOUT { + fn layout(&self) -> Layout { self.elem().layout() }