diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 1d977ff..f66a4d1 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -18,7 +18,7 @@ fn main() { let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); - let mut res: VecZnx = module.new_vec_znx(cols); + let mut res: VecZnx = module.new_vec_znx(1, cols); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) let mut s: Scalar = Scalar::new(n); @@ -31,11 +31,11 @@ fn main() { module.svp_prepare(&mut s_ppol, &s); // a <- Z_{2^prec}[X]/(X^{N}+1) - let mut a: VecZnx = module.new_vec_znx(cols); + let mut a: VecZnx = module.new_vec_znx(1, cols); module.fill_uniform(log_base2k, &mut a, cols, &mut source); // Scratch space for DFT values - let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.cols()); + let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.cols()); // Applies buf_dft <- s * a module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); @@ -46,21 +46,21 @@ fn main() { // buf_big <- IDFT(buf_dft) (not normalized) module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); - let mut m: VecZnx = module.new_vec_znx(msg_cols); + let mut m: VecZnx = module.new_vec_znx(1, msg_cols); let mut want: Vec = vec![0; n]; want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); // m - m.encode_vec_i64(log_base2k, log_scale, &want, 4); + m.encode_vec_i64(0, log_base2k, log_scale, &want, 4); m.normalize(log_base2k, &mut carry); // buf_big <- m - buf_big module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m); // b <- normalize(buf_big) + e - let mut b: VecZnx = module.new_vec_znx(cols); + let mut b: VecZnx = module.new_vec_znx(1, cols); module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry); module.add_normal( log_base2k, @@ -85,7 +85,7 @@ fn main() { // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; - res.decode_vec_i64(log_base2k, res.cols() * log_base2k, &mut have); + res.decode_vec_i64(0, log_base2k, res.cols() * log_base2k, &mut have); let scale: f64 = (1 << (res.cols() * log_base2k - log_scale)) as f64; izip!(want.iter(), have.iter()) diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index be40e25..a69c857 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -1,6 +1,6 @@ use base2k::{ - BACKEND, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VecZnxVec, VmpPMat, - VmpPMatOps, alloc_aligned, + BACKEND, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, + alloc_aligned, }; fn main() { @@ -23,40 +23,34 @@ fn main() { let mut a_values: Vec = vec![i64::default(); n]; a_values[1] = (1 << log_base2k) + 1; - let mut a: VecZnx = module.new_vec_znx(cols); - a.encode_vec_i64(log_base2k, log_k, &a_values, 32); + let mut a: VecZnx = module.new_vec_znx(1, rows); + a.encode_vec_i64(0, log_base2k, log_k, &a_values, 32); a.normalize(log_base2k, &mut buf); - a.print(a.cols(), n); + a.print(0, a.cols(), n); println!(); - let mut vecznx: Vec = Vec::new(); - (0..rows).for_each(|_| { - vecznx.push(module.new_vec_znx(cols)); + let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(1, rows, cols); + + (0..a.cols()).for_each(|row_i| { + let mut tmp: VecZnx = module.new_vec_znx(1, cols); + tmp.at_mut(row_i)[1] = 1 as i64; + module.vmp_prepare_row(&mut vmp_pmat, tmp.raw(), row_i, &mut buf); }); - (0..rows).for_each(|i| { - vecznx[i].raw_mut()[i * n + 1] = 1 as i64; - }); - - let slices: Vec<&[i64]> = vecznx.dblptr(); - - let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - module.vmp_prepare_dblptr(&mut vmp_pmat, &slices, &mut buf); - - let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols); + let mut c_dft: VecZnxDft = module.new_vec_znx_dft(1, cols); module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); let mut c_big: VecZnxBig = c_dft.as_vec_znx_big(); module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft); - let mut res: VecZnx = module.new_vec_znx(cols); + let mut res: VecZnx = module.new_vec_znx(1, rows); module.vec_znx_big_normalize(log_base2k, &mut res, &c_big, &mut buf); let mut values_res: Vec = vec![i64::default(); n]; - res.decode_vec_i64(log_base2k, log_k, &mut values_res); + res.decode_vec_i64(0, log_base2k, log_k, &mut values_res); - res.print(res.cols(), n); + res.print(0, res.cols(), n); module.free(); diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 4615838..c8c08e9 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -9,94 +9,104 @@ pub trait Encoding { /// /// # Arguments /// - /// * `log_base2k`: base two logarithm decomposition of the receiver. - /// * `log_k`: base two logarithm of the scaling of the data. + /// * `poly_idx`: the index of the poly where to encode the data. + /// * `log_base2k`: base two negative logarithm decomposition of the receiver. + /// * `log_k`: base two negative logarithm of the scaling of the data. /// * `data`: data to encode on the receiver. /// * `log_max`: base two logarithm of the infinity norm of the input data. - fn encode_vec_i64(&mut self, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize); + fn encode_vec_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize); /// decode a vector of i64 from the receiver. /// /// # Arguments /// - /// * `log_base2k`: base two logarithm decomposition of the receiver. + /// * `poly_idx`: the index of the poly where to encode the data. + /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `log_k`: base two logarithm of the scaling of the data. /// * `data`: data to decode from the receiver. - fn decode_vec_i64(&self, log_base2k: usize, log_k: usize, data: &mut [i64]); + fn decode_vec_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]); /// decode a vector of Float from the receiver. /// /// # Arguments - /// * `log_base2k`: base two logarithm decomposition of the receiver. + /// * `poly_idx`: the index of the poly where to encode the data. + /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `data`: data to decode from the receiver. - fn decode_vec_float(&self, log_base2k: usize, data: &mut [Float]); + fn decode_vec_float(&self, poly_idx: usize, log_base2k: usize, data: &mut [Float]); /// encodes a single i64 on the receiver at the given index. /// /// # Arguments /// - /// * `log_base2k`: base two logarithm decomposition of the receiver. - /// * `log_k`: base two logarithm of the scaling of the data. + /// * `poly_idx`: the index of the poly where to encode the data. + /// * `log_base2k`: base two negative logarithm decomposition of the receiver. + /// * `log_k`: base two negative logarithm of the scaling of the data. /// * `i`: index of the coefficient on which to encode the data. /// * `data`: data to encode on the receiver. /// * `log_max`: base two logarithm of the infinity norm of the input data. - fn encode_coeff_i64(&mut self, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize); + fn encode_coeff_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize); /// decode a single of i64 from the receiver at the given index. /// /// # Arguments /// - /// * `log_base2k`: base two logarithm decomposition of the receiver. - /// * `log_k`: base two logarithm of the scaling of the data. + /// * `poly_idx`: the index of the poly where to encode the data. + /// * `log_base2k`: base two negative logarithm decomposition of the receiver. + /// * `log_k`: base two negative logarithm of the scaling of the data. /// * `i`: index of the coefficient to decode. /// * `data`: data to decode from the receiver. - fn decode_coeff_i64(&self, log_base2k: usize, log_k: usize, i: usize) -> i64; + fn decode_coeff_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64; } impl Encoding for VecZnx { - fn encode_vec_i64(&mut self, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { - encode_vec_i64(self, log_base2k, log_k, data, log_max) + fn encode_vec_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { + encode_vec_i64(self, poly_idx, log_base2k, log_k, data, log_max) } - fn decode_vec_i64(&self, log_base2k: usize, log_k: usize, data: &mut [i64]) { - decode_vec_i64(self, log_base2k, log_k, data) + fn decode_vec_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { + decode_vec_i64(self, poly_idx, log_base2k, log_k, data) } - fn decode_vec_float(&self, log_base2k: usize, data: &mut [Float]) { - decode_vec_float(self, log_base2k, data) + fn decode_vec_float(&self, poly_idx: usize, log_base2k: usize, data: &mut [Float]) { + decode_vec_float(self, poly_idx, log_base2k, data) } - fn encode_coeff_i64(&mut self, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { - encode_coeff_i64(self, log_base2k, log_k, i, value, log_max) + fn encode_coeff_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { + encode_coeff_i64(self, poly_idx, log_base2k, log_k, i, value, log_max) } - fn decode_coeff_i64(&self, log_base2k: usize, log_k: usize, i: usize) -> i64 { - decode_coeff_i64(self, log_base2k, log_k, i) + fn decode_coeff_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { + decode_coeff_i64(self, poly_idx, log_base2k, log_k, i) } } -fn encode_vec_i64(a: &mut VecZnx, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { +fn encode_vec_i64(a: &mut VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { let cols: usize = (log_k + log_base2k - 1) / log_base2k; - debug_assert!( - cols <= a.cols(), - "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", - cols, - a.cols() - ); + #[cfg(debug_assertions)] + { + assert!( + cols <= a.cols(), + "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", + cols, + a.cols() + ); + assert!(poly_idx < a.size); + assert!(data.len() <= a.n()) + } - let size: usize = min(data.len(), a.n()); + let data_len: usize = data.len(); let log_k_rem: usize = log_base2k - (log_k % log_base2k); (0..a.cols()).for_each(|i| unsafe { - znx_zero_i64_ref(size as u64, a.at_mut(i).as_mut_ptr()); + znx_zero_i64_ref(a.n() as u64, a.at_poly_mut_ptr(poly_idx, i)); }); // If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_mut(cols - 1)[..size].copy_from_slice(&data[..size]); + a.at_poly_mut(poly_idx, cols - 1)[..data_len].copy_from_slice(&data[..data_len]); } else { let mask: i64 = (1 << log_base2k) - 1; let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); @@ -105,7 +115,7 @@ fn encode_vec_i64(a: &mut VecZnx, log_base2k: usize, log_k: usize, data: &[i64], .enumerate() .for_each(|(i, i_rev)| { let shift: usize = i * log_base2k; - izip!(a.at_mut(i_rev)[..size].iter_mut(), data[..size].iter()).for_each(|(y, x)| *y = (x >> shift) & mask); + izip!(a.at_poly_mut(poly_idx, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); }) } @@ -113,45 +123,53 @@ fn encode_vec_i64(a: &mut VecZnx, log_base2k: usize, log_k: usize, data: &[i64], if log_k_rem != log_base2k { let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); (cols - steps..cols).rev().for_each(|i| { - a.at_mut(i)[..size] + a.at_poly_mut(poly_idx, i)[..data_len] .iter_mut() .for_each(|x| *x <<= log_k_rem); }) } } -fn decode_vec_i64(a: &VecZnx, log_base2k: usize, log_k: usize, data: &mut [i64]) { +fn decode_vec_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { let cols: usize = (log_k + log_base2k - 1) / log_base2k; - debug_assert!( - data.len() >= a.n(), - "invalid data: data.len()={} < a.n()={}", - data.len(), - a.n() - ); - data.copy_from_slice(a.at(0)); + #[cfg(debug_assertions)] + { + assert!( + data.len() >= a.n(), + "invalid data: data.len()={} < a.n()={}", + data.len(), + a.n() + ); + assert!(poly_idx < a.size()); + } + data.copy_from_slice(a.at_poly(poly_idx, 0)); let rem: usize = log_base2k - (log_k % log_base2k); (1..cols).for_each(|i| { if i == cols - 1 && rem != log_base2k { let k_rem: usize = log_base2k - rem; - izip!(a.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(poly_idx, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << k_rem) + (x >> rem); }); } else { - izip!(a.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(poly_idx, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << log_base2k) + x; }); } }) } -fn decode_vec_float(a: &VecZnx, log_base2k: usize, data: &mut [Float]) { +fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [Float]) { let cols: usize = a.cols(); - debug_assert!( - data.len() >= a.n(), - "invalid data: data.len()={} < a.n()={}", - data.len(), - a.n() - ); + #[cfg(debug_assertions)] + { + assert!( + data.len() >= a.n(), + "invalid data: data.len()={} < a.n()={}", + data.len(), + a.n() + ); + assert!(poly_idx < a.size()); + } let prec: u32 = (log_base2k * cols) as u32; @@ -161,12 +179,12 @@ fn decode_vec_float(a: &VecZnx, log_base2k: usize, data: &mut [Float]) { // y[i] = sum x[j][i] * 2^{-log_base2k*j} (0..cols).for_each(|i| { if i == 0 { - izip!(a.at(cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(poly_idx, cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { y.assign(*x); *y /= &base; }); } else { - izip!(a.at(cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(poly_idx, cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { *y += Float::with_val(prec, *x); *y /= &base; }); @@ -174,23 +192,29 @@ fn decode_vec_float(a: &VecZnx, log_base2k: usize, data: &mut [Float]) { }); } -fn encode_coeff_i64(a: &mut VecZnx, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { - debug_assert!(i < a.n()); +fn encode_coeff_i64(a: &mut VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { let cols: usize = (log_k + log_base2k - 1) / log_base2k; - debug_assert!( - cols <= a.cols(), - "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", - cols, - a.cols() - ); + + #[cfg(debug_assertions)] + { + assert!(i < a.n()); + assert!( + cols <= a.cols(), + "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", + cols, + a.cols() + ); + assert!(poly_idx < a.size()); + } + let log_k_rem: usize = log_base2k - (log_k % log_base2k); - (0..a.cols()).for_each(|j| a.at_mut(j)[i] = 0); + (0..a.cols()).for_each(|j| a.at_poly_mut(poly_idx, j)[i] = 0); // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_mut(cols - 1)[i] = value; + a.at_poly_mut(poly_idx, cols - 1)[i] = value; } else { let mask: i64 = (1 << log_base2k) - 1; let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); @@ -198,7 +222,7 @@ fn encode_coeff_i64(a: &mut VecZnx, log_base2k: usize, log_k: usize, i: usize, v .rev() .enumerate() .for_each(|(j, j_rev)| { - a.at_mut(j_rev)[i] = (value >> (j * log_base2k)) & mask; + a.at_poly_mut(poly_idx, j_rev)[i] = (value >> (j * log_base2k)) & mask; }) } @@ -206,19 +230,25 @@ fn encode_coeff_i64(a: &mut VecZnx, log_base2k: usize, log_k: usize, i: usize, v if log_k_rem != log_base2k { let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); (cols - steps..cols).rev().for_each(|j| { - a.at_mut(j)[i] <<= log_k_rem; + a.at_poly_mut(poly_idx, j)[i] <<= log_k_rem; }) } } -fn decode_coeff_i64(a: &VecZnx, log_base2k: usize, log_k: usize, i: usize) -> i64 { +fn decode_coeff_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { + #[cfg(debug_assertions)] + { + assert!(i < a.n()); + assert!(poly_idx < a.size()) + } + let cols: usize = (log_k + log_base2k - 1) / log_base2k; - debug_assert!(i < a.n()); let data: &[i64] = a.raw(); let mut res: i64 = data[i]; let rem: usize = log_base2k - (log_k % log_base2k); + let slice_size: usize = a.n() * a.size(); (1..cols).for_each(|i| { - let x = data[i * a.n()]; + let x = data[i * slice_size]; if i == cols - 1 && rem != log_base2k { let k_rem: usize = log_base2k - rem; res = (res << k_rem) + (x >> rem); @@ -231,7 +261,7 @@ fn decode_coeff_i64(a: &VecZnx, log_base2k: usize, log_k: usize, i: usize) -> i6 #[cfg(test)] mod tests { - use crate::{Encoding, VecZnx}; + use crate::{Encoding, Infos, VecZnx}; use itertools::izip; use sampling::source::Source; @@ -241,15 +271,19 @@ mod tests { let log_base2k: usize = 17; let cols: usize = 5; let log_k: usize = cols * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(n, cols); - let mut have: Vec = vec![i64::default(); n]; - have.iter_mut() - .enumerate() - .for_each(|(i, x)| *x = (i as i64) - (n as i64) / 2); - a.encode_vec_i64(log_base2k, log_k, &have, 10); - let mut want = vec![i64::default(); n]; - a.decode_vec_i64(log_base2k, log_k, &mut want); - izip!(want, have).for_each(|(a, b)| assert_eq!(a, b)); + let mut a: VecZnx = VecZnx::new(n, 2, cols); + let mut source: Source = Source::new([0u8; 32]); + let raw: &mut [i64] = a.raw_mut(); + raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + (0..a.size()).for_each(|poly_idx| { + let mut have: Vec = vec![i64::default(); n]; + have.iter_mut() + .for_each(|x| *x = (source.next_i64() << 56) >> 56); + a.encode_vec_i64(poly_idx, log_base2k, log_k, &have, 10); + let mut want: Vec = vec![i64::default(); n]; + a.decode_vec_i64(poly_idx, log_base2k, log_k, &mut want); + izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + }); } #[test] @@ -258,19 +292,17 @@ mod tests { let log_base2k: usize = 17; let cols: usize = 5; let log_k: usize = cols * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(n, cols); - let mut have: Vec = vec![i64::default(); n]; - let mut source = Source::new([1; 32]); - have.iter_mut().for_each(|x| { - *x = source - .next_u64n(u64::MAX, u64::MAX) - .wrapping_sub(u64::MAX / 2 + 1) as i64; - }); - a.encode_vec_i64(log_base2k, log_k, &have, 63); - //(0..a.cols()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i))); - let mut want = vec![i64::default(); n]; - //(0..a.cols()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i))); - a.decode_vec_i64(log_base2k, log_k, &mut want); - izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + let mut a: VecZnx = VecZnx::new(n, 2, cols); + let mut source = Source::new([0u8; 32]); + let raw: &mut [i64] = a.raw_mut(); + raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + (0..a.size()).for_each(|poly_idx| { + let mut have: Vec = vec![i64::default(); n]; + have.iter_mut().for_each(|x| *x = source.next_i64()); + a.encode_vec_i64(poly_idx, log_base2k, log_k, &have, 64); + let mut want = vec![i64::default(); n]; + a.decode_vec_i64(poly_idx, log_base2k, log_k, &mut want); + izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + }) } } diff --git a/base2k/src/infos.rs b/base2k/src/infos.rs index 6898c94..08472d9 100644 --- a/base2k/src/infos.rs +++ b/base2k/src/infos.rs @@ -1,3 +1,5 @@ +use crate::LAYOUT; + pub trait Infos { /// Returns the ring degree of the receiver. fn n(&self) -> usize; @@ -5,6 +7,12 @@ pub trait Infos { /// Returns the base two logarithm of the ring dimension of the receiver. fn log_n(&self) -> usize; + /// Returns the number of stacked polynomials. + fn size(&self) -> usize; + + /// Returns the memory layout of the stacked polynomials. + fn layout(&self) -> LAYOUT; + /// Returns the number of columns of the receiver. /// This method is equivalent to [Infos::cols]. fn cols(&self) -> usize; diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index ec0d2b7..7e97b00 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -27,6 +27,13 @@ pub use vmp::*; pub const GALOISGENERATOR: u64 = 5; pub const DEFAULTALIGN: usize = 64; +#[derive(Copy, Clone)] +#[repr(u8)] +pub enum LAYOUT { + ROW, + COL, +} + pub fn is_aligned_custom(ptr: *const T, align: usize) -> bool { (ptr as usize) % align == 0 } diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs index 776ae75..f72ebaa 100644 --- a/base2k/src/stats.rs +++ b/base2k/src/stats.rs @@ -4,10 +4,10 @@ use rug::float::Round; use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound}; impl VecZnx { - pub fn std(&self, log_base2k: usize) -> f64 { + pub fn std(&self, poly_idx: usize, log_base2k: usize) -> f64 { let prec: u32 = (self.cols() * log_base2k) as u32; let mut data: Vec = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); - self.decode_vec_float(log_base2k, &mut data); + self.decode_vec_float(poly_idx, log_base2k, &mut data); // std = sqrt(sum((xi - avg)^2) / n) let mut avg: Float = Float::with_val(prec, 0); data.iter().for_each(|x| { diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs index 9b2c64b..0e85a31 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/svp.rs @@ -1,6 +1,6 @@ use crate::ffi::svp::{self, svp_ppol_t}; use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::{BACKEND, Module, VecZnx, VecZnxDft, assert_alignement}; +use crate::{BACKEND, LAYOUT, Module, VecZnx, VecZnxDft, assert_alignement}; use crate::{Infos, alloc_aligned, cast_mut}; use rand::seq::SliceRandom; @@ -117,7 +117,9 @@ impl Scalar { pub fn as_vec_znx(&self) -> VecZnx { VecZnx { n: self.n, + size: 1, // TODO REVIEW IF NEED TO ADD size TO SCALAR cols: 1, + layout: LAYOUT::COL, data: Vec::new(), ptr: self.ptr, } diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 01659af..7445b5b 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,3 +1,4 @@ +use crate::LAYOUT; use crate::cast_mut; use crate::ffi::vec_znx; use crate::ffi::znx; @@ -6,14 +7,27 @@ use crate::{alloc_aligned, assert_alignement}; use itertools::izip; use std::cmp::min; -/// [VecZnx] represents a vector of small norm polynomials of Zn\[X\] with [i64] coefficients. +/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of +/// Zn\[X\] with [i64] coefficients. /// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array /// in the memory. +/// +/// # Example +/// +/// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory +/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci +/// are small polynomials of Zn\[X\]. #[derive(Clone)] pub struct VecZnx { /// Polynomial degree. pub n: usize, + /// Stack size + pub size: usize, + + /// Stacking layout + pub layout: LAYOUT, + /// Number of columns. pub cols: usize, @@ -24,23 +38,8 @@ pub struct VecZnx { pub ptr: *mut i64, } -pub trait VecZnxVec { - fn dblptr(&self) -> Vec<&[i64]>; - fn dblptr_mut(&mut self) -> Vec<&mut [i64]>; -} - -impl VecZnxVec for Vec { - fn dblptr(&self) -> Vec<&[i64]> { - self.iter().map(|v| v.raw()).collect() - } - - fn dblptr_mut(&mut self) -> Vec<&mut [i64]> { - self.iter_mut().map(|v| v.raw_mut()).collect() - } -} - -pub fn bytes_of_vec_znx(n: usize, cols: usize) -> usize { - n * cols * 8 +pub fn bytes_of_vec_znx(n: usize, size: usize, cols: usize) -> usize { + n * size * cols * 8 } impl VecZnx { @@ -49,11 +48,12 @@ impl VecZnx { /// The struct will take ownership of buf[..[VecZnx::bytes_of]] /// /// User must ensure that data is properly alligned and that - /// the size of data is at least equal to [VecZnx::bytes_of]. - pub fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self { + /// the size of data is equal to [VecZnx::bytes_of]. + pub fn from_bytes(n: usize, size: usize, cols: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert_eq!(bytes.len(), Self::bytes_of(n, cols)); + assert!(size > 0); + assert_eq!(bytes.len(), Self::bytes_of(n, size, cols)); assert_alignement(bytes.as_ptr()); } unsafe { @@ -61,75 +61,135 @@ impl VecZnx { let ptr: *mut i64 = bytes_i64.as_mut_ptr(); VecZnx { n: n, + size: size, cols: cols, - data: Vec::from_raw_parts(bytes_i64.as_mut_ptr(), bytes.len(), bytes.len()), + layout: LAYOUT::COL, + data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), ptr: ptr, } } } - pub fn from_bytes_borrow(n: usize, cols: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes_borrow(n: usize, size: usize, cols: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert!(bytes.len() >= Self::bytes_of(n, cols)); + assert!(size > 0); + assert!(bytes.len() >= Self::bytes_of(n, size, cols)); assert_alignement(bytes.as_ptr()); } VecZnx { n: n, + size: size, cols: cols, + layout: LAYOUT::COL, data: Vec::new(), ptr: bytes.as_mut_ptr() as *mut i64, } } - pub fn bytes_of(n: usize, cols: usize) -> usize { - bytes_of_vec_znx(n, cols) + pub fn bytes_of(n: usize, size: usize, cols: usize) -> usize { + bytes_of_vec_znx(n, size, cols) } pub fn copy_from(&mut self, a: &VecZnx) { copy_vec_znx_from(self, a); } - pub fn raw(&self) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.cols) } - } - pub fn borrowing(&self) -> bool { self.data.len() == 0 } - pub fn raw_mut(&mut self) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.cols) } + /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::cols()]. + pub fn raw(&self) -> &[i64] { + unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.size * self.cols) } } + /// Returns a reference to backend slice of the receiver. + /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::cols()]. + pub fn raw_mut(&mut self) -> &mut [i64] { + unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.size * self.cols) } + } + + /// Returns a non-mutable pointer to the backedn slice of the receiver. pub fn as_ptr(&self) -> *const i64 { self.ptr } + /// Returns a mutable pointer to the backedn slice of the receiver. pub fn as_mut_ptr(&mut self) -> *mut i64 { self.ptr } - pub fn at(&self, i: usize) -> &[i64] { - let n: usize = self.n(); - &self.raw()[n * i..n * (i + 1)] - } - - pub fn at_mut(&mut self, i: usize) -> &mut [i64] { - let n: usize = self.n(); - &mut self.raw_mut()[n * i..n * (i + 1)] - } - + /// Returns a non-mutable pointer starting a the j-th column. pub fn at_ptr(&self, i: usize) -> *const i64 { - self.ptr.wrapping_add(i * self.n) + #[cfg(debug_assertions)] + { + assert!(i < self.cols); + } + let offset: usize = self.n * self.size * i; + self.ptr.wrapping_add(offset) } - pub fn at_mut_ptr(&mut self, i: usize) -> *mut i64 { - self.ptr.wrapping_add(i * self.n) + /// Returns non-mutable reference to the ith-column. + /// The slice contains [VecZnx::size()] small polynomials, each of [VecZnx::n()] coefficients. + pub fn at(&self, i: usize) -> &[i64] { + unsafe { std::slice::from_raw_parts(self.at_ptr(i), self.n * self.size) } + } + + /// Returns a non-mutable pointer starting a the j-th column of the i-th polynomial. + pub fn at_poly_ptr(&self, i: usize, j: usize) -> *const i64 { + #[cfg(debug_assertions)] + { + assert!(i < self.size); + assert!(j < self.cols); + } + let offset: usize = self.n * (self.size * j + i); + self.ptr.wrapping_add(offset) + } + + /// Returns non-mutable reference to the j-th column of the i-th polynomial. + /// The slice contains one small polynomial of [VecZnx::n()] coefficients. + pub fn at_poly(&self, i: usize, j: usize) -> &[i64] { + unsafe { std::slice::from_raw_parts(self.at_poly_ptr(i, j), self.n) } + } + + /// Returns a mutable pointer starting a the j-th column. + pub fn at_mut_ptr(&self, i: usize) -> *mut i64 { + #[cfg(debug_assertions)] + { + assert!(i < self.cols); + } + let offset: usize = self.n * self.size * i; + self.ptr.wrapping_add(offset) + } + + /// Returns mutable reference to the ith-column. + /// The slice contains [VecZnx::size()] small polynomials, each of [VecZnx::n()] coefficients. + pub fn at_mut(&mut self, i: usize) -> &mut [i64] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i), self.n * self.size) } + } + + /// Returns a mutable pointer starting a the j-th column of the i-th polynomial. + pub fn at_poly_mut_ptr(&mut self, i: usize, j: usize) -> *mut i64 { + #[cfg(debug_assertions)] + { + assert!(i < self.size); + assert!(j < self.cols); + } + + let offset: usize = self.n * (self.size * j + i); + self.ptr.wrapping_add(offset) + } + + /// Returns mutable reference to the j-th column of the i-th polynomial. + /// The slice contains one small polynomial of [VecZnx::n()] coefficients. + pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [i64] { + let ptr: *mut i64 = self.at_poly_mut_ptr(i, j); + unsafe { std::slice::from_raw_parts_mut(ptr, self.n) } } pub fn zero(&mut self) { - unsafe { znx::znx_zero_i64_ref((self.n * self.cols) as u64, self.ptr) } + unsafe { znx::znx_zero_i64_ref((self.n * self.cols * self.size) as u64, self.ptr) } } pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { @@ -144,8 +204,8 @@ impl VecZnx { switch_degree(a, self) } - pub fn print(&self, cols: usize, n: usize) { - (0..cols).for_each(|i| println!("{}: {:?}", i, &self.at(i)[..n])) + pub fn print(&self, poly: usize, cols: usize, n: usize) { + (0..cols).for_each(|i| println!("{}: {:?}", i, &self.at_poly(poly, i)[..n])) } } @@ -160,6 +220,14 @@ impl Infos for VecZnx { self.n } + fn size(&self) -> usize { + self.size + } + + fn layout(&self) -> LAYOUT { + self.layout + } + /// Returns the number of cols of the [VecZnx]. fn cols(&self) -> usize { self.cols @@ -182,11 +250,20 @@ pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { impl VecZnx { /// Allocates a new [VecZnx] composed of #cols polynomials of Z\[X\]. - pub fn new(n: usize, cols: usize) -> Self { - let mut data: Vec = alloc_aligned::(n * cols); + pub fn new(n: usize, size: usize, cols: usize) -> Self { + #[cfg(debug_assertions)] + { + assert!(n > 0); + assert!(n & (n - 1) == 0); + assert!(size > 0); + assert!(cols > 0); + } + let mut data: Vec = alloc_aligned::(n * size * cols); let ptr: *mut i64 = data.as_mut_ptr(); Self { n: n, + size: size, + layout: LAYOUT::COL, cols: cols, data: data, ptr: ptr, @@ -206,7 +283,7 @@ impl VecZnx { if !self.borrowing() { self.data - .truncate((self.cols() - k / log_base2k) * self.n()); + .truncate((self.cols() - k / log_base2k) * self.n() * self.size()); } self.cols -= k / log_base2k; @@ -244,14 +321,20 @@ pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) { }); } +fn normalize_tmp_bytes(n: usize, size: usize) -> usize { + n * size * std::mem::size_of::() +} + fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { let n: usize = a.n(); + let size: usize = a.size(); debug_assert!( - tmp_bytes.len() >= n * 8, - "invalid tmp_bytes: tmp_bytes.len()={} < self.n()={}", + tmp_bytes.len() >= normalize_tmp_bytes(n, size), + "invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({}, {})", tmp_bytes.len(), - n + n, + size, ); #[cfg(debug_assertions)] { @@ -264,7 +347,7 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); (0..a.cols()).rev().for_each(|i| { znx::znx_normalize( - n as u64, + (n * size) as u64, log_base2k as u64, a.at_mut_ptr(i), carry_i64.as_mut_ptr(), @@ -275,27 +358,32 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { } } +pub fn rsh_tmp_bytes(n: usize, size: usize) -> usize { + n * size * std::mem::size_of::() +} + pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); - - debug_assert!( - tmp_bytes.len() >> 3 >= n, - "invalid carry: carry.len()/8={} < self.n()={}", - tmp_bytes.len() >> 3, - n - ); + let size: usize = a.size(); #[cfg(debug_assertions)] { - assert_alignement(tmp_bytes.as_ptr()) + assert!( + tmp_bytes.len() >= rsh_tmp_bytes(n, size), + "invalid carry: carry.len()/8={} < rsh_tmp_bytes({}, {})", + tmp_bytes.len() >> 3, + n, + size, + ); + assert_alignement(tmp_bytes.as_ptr()); } let cols: usize = a.cols(); let cols_steps: usize = k / log_base2k; - a.raw_mut().rotate_right(n * cols_steps); + a.raw_mut().rotate_right(n * size * cols_steps); unsafe { - znx::znx_zero_i64_ref((n * cols_steps) as u64, a.as_mut_ptr()); + znx::znx_zero_i64_ref((n * size * cols_steps) as u64, a.as_mut_ptr()); } let k_rem = k % log_base2k; @@ -304,7 +392,7 @@ pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) { let carry_i64: &mut [i64] = cast_mut(tmp_bytes); unsafe { - znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); + znx::znx_zero_i64_ref((n * size) as u64, carry_i64.as_mut_ptr()); } let log_base2k: usize = log_base2k; @@ -330,13 +418,13 @@ pub trait VecZnxOps { /// # Arguments /// /// * `cols`: the number of cols. - fn new_vec_znx(&self, cols: usize) -> VecZnx; + fn new_vec_znx(&self, size: usize, cols: usize) -> VecZnx; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnx] through [VecZnx::from_bytes]. - fn bytes_of_vec_znx(&self, cols: usize) -> usize; + fn bytes_of_vec_znx(&self, size: usize, cols: usize) -> usize; - fn vec_znx_normalize_tmp_bytes(&self) -> usize; + fn vec_znx_normalize_tmp_bytes(&self, size: usize) -> usize; /// c <- a + b. fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); @@ -389,162 +477,216 @@ pub trait VecZnxOps { } impl VecZnxOps for Module { - fn new_vec_znx(&self, cols: usize) -> VecZnx { - VecZnx::new(self.n(), cols) + fn new_vec_znx(&self, size: usize, cols: usize) -> VecZnx { + VecZnx::new(self.n(), size, cols) } - fn bytes_of_vec_znx(&self, cols: usize) -> usize { - self.n() * cols * 8 + fn bytes_of_vec_znx(&self, size: usize, cols: usize) -> usize { + bytes_of_vec_znx(self.n(), size, cols) } - fn vec_znx_normalize_tmp_bytes(&self) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize } + fn vec_znx_normalize_tmp_bytes(&self, size: usize) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * size } } // c <- a + b fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { + let n: usize = self.n(); + #[cfg(debug_assertions)] + { + assert_eq!(c.n(), n); + assert_eq!(a.n(), n); + assert_eq!(b.n(), n); + } unsafe { vec_znx::vec_znx_add( self.ptr, c.as_mut_ptr(), c.cols() as u64, - c.n() as u64, + (n * c.size()) as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (n * a.size()) as u64, b.as_ptr(), b.cols() as u64, - b.n() as u64, + (n * b.size()) as u64, ) } } // b <- a + b fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + let n: usize = self.n(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), n); + assert_eq!(b.n(), n); + } unsafe { vec_znx::vec_znx_add( self.ptr, b.as_mut_ptr(), b.cols() as u64, - b.n() as u64, + (n * b.size()) as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (n * a.size()) as u64, b.as_ptr(), b.cols() as u64, - b.n() as u64, + (n * b.size()) as u64, ) } } // c <- a + b fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { + let n: usize = self.n(); + #[cfg(debug_assertions)] + { + assert_eq!(c.n(), n); + assert_eq!(a.n(), n); + assert_eq!(b.n(), n); + } unsafe { vec_znx::vec_znx_sub( self.ptr, c.as_mut_ptr(), c.cols() as u64, - c.n() as u64, + (n * c.size()) as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (n * a.size()) as u64, b.as_ptr(), b.cols() as u64, - b.n() as u64, + (n * b.size()) as u64, ) } } // b <- a - b fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + let n: usize = self.n(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), n); + assert_eq!(b.n(), n); + } unsafe { vec_znx::vec_znx_sub( self.ptr, b.as_mut_ptr(), b.cols() as u64, - b.n() as u64, + (n * b.size()) as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (n * a.size()) as u64, b.as_ptr(), b.cols() as u64, - b.n() as u64, + (n * b.size()) as u64, ) } } // b <- b - a fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + let n: usize = self.n(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), n); + assert_eq!(b.n(), n); + } unsafe { vec_znx::vec_znx_sub( self.ptr, b.as_mut_ptr(), b.cols() as u64, - b.n() as u64, + (n * b.size()) as u64, b.as_ptr(), b.cols() as u64, - b.n() as u64, + (n * b.size()) as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (n * a.size()) as u64, ) } } fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) { + let n: usize = self.n(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), n); + assert_eq!(b.n(), n); + } unsafe { vec_znx::vec_znx_negate( self.ptr, b.as_mut_ptr(), b.cols() as u64, - b.n() as u64, + (n * b.size()) as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (n * a.size()) as u64, ) } } fn vec_znx_negate_inplace(&self, a: &mut VecZnx) { + let n: usize = self.n(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), n); + } unsafe { vec_znx::vec_znx_negate( self.ptr, a.as_mut_ptr(), a.cols() as u64, - a.n() as u64, + (n * a.size()) as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (n * a.size()) as u64, ) } } fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { + let n: usize = self.n(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), n); + assert_eq!(b.n(), n); + } unsafe { vec_znx::vec_znx_rotate( self.ptr, k, b.as_mut_ptr(), b.cols() as u64, - b.n() as u64, + (n * b.size()) as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (n * a.size()) as u64, ) } } fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) { + let n: usize = self.n(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), n); + } unsafe { vec_znx::vec_znx_rotate( self.ptr, k, a.as_mut_ptr(), a.cols() as u64, - a.n() as u64, + (n * a.size()) as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (n * a.size()) as u64, ) } } @@ -562,18 +704,22 @@ impl VecZnxOps for Module { /// /// The method will panic if the argument `a` is greater than `a.cols()`. fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { - debug_assert_eq!(a.n(), self.n()); - debug_assert_eq!(b.n(), self.n()); + let n: usize = self.n(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), n); + assert_eq!(b.n(), n); + } unsafe { vec_znx::vec_znx_automorphism( self.ptr, k, b.as_mut_ptr(), b.cols() as u64, - b.n() as u64, + (n * b.size()) as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (n * a.size()) as u64, ); } } @@ -590,17 +736,21 @@ impl VecZnxOps for Module { /// /// The method will panic if the argument `cols` is greater than `self.cols()`. fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { - debug_assert_eq!(a.n(), self.n()); + let n: usize = self.n(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), n); + } unsafe { vec_znx::vec_znx_automorphism( self.ptr, k, a.as_mut_ptr(), a.cols() as u64, - a.n() as u64, + (n * a.size()) as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (n * a.size()) as u64, ); } } diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 0de9c8c..705a5ec 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,11 +1,13 @@ use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; -use crate::{BACKEND, Infos, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement}; +use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement}; pub struct VecZnxBig { pub data: Vec, pub ptr: *mut u8, pub n: usize, + pub size: usize, pub cols: usize, + pub layout: LAYOUT, pub backend: BACKEND, } @@ -13,10 +15,10 @@ impl VecZnxBig { /// Returns a new [VecZnxBig] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. - pub fn from_bytes(module: &Module, cols: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols)); + assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(size, cols)); assert_alignement(bytes.as_ptr()) }; unsafe { @@ -24,22 +26,26 @@ impl VecZnxBig { data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), ptr: bytes.as_mut_ptr(), n: module.n(), + size: size, + layout: LAYOUT::COL, cols: cols, backend: module.backend, } } } - pub fn from_bytes_borrow(module: &Module, cols: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes_borrow(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols)); + assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(size, cols)); assert_alignement(bytes.as_ptr()); } Self { data: Vec::new(), ptr: bytes.as_mut_ptr(), n: module.n(), + size: size, + layout: LAYOUT::COL, cols: cols, backend: module.backend, } @@ -50,6 +56,8 @@ impl VecZnxBig { data: Vec::new(), ptr: self.ptr, n: self.n, + size: self.size, + layout: LAYOUT::COL, cols: self.cols, backend: self.backend, } @@ -81,6 +89,14 @@ impl Infos for VecZnxBig { self.n } + fn size(&self) -> usize { + self.size + } + + fn layout(&self) -> LAYOUT { + self.layout + } + /// Returns the number of cols of the [VecZnx]. fn cols(&self) -> usize { self.cols @@ -94,7 +110,7 @@ impl Infos for VecZnxBig { pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. - fn new_vec_znx_big(&self, cols: usize) -> VecZnxBig; + fn new_vec_znx_big(&self, size: usize, cols: usize) -> VecZnxBig; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -107,7 +123,7 @@ pub trait VecZnxBigOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig; + fn new_vec_znx_big_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxBig; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -120,11 +136,11 @@ pub trait VecZnxBigOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; + fn new_vec_znx_big_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. - fn bytes_of_vec_znx_big(&self, cols: usize) -> usize; + fn bytes_of_vec_znx_big(&self, size: usize, cols: usize) -> usize; /// b <- b - a fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); @@ -162,28 +178,30 @@ pub trait VecZnxBigOps { } impl VecZnxBigOps for Module { - fn new_vec_znx_big(&self, cols: usize) -> VecZnxBig { - let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_big(cols)); + fn new_vec_znx_big(&self, size: usize, cols: usize) -> VecZnxBig { + let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_big(size, cols)); let ptr: *mut u8 = data.as_mut_ptr(); VecZnxBig { data: data, ptr: ptr, n: self.n(), + size: size, + layout: LAYOUT::COL, cols: cols, backend: self.backend(), } } - fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes(self, cols, bytes) + fn new_vec_znx_big_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxBig { + VecZnxBig::from_bytes(self, size, cols, bytes) } - fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes_borrow(self, cols, tmp_bytes) + fn new_vec_znx_big_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { + VecZnxBig::from_bytes_borrow(self, size, cols, tmp_bytes) } - fn bytes_of_vec_znx_big(&self, cols: usize) -> usize { - unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, cols as u64) as usize } + fn bytes_of_vec_znx_big(&self, size: usize, cols: usize) -> usize { + unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, cols as u64) as usize * size } } fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 7798298..8b31ea6 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,13 +1,15 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; -use crate::{BACKEND, Infos, Module, VecZnxBig, assert_alignement}; +use crate::{BACKEND, Infos, LAYOUT, Module, VecZnxBig, assert_alignement}; use crate::{DEFAULTALIGN, VecZnx, alloc_aligned}; pub struct VecZnxDft { pub data: Vec, pub ptr: *mut u8, pub n: usize, + pub size: usize, + pub layout: LAYOUT, pub cols: usize, pub backend: BACKEND, } @@ -16,10 +18,10 @@ impl VecZnxDft { /// Returns a new [VecZnxDft] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. - pub fn from_bytes(module: &Module, cols: usize, bytes: &mut [u8]) -> VecZnxDft { + pub fn from_bytes(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft { #[cfg(debug_assertions)] { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols)); + assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(size, cols)); assert_alignement(bytes.as_ptr()) } unsafe { @@ -27,22 +29,26 @@ impl VecZnxDft { data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), ptr: bytes.as_mut_ptr(), n: module.n(), + size: size, + layout: LAYOUT::COL, cols: cols, backend: module.backend, } } } - pub fn from_bytes_borrow(module: &Module, cols: usize, bytes: &mut [u8]) -> VecZnxDft { + pub fn from_bytes_borrow(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft { #[cfg(debug_assertions)] { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols)); + assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(size, cols)); assert_alignement(bytes.as_ptr()); } VecZnxDft { data: Vec::new(), ptr: bytes.as_mut_ptr(), n: module.n(), + size: size, + layout: LAYOUT::COL, cols: cols, backend: module.backend, } @@ -56,6 +62,8 @@ impl VecZnxDft { data: Vec::new(), ptr: self.ptr, n: self.n, + layout: LAYOUT::COL, + size: self.size, cols: self.cols, backend: self.backend, } @@ -105,6 +113,14 @@ impl Infos for VecZnxDft { self.n } + fn size(&self) -> usize { + self.size + } + + fn layout(&self) -> LAYOUT { + self.layout + } + /// Returns the number of cols of the [VecZnx]. fn cols(&self) -> usize { self.cols @@ -118,7 +134,7 @@ impl Infos for VecZnxDft { pub trait VecZnxDftOps { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. - fn new_vec_znx_dft(&self, cols: usize) -> VecZnxDft; + fn new_vec_znx_dft(&self, size: usize, cols: usize) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -131,7 +147,7 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft; + fn new_vec_znx_dft_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -144,7 +160,7 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft; + fn new_vec_znx_dft_from_bytes_borrow(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -155,7 +171,7 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize; + fn bytes_of_vec_znx_dft(&self, size: usize, cols: usize) -> usize; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. @@ -176,28 +192,30 @@ pub trait VecZnxDftOps { } impl VecZnxDftOps for Module { - fn new_vec_znx_dft(&self, cols: usize) -> VecZnxDft { - let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_dft(cols)); + fn new_vec_znx_dft(&self, size: usize, cols: usize) -> VecZnxDft { + let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_dft(size, cols)); let ptr: *mut u8 = data.as_mut_ptr(); VecZnxDft { data: data, ptr: ptr, n: self.n(), + size: size, + layout: LAYOUT::COL, cols: cols, backend: self.backend(), } } - fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes(self, cols, tmp_bytes) + fn new_vec_znx_dft_from_bytes(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { + VecZnxDft::from_bytes(self, size, cols, tmp_bytes) } - fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes_borrow(self, cols, tmp_bytes) + fn new_vec_znx_dft_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { + VecZnxDft::from_bytes_borrow(self, size, cols, tmp_bytes) } - fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize { - unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize } + fn bytes_of_vec_znx_dft(&self, size: usize, cols: usize) -> usize { + unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize * size } } fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) { @@ -317,9 +335,9 @@ mod tests { let cols: usize = 2; let log_base2k: usize = 17; - let mut a: VecZnx = module.new_vec_znx(cols); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(cols); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(cols); + let mut a: VecZnx = module.new_vec_znx(1, cols); + let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, cols); + let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, cols); let mut source: Source = Source::new(new_seed()); module.fill_uniform(log_base2k, &mut a, cols, &mut source); diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index 1ffdfc0..7d6c26f 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -1,7 +1,7 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp::{self, vmp_pmat_t}; -use crate::{BACKEND, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement}; +use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement}; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], /// stored as a 3D matrix in the DFT domain in a single contiguous array. @@ -23,8 +23,11 @@ pub struct VmpPMat { cols: usize, /// The ring degree of each [VecZnxDft]. n: usize, - - #[warn(dead_code)] + /// The number of stacked [VmpPMat], must be a square. + size: usize, + /// The memory layout of the stacked [VmpPMat]. + layout: LAYOUT, + /// The backend fft or ntt. backend: BACKEND, } @@ -38,6 +41,14 @@ impl Infos for VmpPMat { (usize::BITS - (self.n() - 1).leading_zeros()) as _ } + fn size(&self) -> usize { + self.size + } + + fn layout(&self) -> LAYOUT { + self.layout + } + /// Returns the number of rows (i.e. of [VecZnxDft]) of the [VmpPMat] fn rows(&self) -> usize { self.rows @@ -120,12 +131,16 @@ impl VmpPMat { &self.raw::()[blk * nrows * ncols * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] } } + + fn backend(&self) -> BACKEND { + self.backend + } } /// This trait implements methods for vector matrix product, /// that is, multiplying a [VecZnx] with a [VmpPMat]. pub trait VmpPMatOps { - fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize) -> usize; + fn bytes_of_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> usize; /// Allocates a new [VmpPMat] with the given number of rows and columns. /// @@ -133,7 +148,7 @@ pub trait VmpPMatOps { /// /// * `rows`: number of rows (number of [VecZnxDft]). /// * `cols`: number of cols (number of cols of each [VecZnxDft]). - fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat; + fn new_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> VmpPMat; /// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous]. /// @@ -360,17 +375,19 @@ pub trait VmpPMatOps { } impl VmpPMatOps for Module { - fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize) -> usize { - unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, cols as u64) as usize } + fn bytes_of_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> usize { + unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, cols as u64) as usize * size } } - fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat { - let mut data: Vec = alloc_aligned::(self.bytes_of_vmp_pmat(rows, cols)); + fn new_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> VmpPMat { + let mut data: Vec = alloc_aligned::(self.bytes_of_vmp_pmat(size, rows, cols)); let ptr: *mut u8 = data.as_mut_ptr(); VmpPMat { data: data, ptr: ptr, n: self.n(), + size: size, + layout: LAYOUT::COL, cols: cols, rows: rows, backend: self.backend(), @@ -424,10 +441,10 @@ impl VmpPMatOps for Module { } fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { - debug_assert_eq!(a.len(), b.cols() * self.n()); - debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols())); #[cfg(debug_assertions)] { + assert_eq!(a.len(), b.cols() * self.n()); + assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols())); assert_alignement(tmp_bytes.as_ptr()); } unsafe { @@ -642,13 +659,13 @@ mod tests { let vpmat_rows: usize = 4; let vpmat_cols: usize = 5; let log_base2k: usize = 8; - let mut a: VecZnx = module.new_vec_znx(vpmat_cols); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(vpmat_cols); - let mut a_big: VecZnxBig = module.new_vec_znx_big(vpmat_cols); - let mut b_big: VecZnxBig = module.new_vec_znx_big(vpmat_cols); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(vpmat_cols); - let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols); - let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols); + let mut a: VecZnx = module.new_vec_znx(1, vpmat_cols); + let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_cols); + let mut a_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_cols); + let mut b_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_cols); + let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_cols); + let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(1, vpmat_rows, vpmat_cols); + let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(1, vpmat_rows, vpmat_cols); let mut tmp_bytes: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols)); diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs index 94df0b6..fdd2240 100644 --- a/rlwe/benches/gadget_product.rs +++ b/rlwe/benches/gadget_product.rs @@ -1,6 +1,4 @@ -use base2k::{ - BACKEND, Infos, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8, -}; +use base2k::{BACKEND, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use rlwe::{ ciphertext::{Ciphertext, new_gadget_ciphertext}, @@ -106,10 +104,10 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { &mut tmp_bytes, ); - let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols()); - let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols()); + let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(1, gadget_ct.cols()); + let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(1, gadget_ct.cols()); - let mut a: VecZnx = params.module().new_vec_znx(params.cols_q()); + let mut a: VecZnx = params.module().new_vec_znx(0, params.cols_q()); params .module() .fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa); diff --git a/rlwe/examples/encryption.rs b/rlwe/examples/encryption.rs index cd9c7a1..b9d66cd 100644 --- a/rlwe/examples/encryption.rs +++ b/rlwe/examples/encryption.rs @@ -39,11 +39,11 @@ fn main() { let log_k: usize = params.log_q() - 20; - pt.0.value[0].encode_vec_i64(log_base2k, log_k, &want, 32); + pt.0.value[0].encode_vec_i64(0, log_base2k, log_k, &want, 32); pt.0.value[0].normalize(log_base2k, &mut tmp_bytes); println!("log_k: {}", log_k); - pt.0.value[0].print(pt.cols(), 16); + pt.0.value[0].print(0, pt.cols(), 16); println!(); let mut ct: Ciphertext = params.new_ciphertext(params.log_q()); @@ -64,12 +64,12 @@ fn main() { ); params.decrypt_rlwe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes); - pt.0.value[0].print(pt.cols(), 16); + pt.0.value[0].print(0, pt.cols(), 16); let mut have = vec![i64::default(); params.n()]; println!("pt: {}", log_k); - pt.0.value[0].decode_vec_i64(pt.log_base2k(), log_k, &mut have); + pt.0.value[0].decode_vec_i64(0, pt.log_base2k(), log_k, &mut have); println!("want: {:?}", &want[..16]); println!("have: {:?}", &have[..16]); diff --git a/rlwe/src/automorphism.rs b/rlwe/src/automorphism.rs index 46a4fc5..5e5b48a 100644 --- a/rlwe/src/automorphism.rs +++ b/rlwe/src/automorphism.rs @@ -11,7 +11,7 @@ use base2k::{ VmpPMatOps, assert_alignement, }; use sampling::source::Source; -use std::{cmp::min, collections::HashMap}; +use std::collections::HashMap; /// Stores DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E, A]) where AUTO(X, p): X^{i} -> X^{i*p} pub struct AutomorphismKey { @@ -152,7 +152,7 @@ pub fn automorphism( pub fn automorphism_inplace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize { return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols) - + 2 * module.bytes_of_vec_znx_dft(std::cmp::min(c_cols, a_cols)); + + 2 * module.bytes_of_vec_znx_dft(1, std::cmp::min(c_cols, a_cols)); } pub fn automorphism_inplace( @@ -184,11 +184,11 @@ pub fn automorphism_big( assert_alignement(tmp_bytes.as_ptr()); } - let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); + let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); + let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft); + let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft); + let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft); // a1_dft = DFT(a[1]) module.vec_znx_dft(&mut a1_dft, a.at(1)); @@ -295,7 +295,7 @@ mod test { let mut pt: Plaintext = params.new_plaintext(log_q); let mut pt_auto: Plaintext = params.new_plaintext(log_q); - pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32); + pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32); module.vec_znx_automorphism(p, pt_auto.at_mut(0), pt.at(0)); encrypt_rlwe_sk( @@ -334,7 +334,7 @@ mod test { // pt.at(0).print(pt.cols(), 16); - let noise_have: f64 = pt.at(0).std(log_base2k).log2(); + let noise_have: f64 = pt.at(0).std(0, log_base2k).log2(); let var_msg: f64 = (params.xs() as f64) / params.n() as f64; let var_a_err: f64 = 1f64 / 12f64; diff --git a/rlwe/src/ciphertext.rs b/rlwe/src/ciphertext.rs index 67a7f77..9d1fe1a 100644 --- a/rlwe/src/ciphertext.rs +++ b/rlwe/src/ciphertext.rs @@ -1,6 +1,6 @@ use crate::elem::{Elem, ElemCommon}; use crate::parameters::Parameters; -use base2k::{Infos, Module, VecZnx, VmpPMat}; +use base2k::{Infos, LAYOUT, Module, VecZnx, VmpPMat}; pub struct Ciphertext(pub Elem); @@ -38,6 +38,10 @@ where self.elem().size() } + fn layout(&self) -> LAYOUT { + self.elem().layout() + } + fn rows(&self) -> usize { self.elem().rows() } diff --git a/rlwe/src/decryptor.rs b/rlwe/src/decryptor.rs index 04d56bc..6eeea27 100644 --- a/rlwe/src/decryptor.rs +++ b/rlwe/src/decryptor.rs @@ -9,7 +9,6 @@ use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZn use std::cmp::min; pub struct Decryptor { - #[warn(dead_code)] sk: SvpPPol, } @@ -21,8 +20,8 @@ impl Decryptor { } } -pub fn decrypt_rlwe_tmp_byte(module: &Module, limbs: usize) -> usize { - module.bytes_of_vec_znx_dft(limbs) + module.vec_znx_big_normalize_tmp_bytes() +pub fn decrypt_rlwe_tmp_byte(module: &Module, cols: usize) -> usize { + module.bytes_of_vec_znx_dft(1, cols) + module.vec_znx_big_normalize_tmp_bytes() } impl Parameters { @@ -48,9 +47,9 @@ pub fn decrypt_rlwe(module: &Module, res: &mut Elem, a: &Elem, s decrypt_rlwe_tmp_byte(module, cols) ); - let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); + let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - let mut res_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, cols, tmp_bytes_vec_znx_dft); + let mut res_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, 1, cols, tmp_bytes_vec_znx_dft); let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big(); // res_dft <- DFT(ct[1]) * DFT(sk) diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index e0252a6..e7e61c4 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,4 +1,4 @@ -use base2k::{Infos, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps}; +use base2k::{Infos, LAYOUT, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps}; pub struct Elem { pub value: Vec, @@ -25,11 +25,11 @@ impl ElemVecZnx for Elem { let n: usize = module.n(); assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size)); let mut value: Vec = Vec::new(); - let limbs: usize = (log_q + log_base2k - 1) / log_base2k; - let elem_size = VecZnx::bytes_of(n, limbs); + let cols: usize = (log_q + log_base2k - 1) / log_base2k; + let elem_size = VecZnx::bytes_of(n, size, cols); let mut ptr: usize = 0; (0..size).for_each(|_| { - value.push(VecZnx::from_bytes(n, limbs, &mut bytes[ptr..])); + value.push(VecZnx::from_bytes(n, 1, cols, &mut bytes[ptr..])); ptr += elem_size }); Self { @@ -45,11 +45,11 @@ impl ElemVecZnx for Elem { let n: usize = module.n(); assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size)); let mut value: Vec = Vec::new(); - let limbs: usize = (log_q + log_base2k - 1) / log_base2k; - let elem_size = VecZnx::bytes_of(n, limbs); + let cols: usize = (log_q + log_base2k - 1) / log_base2k; + let elem_size = VecZnx::bytes_of(n, 1, cols); let mut ptr: usize = 0; (0..size).for_each(|_| { - value.push(VecZnx::from_bytes_borrow(n, limbs, &mut bytes[ptr..])); + value.push(VecZnx::from_bytes_borrow(n, 1, cols, &mut bytes[ptr..])); ptr += elem_size }); Self { @@ -71,6 +71,7 @@ pub trait ElemCommon { fn elem(&self) -> &Elem; fn elem_mut(&mut self) -> &mut Elem; fn size(&self) -> usize; + fn layout(&self) -> LAYOUT; fn rows(&self) -> usize; fn cols(&self) -> usize; fn log_base2k(&self) -> usize; @@ -101,6 +102,10 @@ impl ElemCommon for Elem { self.value.len() } + fn layout(&self) -> LAYOUT { + self.value[0].layout() + } + fn rows(&self) -> usize { self.value[0].rows() } @@ -135,9 +140,9 @@ impl ElemCommon for Elem { impl Elem { pub fn new(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> Self { assert!(rows > 0); - let limbs: usize = (log_q + log_base2k - 1) / log_base2k; + let cols: usize = (log_q + log_base2k - 1) / log_base2k; let mut value: Vec = Vec::new(); - (0..rows).for_each(|_| value.push(module.new_vec_znx(limbs))); + (0..rows).for_each(|_| value.push(module.new_vec_znx(1, cols))); Self { value, log_q, @@ -152,7 +157,7 @@ impl Elem { assert!(rows > 0); assert!(cols > 0); let mut value: Vec = Vec::new(); - (0..size).for_each(|_| value.push(module.new_vmp_pmat(rows, cols))); + (0..size).for_each(|_| value.push(module.new_vmp_pmat(1, rows, cols))); Self { value: value, log_q: 0, diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs index b919826..bdb383c 100644 --- a/rlwe/src/encryptor.rs +++ b/rlwe/src/encryptor.rs @@ -108,7 +108,7 @@ impl EncryptorSk { } pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize { - module.bytes_of_vec_znx_dft((log_q + log_base2k - 1) / log_base2k) + module.vec_znx_big_normalize_tmp_bytes() + module.bytes_of_vec_znx_dft(1, (log_q + log_base2k - 1) / log_base2k) + module.vec_znx_big_normalize_tmp_bytes() } pub fn encrypt_rlwe_sk( module: &Module, @@ -151,10 +151,10 @@ fn encrypt_rlwe_sk_core( // c1 <- Z_{2^prec}[X]/(X^{N}+1) module.fill_uniform(log_base2k, c1, cols, source_xa); - let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); + let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); // Scratch space for DFT values - let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, cols, tmp_bytes_vec_znx_dft); + let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, 1, cols, tmp_bytes_vec_znx_dft); // Applies buf_dft <- DFT(s) * DFT(c1) module.svp_apply_dft(&mut buf_dft, sk, c1); diff --git a/rlwe/src/gadget_product.rs b/rlwe/src/gadget_product.rs index 85b10e6..bbf9642 100644 --- a/rlwe/src/gadget_product.rs +++ b/rlwe/src/gadget_product.rs @@ -46,7 +46,7 @@ pub fn gadget_product_core( pub fn gadget_product_big_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize { return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols) - + 2 * module.bytes_of_vec_znx_dft(min(c_cols, a_cols)); + + 2 * module.bytes_of_vec_znx_dft(1, min(c_cols, a_cols)); } /// Evaluates the gadget product: c.at(i) = IDFT() @@ -66,11 +66,11 @@ pub fn gadget_product_big( ) { let cols: usize = min(c.cols(), a.cols()); - let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); + let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); + let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft); + let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft); + let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft); // a1_dft = DFT(a[1]) module.vec_znx_dft(&mut a1_dft, a.at(1)); @@ -99,11 +99,11 @@ pub fn gadget_product( ) { let cols: usize = min(c.cols(), a.cols()); - let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); + let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); + let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft); + let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft); + let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft); let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); // a1_dft = DFT(a[1]) @@ -206,7 +206,7 @@ mod test { // Intermediate buffers // Input polynopmial, uniformly distributed - let mut a: VecZnx = params.module().new_vec_znx(params.cols_q()); + let mut a: VecZnx = params.module().new_vec_znx(1, params.cols_q()); params .module() .fill_uniform(log_base2k, &mut a, params.cols_q(), &mut source_xa); @@ -215,9 +215,9 @@ mod test { let mut elem_res: Elem = Elem::::new(params.module(), log_base2k, params.log_qp(), 2); // Ideal output = a * s - let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(a.cols()); + let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(1, a.cols()); let mut a_big: VecZnxBig = a_dft.as_vec_znx_big(); - let mut a_times_s: VecZnx = params.module().new_vec_znx(a.cols()); + let mut a_times_s: VecZnx = params.module().new_vec_znx(1, a.cols()); // a * sk0 params.module().svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a); @@ -232,12 +232,12 @@ mod test { // Iterates over all possible cols values for input/output polynomials and gadget ciphertext. (1..a.cols() + 1).for_each(|a_cols| { - let mut a_trunc: VecZnx = params.module().new_vec_znx(a_cols); + let mut a_trunc: VecZnx = params.module().new_vec_znx(1, a_cols); a_trunc.copy_from(&a); (1..gadget_ct.cols() + 1).for_each(|b_cols| { - let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(b_cols); - let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(b_cols); + let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(1, b_cols); + let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(1, b_cols); let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big(); let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big(); @@ -296,7 +296,7 @@ mod test { // pt.at(0).print(pt.elem().cols(), 16); - let noise_have: f64 = pt.at(0).std(log_base2k).log2(); + let noise_have: f64 = pt.at(0).std(0, log_base2k).log2(); let var_a_err: f64; diff --git a/rlwe/src/key_switching.rs b/rlwe/src/key_switching.rs index 78d48d5..4e0001a 100644 --- a/rlwe/src/key_switching.rs +++ b/rlwe/src/key_switching.rs @@ -1,6 +1,6 @@ use crate::ciphertext::Ciphertext; use crate::elem::ElemCommon; -use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement}; +use base2k::{Module, VecZnx, VecZnxBigOps, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement}; use std::cmp::min; pub fn key_switch_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { @@ -8,8 +8,8 @@ pub fn key_switch_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k; let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k; return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols) - + module.bytes_of_vec_znx_dft(std::cmp::min(res_cols, in_cols)) - + module.bytes_of_vec_znx_dft(gct_cols); + + module.bytes_of_vec_znx_dft(1, std::cmp::min(res_cols, in_cols)) + + module.bytes_of_vec_znx_dft(1, gct_cols); } pub fn key_switch_rlwe( @@ -54,11 +54,11 @@ fn key_switch_rlwe_core( assert_alignement(tmp_bytes.as_ptr()); } - let (tmp_bytes_a1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols)); + let (tmp_bytes_a1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); + let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - let mut a1_dft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_a1_dft); - let mut res_dft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft); + let mut a1_dft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_a1_dft); + let mut res_dft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_res_dft); let mut res_big = res_dft.as_vec_znx_big(); module.vec_znx_dft(&mut a1_dft, a.at(1)); diff --git a/rlwe/src/plaintext.rs b/rlwe/src/plaintext.rs index d7725c3..86f7e32 100644 --- a/rlwe/src/plaintext.rs +++ b/rlwe/src/plaintext.rs @@ -1,7 +1,7 @@ use crate::ciphertext::Ciphertext; use crate::elem::{Elem, ElemCommon, ElemVecZnx}; use crate::parameters::Parameters; -use base2k::{Module, VecZnx}; +use base2k::{LAYOUT, Module, VecZnx}; pub struct Plaintext(pub Elem); @@ -79,6 +79,10 @@ impl ElemCommon for Plaintext { self.elem().size() } + fn layout(&self) -> LAYOUT { + self.elem().layout() + } + fn rows(&self) -> usize { self.0.rows() } diff --git a/rlwe/src/rgsw_product.rs b/rlwe/src/rgsw_product.rs index 0c1bdba..dc42602 100644 --- a/rlwe/src/rgsw_product.rs +++ b/rlwe/src/rgsw_product.rs @@ -18,8 +18,8 @@ pub fn rgsw_product_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usiz let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k; let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k; return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols) - + module.bytes_of_vec_znx_dft(std::cmp::min(res_cols, in_cols)) - + 2 * module.bytes_of_vec_znx_dft(gct_cols); + + module.bytes_of_vec_znx_dft(1, std::cmp::min(res_cols, in_cols)) + + 2 * module.bytes_of_vec_znx_dft(1, gct_cols); } pub fn rgsw_product( @@ -40,13 +40,13 @@ pub fn rgsw_product( assert_alignement(tmp_bytes.as_ptr()); } - let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(a.cols())); - let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols)); - let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols)); + let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, a.cols())); + let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); + let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(a.cols(), tmp_bytes_ai_dft); - let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c0_dft); - let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c1_dft); + let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, a.cols(), tmp_bytes_ai_dft); + let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c0_dft); + let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c1_dft); let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big(); let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big(); @@ -82,13 +82,13 @@ pub fn rgsw_product_inplace( assert_alignement(tmp_bytes.as_ptr()); } - let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(a.cols())); - let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols)); - let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols)); + let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, a.cols())); + let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); + let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(a.cols(), tmp_bytes_ai_dft); - let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c0_dft); - let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c1_dft); + let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, a.cols(), tmp_bytes_ai_dft); + let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c0_dft); + let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c1_dft); let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big(); let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big(); @@ -193,7 +193,7 @@ mod test { let mut pt: Plaintext = params.new_plaintext(log_q); let mut pt_rotate: Plaintext = params.new_plaintext(log_q); - pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32); + pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32); module.vec_znx_rotate(k, pt_rotate.at_mut(0), pt.at_mut(0)); @@ -222,7 +222,7 @@ mod test { // pt.at(0).print(pt.cols(), 16); - let noise_have: f64 = pt.at(0).std(log_base2k).log2(); + let noise_have: f64 = pt.at(0).std(0, log_base2k).log2(); let var_msg: f64 = 1f64 / params.n() as f64; // X^{k} let var_a0_err: f64 = params.xe() * params.xe(); diff --git a/rlwe/src/trace.rs b/rlwe/src/trace.rs index 70bb92d..9e7feb8 100644 --- a/rlwe/src/trace.rs +++ b/rlwe/src/trace.rs @@ -22,7 +22,7 @@ impl Parameters { pub fn trace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize { return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols) - + 2 * module.bytes_of_vec_znx_dft(std::cmp::min(c_cols, a_cols)); + + 2 * module.bytes_of_vec_znx_dft(1, std::cmp::min(c_cols, a_cols)); } pub fn trace_inplace( @@ -59,11 +59,11 @@ pub fn trace_inplace( let cols: usize = std::cmp::min(b_cols, a.cols()); - let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols)); + let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); + let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft); + let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft); + let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_res_dft); let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); let log_base2k: usize = a.log_base2k(); @@ -189,12 +189,12 @@ mod test { let mut ct: Ciphertext = params.new_ciphertext(log_q); let mut pt: Plaintext = params.new_plaintext(log_q); - pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32); + pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32); pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes); - pt.at(0).decode_vec_i64(log_base2k, log_k, &mut data); + pt.at(0).decode_vec_i64(0, log_base2k, log_k, &mut data); - pt.at(0).print(pt.cols(), 16); + pt.at(0).print(0, pt.cols(), 16); encrypt_rlwe_sk( module, @@ -227,9 +227,9 @@ mod test { &mut tmp_bytes, ); - pt.at(0).print(pt.cols(), 16); + pt.at(0).print(0, pt.cols(), 16); - pt.at(0).decode_vec_i64(log_base2k, log_k, &mut data); + pt.at(0).decode_vec_i64(0, log_base2k, log_k, &mut data); println!("trace: {:?}", &data[..16]); } diff --git a/sampling/src/source.rs b/sampling/src/source.rs index 9f51df6..c356163 100644 --- a/sampling/src/source.rs +++ b/sampling/src/source.rs @@ -45,7 +45,7 @@ impl Source { min + ((self.next_u64() << 11 >> 11) as f64) / MAXF64 * (max - min) } - pub fn next_i64(&mut self) -> i64{ + pub fn next_i64(&mut self) -> i64 { self.next_u64() as i64 } }