From 0cca56755ba7c72b389047deed0dc491a36dd4fe Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 25 Apr 2025 11:04:17 +0200 Subject: [PATCH 01/87] wip --- base2k/src/ffi/module.rs | 2 - base2k/src/infos.rs | 49 +++++++---- base2k/src/lib.rs | 7 -- base2k/src/svp.rs | 2 +- base2k/src/vec_znx.rs | 171 ++++++++++++++++++-------------------- base2k/src/vec_znx_dft.rs | 4 - base2k/src/vmp.rs | 23 +++-- 7 files changed, 130 insertions(+), 128 deletions(-) diff --git a/base2k/src/ffi/module.rs b/base2k/src/ffi/module.rs index 755d613..e35d4c0 100644 --- a/base2k/src/ffi/module.rs +++ b/base2k/src/ffi/module.rs @@ -3,8 +3,6 @@ pub struct module_info_t { } pub type module_type_t = ::std::os::raw::c_uint; -pub const module_type_t_FFT64: module_type_t = 0; -pub const module_type_t_NTT120: module_type_t = 1; pub use self::module_type_t as MODULE_TYPE; pub type MODULE = module_info_t; diff --git a/base2k/src/infos.rs b/base2k/src/infos.rs index 08472d9..2445022 100644 --- a/base2k/src/infos.rs +++ b/base2k/src/infos.rs @@ -1,22 +1,43 @@ -use crate::LAYOUT; +#[derive(Copy, Clone)] +#[repr(C)] +pub struct LAYOUT{ + /// Ring degree. + n: usize, + /// Number of logical rows in the layout. + rows: usize, + /// Number of polynomials per row. + cols: usize, + /// Number of limbs per polynomial. + size: usize, + /// Whether limbs are interleaved across rows. + /// + /// For example, for (rows, cols, size) = (2, 2, 3): + /// + /// - `true`: layout is ((a0, b0, a1, b1), (c0, d0, c1, d1)) + /// - `false`: layout is ((a0, a1, b0, b1), (c0, c1, d0, d1)) + interleaved : bool, +} pub trait Infos { - /// Returns the ring degree of the receiver. - fn n(&self) -> usize; - /// Returns the base two logarithm of the ring dimension of the receiver. - fn log_n(&self) -> usize; - - /// Returns the number of stacked polynomials. - fn size(&self) -> usize; - - /// Returns the memory layout of the stacked polynomials. + /// Returns the full layout. fn layout(&self) -> LAYOUT; - /// Returns the number of columns of the receiver. - /// This method is equivalent to [Infos::cols]. + /// Returns the ring degree of the polynomials. + fn n(&self) -> usize; + + /// Returns the base two logarithm of the ring dimension of the polynomials. + fn log_n(&self) -> usize; + + /// Returns the number of rows. + fn rows(&self) -> usize; + + /// Returns the number of polynomials in each row. fn cols(&self) -> usize; - /// Returns the number of rows of the receiver. - fn rows(&self) -> usize; + /// Returns the number of limbs per polynomial. + fn size(&self) -> usize; + + /// Whether limbs are interleaved across rows. + fn interleaved(&self) -> bool; } diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 7e97b00..ec0d2b7 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -27,13 +27,6 @@ pub use vmp::*; pub const GALOISGENERATOR: u64 = 5; pub const DEFAULTALIGN: usize = 64; -#[derive(Copy, Clone)] -#[repr(u8)] -pub enum LAYOUT { - ROW, - COL, -} - pub fn is_aligned_custom(ptr: *const T, align: usize) -> bool { (ptr as usize) % align == 0 } diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs index 0e85a31..bc37f86 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/svp.rs @@ -119,7 +119,7 @@ impl Scalar { n: self.n, size: 1, // TODO REVIEW IF NEED TO ADD size TO SCALAR cols: 1, - layout: LAYOUT::COL, + layout: LAYOUT::COL(1, 1), data: Vec::new(), ptr: self.ptr, } diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 7445b5b..71a315e 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -22,15 +22,12 @@ pub struct VecZnx { /// Polynomial degree. pub n: usize, - /// Stack size + /// Number of limbs pub size: usize, - /// Stacking layout + /// Layout pub layout: LAYOUT, - /// Number of columns. - pub cols: usize, - /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. pub data: Vec, @@ -38,8 +35,8 @@ pub struct VecZnx { pub ptr: *mut i64, } -pub fn bytes_of_vec_znx(n: usize, size: usize, cols: usize) -> usize { - n * size * cols * 8 +pub fn bytes_of_vec_znx(n: usize, layout: LAYOUT, size: usize) -> usize { + n * layout.size() * size * 8 } impl VecZnx { @@ -49,11 +46,11 @@ impl VecZnx { /// /// User must ensure that data is properly alligned and that /// the size of data is equal to [VecZnx::bytes_of]. - pub fn from_bytes(n: usize, size: usize, cols: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes(n: usize, layout: LAYOUT, size: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(n, size, cols)); + assert_eq!(bytes.len(), Self::bytes_of(n, layout, size)); assert_alignement(bytes.as_ptr()); } unsafe { @@ -62,33 +59,31 @@ impl VecZnx { VecZnx { n: n, size: size, - cols: cols, - layout: LAYOUT::COL, + layout: layout, data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), ptr: ptr, } } } - pub fn from_bytes_borrow(n: usize, size: usize, cols: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes_borrow(n: usize, layout: LAYOUT, size: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { assert!(size > 0); - assert!(bytes.len() >= Self::bytes_of(n, size, cols)); + assert!(bytes.len() >= Self::bytes_of(n, layout, size)); assert_alignement(bytes.as_ptr()); } VecZnx { n: n, size: size, - cols: cols, - layout: LAYOUT::COL, + layout: layout, data: Vec::new(), ptr: bytes.as_mut_ptr() as *mut i64, } } - pub fn bytes_of(n: usize, size: usize, cols: usize) -> usize { - bytes_of_vec_znx(n, size, cols) + pub fn bytes_of(n: usize, layout: LAYOUT, size: usize) -> usize { + bytes_of_vec_znx(n, layout, size) } pub fn copy_from(&mut self, a: &VecZnx) { @@ -99,15 +94,15 @@ impl VecZnx { self.data.len() == 0 } - /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::cols()]. + /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::size()]. pub fn raw(&self) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.size * self.cols) } + unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.size * self.size) } } /// Returns a reference to backend slice of the receiver. - /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::cols()]. + /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::size()]. pub fn raw_mut(&mut self) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.size * self.cols) } + unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.size * self.size) } } /// Returns a non-mutable pointer to the backedn slice of the receiver. @@ -124,7 +119,7 @@ impl VecZnx { pub fn at_ptr(&self, i: usize) -> *const i64 { #[cfg(debug_assertions)] { - assert!(i < self.cols); + assert!(i < self.size); } let offset: usize = self.n * self.size * i; self.ptr.wrapping_add(offset) @@ -141,7 +136,7 @@ impl VecZnx { #[cfg(debug_assertions)] { assert!(i < self.size); - assert!(j < self.cols); + assert!(j < self.size); } let offset: usize = self.n * (self.size * j + i); self.ptr.wrapping_add(offset) @@ -157,7 +152,7 @@ impl VecZnx { pub fn at_mut_ptr(&self, i: usize) -> *mut i64 { #[cfg(debug_assertions)] { - assert!(i < self.cols); + assert!(i < self.size); } let offset: usize = self.n * self.size * i; self.ptr.wrapping_add(offset) @@ -174,7 +169,7 @@ impl VecZnx { #[cfg(debug_assertions)] { assert!(i < self.size); - assert!(j < self.cols); + assert!(j < self.size); } let offset: usize = self.n * (self.size * j + i); @@ -189,7 +184,7 @@ impl VecZnx { } pub fn zero(&mut self) { - unsafe { znx::znx_zero_i64_ref((self.n * self.cols * self.size) as u64, self.ptr) } + unsafe { znx::znx_zero_i64_ref((self.n * self.size * self.size) as u64, self.ptr) } } pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { @@ -204,8 +199,8 @@ impl VecZnx { switch_degree(a, self) } - pub fn print(&self, poly: usize, cols: usize, n: usize) { - (0..cols).for_each(|i| println!("{}: {:?}", i, &self.at_poly(poly, i)[..n])) + pub fn print(&self, poly: usize, size: usize, n: usize) { + (0..size).for_each(|i| println!("{}: {:?}", i, &self.at_poly(poly, i)[..n])) } } @@ -228,9 +223,9 @@ impl Infos for VecZnx { self.layout } - /// Returns the number of cols of the [VecZnx]. - fn cols(&self) -> usize { - self.cols + /// Returns the number of size of the [VecZnx]. + fn size(&self) -> usize { + self.size } /// Returns the number of rows of the [VecZnx]. @@ -249,22 +244,22 @@ pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { } impl VecZnx { - /// Allocates a new [VecZnx] composed of #cols polynomials of Z\[X\]. - pub fn new(n: usize, size: usize, cols: usize) -> Self { + /// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\]. + pub fn new(n: usize, size: usize, size: usize) -> Self { #[cfg(debug_assertions)] { assert!(n > 0); assert!(n & (n - 1) == 0); assert!(size > 0); - assert!(cols > 0); + assert!(size <= u8::MAX as usize); + assert!(size > 0); } - let mut data: Vec = alloc_aligned::(n * size * cols); + let mut data: Vec = alloc_aligned::(n * size * size); let ptr: *mut i64 = data.as_mut_ptr(); Self { n: n, + layout: LAYOUT::COL(1, size as u8), size: size, - layout: LAYOUT::COL, - cols: cols, data: data, ptr: ptr, } @@ -283,16 +278,16 @@ impl VecZnx { if !self.borrowing() { self.data - .truncate((self.cols() - k / log_base2k) * self.n() * self.size()); + .truncate((self.size() - k / log_base2k) * self.n() * self.size()); } - self.cols -= k / log_base2k; + self.size -= k / log_base2k; let k_rem: usize = k % log_base2k; if k_rem != 0 { let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem; - self.at_mut(self.cols() - 1) + self.at_mut(self.size() - 1) .iter_mut() .for_each(|x: &mut i64| *x &= mask) } @@ -310,9 +305,9 @@ pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) { b.zero(); } - let cols = min(a.cols(), b.cols()); + let size = min(a.size(), b.size()); - (0..cols).for_each(|i| { + (0..size).for_each(|i| { izip!( a.at(i).iter().step_by(gap_in), b.at_mut(i).iter_mut().step_by(gap_out) @@ -345,7 +340,7 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { unsafe { znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); - (0..a.cols()).rev().for_each(|i| { + (0..a.size()).rev().for_each(|i| { znx::znx_normalize( (n * size) as u64, log_base2k as u64, @@ -378,12 +373,12 @@ pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) { assert_alignement(tmp_bytes.as_ptr()); } - let cols: usize = a.cols(); - let cols_steps: usize = k / log_base2k; + let size: usize = a.size(); + let size_steps: usize = k / log_base2k; - a.raw_mut().rotate_right(n * size * cols_steps); + a.raw_mut().rotate_right(n * size * size_steps); unsafe { - znx::znx_zero_i64_ref((n * size * cols_steps) as u64, a.as_mut_ptr()); + znx::znx_zero_i64_ref((n * size * size_steps) as u64, a.as_mut_ptr()); } let k_rem = k % log_base2k; @@ -397,7 +392,7 @@ pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) { let log_base2k: usize = log_base2k; - (cols_steps..cols).for_each(|i| { + (size_steps..size).for_each(|i| { izip!(carry_i64.iter_mut(), a.at_mut(i).iter_mut()).for_each(|(ci, xi)| { *xi += *ci << log_base2k; *ci = get_base_k_carry(*xi, k_rem); @@ -417,12 +412,12 @@ pub trait VecZnxOps { /// /// # Arguments /// - /// * `cols`: the number of cols. - fn new_vec_znx(&self, size: usize, cols: usize) -> VecZnx; + /// * `size`: the number of size. + fn new_vec_znx(&self, size: usize, size: usize) -> VecZnx; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnx] through [VecZnx::from_bytes]. - fn bytes_of_vec_znx(&self, size: usize, cols: usize) -> usize; + fn bytes_of_vec_znx(&self, size: usize, size: usize) -> usize; fn vec_znx_normalize_tmp_bytes(&self, size: usize) -> usize; @@ -477,12 +472,12 @@ pub trait VecZnxOps { } impl VecZnxOps for Module { - fn new_vec_znx(&self, size: usize, cols: usize) -> VecZnx { - VecZnx::new(self.n(), size, cols) + fn new_vec_znx(&self, size: usize, size: usize) -> VecZnx { + VecZnx::new(self.n(), size, size) } - fn bytes_of_vec_znx(&self, size: usize, cols: usize) -> usize { - bytes_of_vec_znx(self.n(), size, cols) + fn bytes_of_vec_znx(&self, size: usize, size: usize) -> usize { + bytes_of_vec_znx(self.n(), size, size) } fn vec_znx_normalize_tmp_bytes(&self, size: usize) -> usize { @@ -502,13 +497,13 @@ impl VecZnxOps for Module { vec_znx::vec_znx_add( self.ptr, c.as_mut_ptr(), - c.cols() as u64, + c.size() as u64, (n * c.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, b.as_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, ) } @@ -526,13 +521,13 @@ impl VecZnxOps for Module { vec_znx::vec_znx_add( self.ptr, b.as_mut_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, b.as_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, ) } @@ -551,13 +546,13 @@ impl VecZnxOps for Module { vec_znx::vec_znx_sub( self.ptr, c.as_mut_ptr(), - c.cols() as u64, + c.size() as u64, (n * c.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, b.as_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, ) } @@ -575,13 +570,13 @@ impl VecZnxOps for Module { vec_znx::vec_znx_sub( self.ptr, b.as_mut_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, b.as_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, ) } @@ -599,13 +594,13 @@ impl VecZnxOps for Module { vec_znx::vec_znx_sub( self.ptr, b.as_mut_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, b.as_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, ) } @@ -622,10 +617,10 @@ impl VecZnxOps for Module { vec_znx::vec_znx_negate( self.ptr, b.as_mut_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, ) } @@ -641,10 +636,10 @@ impl VecZnxOps for Module { vec_znx::vec_znx_negate( self.ptr, a.as_mut_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, ) } @@ -662,10 +657,10 @@ impl VecZnxOps for Module { self.ptr, k, b.as_mut_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, ) } @@ -682,27 +677,27 @@ impl VecZnxOps for Module { self.ptr, k, a.as_mut_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, ) } } - /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each cols. + /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size. /// /// # Arguments /// /// * `a`: input. /// * `b`: output. /// * `k`: the power to which to map each coefficients. - /// * `a_cols`: the number of a_cols on which to apply the mapping. + /// * `a_size`: the number of a_size on which to apply the mapping. /// /// # Panics /// - /// The method will panic if the argument `a` is greater than `a.cols()`. + /// The method will panic if the argument `a` is greater than `a.size()`. fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { let n: usize = self.n(); #[cfg(debug_assertions)] @@ -715,26 +710,26 @@ impl VecZnxOps for Module { self.ptr, k, b.as_mut_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, ); } } - /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each cols. + /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size. /// /// # Arguments /// /// * `a`: input and output. /// * `k`: the power to which to map each coefficients. - /// * `a_cols`: the number of cols on which to apply the mapping. + /// * `a_size`: the number of size on which to apply the mapping. /// /// # Panics /// - /// The method will panic if the argument `cols` is greater than `self.cols()`. + /// The method will panic if the argument `size` is greater than `self.size()`. fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { let n: usize = self.n(); #[cfg(debug_assertions)] @@ -746,10 +741,10 @@ impl VecZnxOps for Module { self.ptr, k, a.as_mut_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, ); } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 8b31ea6..b512fd8 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -113,10 +113,6 @@ impl Infos for VecZnxDft { self.n } - fn size(&self) -> usize { - self.size - } - fn layout(&self) -> LAYOUT { self.layout } diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index 7d6c26f..b04232d 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -5,27 +5,25 @@ use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_ /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], /// stored as a 3D matrix in the DFT domain in a single contiguous array. -/// Each row of the [VmpPMat] can be seen as a [VecZnxDft]. +/// Each col of the [VmpPMat] can be seen as a collection of [VecZnxDft]. /// -/// The backend array of [VmpPMat] is allocate in C, -/// and thus must be manually freed. -/// -/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx] and a [VmpPMat]. +/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [VmpPMat]. /// See the trait [VmpPMatOps] for additional information. pub struct VmpPMat { /// Raw data, is empty if borrowing scratch space. data: Vec, /// Pointer to data. Can point to scratch space. ptr: *mut u8, - /// The number of [VecZnxDft]. + /// The size of the decomposition basis (i.e. nb. [VecZnxDft]). rows: usize, - /// The number of cols in each [VecZnxDft]. + /// The size of each [VecZnxDft]. cols: usize, /// The ring degree of each [VecZnxDft]. n: usize, - /// The number of stacked [VmpPMat], must be a square. - size: usize, - /// The memory layout of the stacked [VmpPMat]. + /// 1nd dim: the number of stacked [VecZnxDft] per decomposition basis (row-dimension). + /// A value greater than one enables to compute a sum of [VecZnx] x [VmpPMat]. + /// 2st dim: the number of stacked [VecZnxDft] (col-dimension). + /// A value greater than one enables to compute multiple [VecZnx] x [VmpPMat] in parallel. layout: LAYOUT, /// The backend fft or ntt. backend: BACKEND, @@ -531,6 +529,7 @@ impl VmpPMatOps for Module { #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); + assert_eq!(a.size()*a.size(), b.size()); } unsafe { vmp::vmp_apply_dft( @@ -539,7 +538,7 @@ impl VmpPMatOps for Module { c.cols() as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (a.n()*a.size()) as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, b.cols() as u64, @@ -561,7 +560,7 @@ impl VmpPMatOps for Module { c.cols() as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (a.n()*a.size()) as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, b.cols() as u64, From 90b34e171d7a89808642be78887bf31705947dfc Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 25 Apr 2025 11:08:55 +0200 Subject: [PATCH 02/87] fixed typo in doc --- base2k/src/infos.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/base2k/src/infos.rs b/base2k/src/infos.rs index 2445022..ba799d7 100644 --- a/base2k/src/infos.rs +++ b/base2k/src/infos.rs @@ -9,12 +9,12 @@ pub struct LAYOUT{ cols: usize, /// Number of limbs per polynomial. size: usize, - /// Whether limbs are interleaved across rows. + /// Whether limbs are interleaved inside a row. /// /// For example, for (rows, cols, size) = (2, 2, 3): /// - /// - `true`: layout is ((a0, b0, a1, b1), (c0, d0, c1, d1)) - /// - `false`: layout is ((a0, a1, b0, b1), (c0, c1, d0, d1)) + /// - `true`: layout is ((a0, b0, a1, b1, a2, b2), (c0, d0, c1, d1, c2, d2)) + /// - `false`: layout is ((a0, a1, a2, b0, b1, b2), (c0, c1, c2, d0, d1, d2)) interleaved : bool, } From 2a96f890473d0d008691155aec6cb6f09fac05d5 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 25 Apr 2025 15:24:09 +0200 Subject: [PATCH 03/87] wip --- base2k/examples/rlwe_encrypt.rs | 38 ++- base2k/examples/vector_matrix_product.rs | 35 ++- base2k/src/encoding.rs | 151 ++++----- base2k/src/ffi/vec_znx_big.rs | 79 ++--- base2k/src/infos.rs | 30 +- base2k/src/lib.rs | 20 +- base2k/src/module.rs | 48 ++- base2k/src/sampling.rs | 76 +++-- base2k/src/svp.rs | 57 ++-- base2k/src/vec_znx.rs | 374 +++++++++++----------- base2k/src/vec_znx_big.rs | 226 +++++++------ base2k/src/vec_znx_dft.rs | 227 +++++++------- base2k/src/vmp.rs | 384 ++++++++++------------- rlwe/src/ciphertext.rs | 4 +- rlwe/src/elem.rs | 6 +- rlwe/src/plaintext.rs | 4 +- 16 files changed, 864 insertions(+), 895 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index f66a4d1..1da44e9 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,5 +1,5 @@ use base2k::{ - BACKEND, Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, + Encoding, FFT64, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned, }; use itertools::izip; @@ -8,44 +8,48 @@ use sampling::source::Source; fn main() { let n: usize = 16; let log_base2k: usize = 18; - let cols: usize = 3; + let limbs: usize = 3; let msg_cols: usize = 2; let log_scale: usize = msg_cols * log_base2k - 5; - let module: Module = Module::new(n, BACKEND::FFT64); + let module: Module = Module::::new(n); let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); - let mut res: VecZnx = module.new_vec_znx(1, cols); + let mut res: VecZnx = module.new_vec_znx(1, limbs); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) let mut s: Scalar = Scalar::new(n); s.fill_ternary_prob(0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_ppol: SvpPPol = module.new_svp_ppol(); + let mut s_ppol: SvpPPol = module.new_svp_ppol(); // s_ppol <- DFT(s) module.svp_prepare(&mut s_ppol, &s); // a <- Z_{2^prec}[X]/(X^{N}+1) - let mut a: VecZnx = module.new_vec_znx(1, cols); - module.fill_uniform(log_base2k, &mut a, cols, &mut source); + let mut a: VecZnx = module.new_vec_znx(1, limbs); + module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); + + // Scratch space for DFT values - let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.cols()); + let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.limbs()); // Applies buf_dft <- s * a module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); // Alias scratch space - let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); + let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); // buf_big <- IDFT(buf_dft) (not normalized) module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); + println!("{:?}", buf_big.raw()); + let mut m: VecZnx = module.new_vec_znx(1, msg_cols); let mut want: Vec = vec![0; n]; @@ -59,13 +63,17 @@ fn main() { // buf_big <- m - buf_big module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m); + println!("{:?}", buf_big.raw()); + // b <- normalize(buf_big) + e - let mut b: VecZnx = module.new_vec_znx(1, cols); + let mut b: VecZnx = module.new_vec_znx(1, limbs); module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry); + b.print(n); module.add_normal( log_base2k, &mut b, - log_base2k * cols, + 0, + log_base2k * limbs, &mut source, 3.2, 19.0, @@ -80,14 +88,18 @@ fn main() { // buf_big <- a * s + b module.vec_znx_big_add_small_inplace(&mut buf_big, &b); + println!("raw: {:?}", &buf_big.raw()); + // res <- normalize(buf_big) module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry); + + // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; - res.decode_vec_i64(0, log_base2k, res.cols() * log_base2k, &mut have); + res.decode_vec_i64(0, log_base2k, res.limbs() * log_base2k, &mut have); - let scale: f64 = (1 << (res.cols() * log_base2k - log_scale)) as f64; + let scale: f64 = (1 << (res.limbs() * log_base2k - log_scale)) as f64; izip!(want.iter(), have.iter()) .enumerate() .for_each(|(i, (a, b))| { diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index a69c857..4e8b97e 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -1,5 +1,5 @@ use base2k::{ - BACKEND, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, + Encoding, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, alloc_aligned, }; @@ -7,50 +7,51 @@ fn main() { let log_n: i32 = 5; let n: usize = 1 << log_n; - let module: Module = Module::new(n, BACKEND::FFT64); + let module: Module = Module::::new(n); let log_base2k: usize = 15; - let cols: usize = 5; - let log_k: usize = log_base2k * cols - 5; + let limbs_vec: usize = 5; + let log_k: usize = log_base2k * limbs_vec - 5; - let rows: usize = cols; - let cols: usize = cols + 1; + let rows_mat: usize = limbs_vec; + let limbs_mat: usize = limbs_vec + 1; // Maximum size of the byte scratch needed - let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows, cols) | module.vmp_apply_dft_tmp_bytes(cols, cols, rows, cols); + let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows_mat, 1, limbs_mat) + | module.vmp_apply_dft_tmp_bytes(limbs_vec, limbs_vec, rows_mat, limbs_mat); let mut buf: Vec = alloc_aligned(tmp_bytes); let mut a_values: Vec = vec![i64::default(); n]; a_values[1] = (1 << log_base2k) + 1; - let mut a: VecZnx = module.new_vec_znx(1, rows); + let mut a: VecZnx = module.new_vec_znx(1, limbs_vec); a.encode_vec_i64(0, log_base2k, log_k, &a_values, 32); a.normalize(log_base2k, &mut buf); - a.print(0, a.cols(), n); + a.print(n); println!(); - let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(1, rows, cols); + let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows_mat, 1, limbs_mat); - (0..a.cols()).for_each(|row_i| { - let mut tmp: VecZnx = module.new_vec_znx(1, cols); - tmp.at_mut(row_i)[1] = 1 as i64; + (0..a.limbs()).for_each(|row_i| { + let mut tmp: VecZnx = module.new_vec_znx(1, limbs_mat); + tmp.at_limb_mut(row_i)[1] = 1 as i64; module.vmp_prepare_row(&mut vmp_pmat, tmp.raw(), row_i, &mut buf); }); - let mut c_dft: VecZnxDft = module.new_vec_znx_dft(1, cols); + let mut c_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs_mat); module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); - let mut c_big: VecZnxBig = c_dft.as_vec_znx_big(); + let mut c_big: VecZnxBig = c_dft.as_vec_znx_big(); module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft); - let mut res: VecZnx = module.new_vec_znx(1, rows); + let mut res: VecZnx = module.new_vec_znx(1, limbs_vec); module.vec_znx_big_normalize(log_base2k, &mut res, &c_big, &mut buf); let mut values_res: Vec = vec![i64::default(); n]; res.decode_vec_i64(0, log_base2k, log_k, &mut values_res); - res.print(0, res.cols(), n); + res.print(n); module.free(); diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index c8c08e9..d4085cb 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -9,129 +9,130 @@ pub trait Encoding { /// /// # Arguments /// - /// * `poly_idx`: the index of the poly where to encode the data. + /// * `col_i`: the index of the poly where to encode the data. /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `log_k`: base two negative logarithm of the scaling of the data. /// * `data`: data to encode on the receiver. /// * `log_max`: base two logarithm of the infinity norm of the input data. - fn encode_vec_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize); + fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize); /// decode a vector of i64 from the receiver. /// /// # Arguments /// - /// * `poly_idx`: the index of the poly where to encode the data. + /// * `col_i`: the index of the poly where to encode the data. /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `log_k`: base two logarithm of the scaling of the data. /// * `data`: data to decode from the receiver. - fn decode_vec_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]); + fn decode_vec_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]); /// decode a vector of Float from the receiver. /// /// # Arguments - /// * `poly_idx`: the index of the poly where to encode the data. + /// * `col_i`: the index of the poly where to encode the data. /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `data`: data to decode from the receiver. - fn decode_vec_float(&self, poly_idx: usize, log_base2k: usize, data: &mut [Float]); + fn decode_vec_float(&self, col_i: usize, log_base2k: usize, data: &mut [Float]); /// encodes a single i64 on the receiver at the given index. /// /// # Arguments /// - /// * `poly_idx`: the index of the poly where to encode the data. + /// * `col_i`: the index of the poly where to encode the data. /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `log_k`: base two negative logarithm of the scaling of the data. /// * `i`: index of the coefficient on which to encode the data. /// * `data`: data to encode on the receiver. /// * `log_max`: base two logarithm of the infinity norm of the input data. - fn encode_coeff_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize); + fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize); /// decode a single of i64 from the receiver at the given index. /// /// # Arguments /// - /// * `poly_idx`: the index of the poly where to encode the data. + /// * `col_i`: the index of the poly where to encode the data. /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `log_k`: base two negative logarithm of the scaling of the data. /// * `i`: index of the coefficient to decode. /// * `data`: data to decode from the receiver. - fn decode_coeff_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64; + fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64; } impl Encoding for VecZnx { - fn encode_vec_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { - encode_vec_i64(self, poly_idx, log_base2k, log_k, data, log_max) + fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { + encode_vec_i64(self, col_i, log_base2k, log_k, data, log_max) } - fn decode_vec_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { - decode_vec_i64(self, poly_idx, log_base2k, log_k, data) + fn decode_vec_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { + decode_vec_i64(self, col_i, log_base2k, log_k, data) } - fn decode_vec_float(&self, poly_idx: usize, log_base2k: usize, data: &mut [Float]) { - decode_vec_float(self, poly_idx, log_base2k, data) + fn decode_vec_float(&self, col_i: usize, log_base2k: usize, data: &mut [Float]) { + decode_vec_float(self, col_i, log_base2k, data) } - fn encode_coeff_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { - encode_coeff_i64(self, poly_idx, log_base2k, log_k, i, value, log_max) + fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { + encode_coeff_i64(self, col_i, log_base2k, log_k, i, value, log_max) } - fn decode_coeff_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { - decode_coeff_i64(self, poly_idx, log_base2k, log_k, i) + fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { + decode_coeff_i64(self, col_i, log_base2k, log_k, i) } } -fn encode_vec_i64(a: &mut VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { - let cols: usize = (log_k + log_base2k - 1) / log_base2k; +fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { + let limbs: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!( - cols <= a.cols(), - "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", - cols, - a.cols() + limbs <= a.limbs(), + "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.limbs()={}", + limbs, + a.limbs() ); - assert!(poly_idx < a.size); + assert!(col_i < a.cols()); assert!(data.len() <= a.n()) } let data_len: usize = data.len(); let log_k_rem: usize = log_base2k - (log_k % log_base2k); - (0..a.cols()).for_each(|i| unsafe { - znx_zero_i64_ref(a.n() as u64, a.at_poly_mut_ptr(poly_idx, i)); + // Zeroes coefficients of the i-th column + (0..a.limbs()).for_each(|i| unsafe { + znx_zero_i64_ref(a.n() as u64, a.at_mut_ptr(col_i, i)); }); // If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_poly_mut(poly_idx, cols - 1)[..data_len].copy_from_slice(&data[..data_len]); + a.at_poly_mut(col_i, limbs - 1)[..data_len].copy_from_slice(&data[..data_len]); } else { let mask: i64 = (1 << log_base2k) - 1; - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - (cols - steps..cols) + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); + (limbs - steps..limbs) .rev() .enumerate() .for_each(|(i, i_rev)| { let shift: usize = i * log_base2k; - izip!(a.at_poly_mut(poly_idx, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); + izip!(a.at_poly_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); }) } // Case where self.prec % self.k != 0. if log_k_rem != log_base2k { - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - (cols - steps..cols).rev().for_each(|i| { - a.at_poly_mut(poly_idx, i)[..data_len] + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); + (limbs - steps..limbs).rev().for_each(|i| { + a.at_poly_mut(col_i, i)[..data_len] .iter_mut() .for_each(|x| *x <<= log_k_rem); }) } } -fn decode_vec_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { - let cols: usize = (log_k + log_base2k - 1) / log_base2k; +fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { + let limbs: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!( @@ -140,26 +141,26 @@ fn decode_vec_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data.len(), a.n() ); - assert!(poly_idx < a.size()); + assert!(col_i < a.cols()); } - data.copy_from_slice(a.at_poly(poly_idx, 0)); + data.copy_from_slice(a.at_poly(col_i, 0)); let rem: usize = log_base2k - (log_k % log_base2k); - (1..cols).for_each(|i| { - if i == cols - 1 && rem != log_base2k { + (1..limbs).for_each(|i| { + if i == limbs - 1 && rem != log_base2k { let k_rem: usize = log_base2k - rem; - izip!(a.at_poly(poly_idx, i).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << k_rem) + (x >> rem); }); } else { - izip!(a.at_poly(poly_idx, i).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << log_base2k) + x; }); } }) } -fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [Float]) { - let cols: usize = a.cols(); +fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) { + let limbs: usize = a.limbs(); #[cfg(debug_assertions)] { assert!( @@ -168,23 +169,23 @@ fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [ data.len(), a.n() ); - assert!(poly_idx < a.size()); + assert!(col_i < a.cols()); } - let prec: u32 = (log_base2k * cols) as u32; + let prec: u32 = (log_base2k * limbs) as u32; // 2^{log_base2k} let base = Float::with_val(prec, (1 << log_base2k) as f64); // y[i] = sum x[j][i] * 2^{-log_base2k*j} - (0..cols).for_each(|i| { + (0..limbs).for_each(|i| { if i == 0 { - izip!(a.at_poly(poly_idx, cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(col_i, limbs - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { y.assign(*x); *y /= &base; }); } else { - izip!(a.at_poly(poly_idx, cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(col_i, limbs - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { *y += Float::with_val(prec, *x); *y /= &base; }); @@ -192,61 +193,61 @@ fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [ }); } -fn encode_coeff_i64(a: &mut VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { - let cols: usize = (log_k + log_base2k - 1) / log_base2k; +fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { + let limbs: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!(i < a.n()); assert!( - cols <= a.cols(), - "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", - cols, - a.cols() + limbs <= a.limbs(), + "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.limbs()={}", + limbs, + a.limbs() ); - assert!(poly_idx < a.size()); + assert!(col_i < a.cols()); } let log_k_rem: usize = log_base2k - (log_k % log_base2k); - (0..a.cols()).for_each(|j| a.at_poly_mut(poly_idx, j)[i] = 0); + (0..a.limbs()).for_each(|j| a.at_poly_mut(col_i, j)[i] = 0); // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_poly_mut(poly_idx, cols - 1)[i] = value; + a.at_poly_mut(col_i, limbs - 1)[i] = value; } else { let mask: i64 = (1 << log_base2k) - 1; - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - (cols - steps..cols) + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); + (limbs - steps..limbs) .rev() .enumerate() .for_each(|(j, j_rev)| { - a.at_poly_mut(poly_idx, j_rev)[i] = (value >> (j * log_base2k)) & mask; + a.at_poly_mut(col_i, j_rev)[i] = (value >> (j * log_base2k)) & mask; }) } // Case where prec % k != 0. if log_k_rem != log_base2k { - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - (cols - steps..cols).rev().for_each(|j| { - a.at_poly_mut(poly_idx, j)[i] <<= log_k_rem; + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); + (limbs - steps..limbs).rev().for_each(|j| { + a.at_poly_mut(col_i, j)[i] <<= log_k_rem; }) } } -fn decode_coeff_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { +fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { #[cfg(debug_assertions)] { assert!(i < a.n()); - assert!(poly_idx < a.size()) + assert!(col_i < a.cols()) } let cols: usize = (log_k + log_base2k - 1) / log_base2k; let data: &[i64] = a.raw(); let mut res: i64 = data[i]; let rem: usize = log_base2k - (log_k % log_base2k); - let slice_size: usize = a.n() * a.size(); + let slice_size: usize = a.n() * a.limbs(); (1..cols).for_each(|i| { let x = data[i * slice_size]; if i == cols - 1 && rem != log_base2k { @@ -275,13 +276,13 @@ mod tests { let mut source: Source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - (0..a.size()).for_each(|poly_idx| { + (0..a.cols()).for_each(|col_i| { let mut have: Vec = vec![i64::default(); n]; have.iter_mut() .for_each(|x| *x = (source.next_i64() << 56) >> 56); - a.encode_vec_i64(poly_idx, log_base2k, log_k, &have, 10); + a.encode_vec_i64(col_i, log_base2k, log_k, &have, 10); let mut want: Vec = vec![i64::default(); n]; - a.decode_vec_i64(poly_idx, log_base2k, log_k, &mut want); + a.decode_vec_i64(col_i, log_base2k, log_k, &mut want); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); }); } @@ -296,12 +297,12 @@ mod tests { let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - (0..a.size()).for_each(|poly_idx| { + (0..a.cols()).for_each(|col_i| { let mut have: Vec = vec![i64::default(); n]; have.iter_mut().for_each(|x| *x = source.next_i64()); - a.encode_vec_i64(poly_idx, log_base2k, log_k, &have, 64); + a.encode_vec_i64(col_i, log_base2k, log_k, &have, 64); let mut want = vec![i64::default(); n]; - a.decode_vec_i64(poly_idx, log_base2k, log_k, &mut want); + a.decode_vec_i64(col_i, log_base2k, log_k, &mut want); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); }) } diff --git a/base2k/src/ffi/vec_znx_big.rs b/base2k/src/ffi/vec_znx_big.rs index e1222c3..8c06e90 100644 --- a/base2k/src/ffi/vec_znx_big.rs +++ b/base2k/src/ffi/vec_znx_big.rs @@ -8,17 +8,17 @@ pub struct vec_znx_big_t { pub type VEC_ZNX_BIG = vec_znx_big_t; unsafe extern "C" { - pub fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64; + pub unsafe fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64; } unsafe extern "C" { - pub fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG; + pub unsafe fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG; } unsafe extern "C" { - pub fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG); + pub unsafe fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG); } unsafe extern "C" { - pub fn vec_znx_big_add( + pub unsafe fn vec_znx_big_add( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -29,7 +29,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_add_small( + pub unsafe fn vec_znx_big_add_small( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -41,7 +41,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_add_small2( + pub unsafe fn vec_znx_big_add_small2( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -54,7 +54,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_sub( + pub unsafe fn vec_znx_big_sub( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -65,7 +65,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_sub_small_b( + pub unsafe fn vec_znx_big_sub_small_b( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -77,7 +77,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_sub_small_a( + pub unsafe fn vec_znx_big_sub_small_a( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -89,7 +89,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_sub_small2( + pub unsafe fn vec_znx_big_sub_small2( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -101,8 +101,13 @@ unsafe extern "C" { b_sl: u64, ); } + unsafe extern "C" { - pub fn vec_znx_big_normalize_base2k( + pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_normalize_base2k( module: *const MODULE, log2_base2k: u64, res: *mut i64, @@ -113,34 +118,9 @@ unsafe extern "C" { tmp_space: *mut u8, ); } -unsafe extern "C" { - pub fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; -} unsafe extern "C" { - pub fn vec_znx_big_automorphism( - module: *const MODULE, - p: i64, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const VEC_ZNX_BIG, - a_size: u64, - ); -} - -unsafe extern "C" { - pub fn vec_znx_big_rotate( - module: *const MODULE, - p: i64, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const VEC_ZNX_BIG, - a_size: u64, - ); -} - -unsafe extern "C" { - pub fn vec_znx_big_range_normalize_base2k( + pub unsafe fn vec_znx_big_range_normalize_base2k( module: *const MODULE, log2_base2k: u64, res: *mut i64, @@ -153,6 +133,29 @@ unsafe extern "C" { tmp_space: *mut u8, ); } + unsafe extern "C" { - pub fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; + pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_automorphism( + module: *const MODULE, + p: i64, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + ); +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_rotate( + module: *const MODULE, + p: i64, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + ); } diff --git a/base2k/src/infos.rs b/base2k/src/infos.rs index ba799d7..764a7fe 100644 --- a/base2k/src/infos.rs +++ b/base2k/src/infos.rs @@ -1,28 +1,4 @@ -#[derive(Copy, Clone)] -#[repr(C)] -pub struct LAYOUT{ - /// Ring degree. - n: usize, - /// Number of logical rows in the layout. - rows: usize, - /// Number of polynomials per row. - cols: usize, - /// Number of limbs per polynomial. - size: usize, - /// Whether limbs are interleaved inside a row. - /// - /// For example, for (rows, cols, size) = (2, 2, 3): - /// - /// - `true`: layout is ((a0, b0, a1, b1, a2, b2), (c0, d0, c1, d1, c2, d2)) - /// - `false`: layout is ((a0, a1, a2, b0, b1, b2), (c0, c1, c2, d0, d1, d2)) - interleaved : bool, -} - pub trait Infos { - - /// Returns the full layout. - fn layout(&self) -> LAYOUT; - /// Returns the ring degree of the polynomials. fn n(&self) -> usize; @@ -36,8 +12,8 @@ pub trait Infos { fn cols(&self) -> usize; /// Returns the number of limbs per polynomial. - fn size(&self) -> usize; + fn limbs(&self) -> usize; - /// Whether limbs are interleaved across rows. - fn interleaved(&self) -> bool; + /// Returns the total number of small polynomials. + fn poly_count(&self) -> usize; } diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index ec0d2b7..5144afd 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -27,7 +27,7 @@ pub use vmp::*; pub const GALOISGENERATOR: u64 = 5; pub const DEFAULTALIGN: usize = 64; -pub fn is_aligned_custom(ptr: *const T, align: usize) -> bool { +fn is_aligned_custom(ptr: *const T, align: usize) -> bool { (ptr as usize) % align == 0 } @@ -54,13 +54,10 @@ pub fn cast_mut(data: &[T]) -> &mut [V] { unsafe { std::slice::from_raw_parts_mut(ptr, len) } } -use std::alloc::{Layout, alloc}; -use std::ptr; - /// Allocates a block of bytes with a custom alignement. /// Alignement must be a power of two and size a multiple of the alignement. /// Allocated memory is initialized to zero. -pub fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { +fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { assert!( align.is_power_of_two(), "Alignment must be a power of two but is {}", @@ -74,8 +71,8 @@ pub fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { align ); unsafe { - let layout: Layout = Layout::from_size_align(size, align).expect("Invalid alignment"); - let ptr: *mut u8 = alloc(layout); + let layout: std::alloc::Layout = std::alloc::Layout::from_size_align(size, align).expect("Invalid alignment"); + let ptr: *mut u8 = std::alloc::alloc(layout); if ptr.is_null() { panic!("Memory allocation failed"); } @@ -86,18 +83,11 @@ pub fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { align ); // Init allocated memory to zero - ptr::write_bytes(ptr, 0, size); + std::ptr::write_bytes(ptr, 0, size); Vec::from_raw_parts(ptr, size, size) } } -/// Allocates a block of bytes aligned with [DEFAULTALIGN]. -/// Size must be amultiple of [DEFAULTALIGN]. -/// /// Allocated memory is initialized to zero. -pub fn alloc_aligned_u8(size: usize) -> Vec { - alloc_aligned_custom_u8(size, DEFAULTALIGN) -} - /// Allocates a block of T aligned with [DEFAULTALIGN]. /// Size of T * size msut be a multiple of [DEFAULTALIGN]. pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { diff --git a/base2k/src/module.rs b/base2k/src/module.rs index 8cbdbca..205cf62 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -1,5 +1,6 @@ use crate::GALOISGENERATOR; use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info}; +use std::marker::PhantomData; #[derive(Copy, Clone)] #[repr(u8)] @@ -8,37 +9,50 @@ pub enum BACKEND { NTT120, } -pub struct Module { - pub ptr: *mut MODULE, - pub n: usize, - pub backend: BACKEND, +pub trait Backend { + const KIND: BACKEND; + fn module_type() -> u32; } -impl Module { +pub struct FFT64; +pub struct NTT120; + +impl Backend for FFT64 { + const KIND: BACKEND = BACKEND::FFT64; + fn module_type() -> u32 { + 0 + } +} + +impl Backend for NTT120 { + const KIND: BACKEND = BACKEND::NTT120; + fn module_type() -> u32 { + 1 + } +} + +pub struct Module { + pub ptr: *mut MODULE, + pub n: usize, + _marker: PhantomData, +} + +impl Module { // Instantiates a new module. - pub fn new(n: usize, module_type: BACKEND) -> Self { + pub fn new(n: usize) -> Self { unsafe { - let module_type_u32: u32; - match module_type { - BACKEND::FFT64 => module_type_u32 = 0, - BACKEND::NTT120 => module_type_u32 = 1, - } - let m: *mut module_info_t = new_module_info(n as u64, module_type_u32); + let m: *mut module_info_t = new_module_info(n as u64, B::module_type()); if m.is_null() { panic!("Failed to create module."); } Self { ptr: m, n: n, - backend: module_type, + _marker: PhantomData, } } } - pub fn backend(&self) -> BACKEND { - self.backend - } - pub fn n(&self) -> usize { self.n } diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index 064c1e2..db9a79b 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,16 +1,17 @@ -use crate::{Infos, Module, VecZnx}; +use crate::{Backend, Infos, Module, VecZnx}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; pub trait Sampling { - /// Fills the first `cols` cols with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source); + /// Fills the first `limbs` limbs with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] + fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source); /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. fn add_dist_f64>( &self, log_base2k: usize, a: &mut VecZnx, + col_i: usize, log_k: usize, source: &mut Source, dist: D, @@ -18,24 +19,35 @@ pub trait Sampling { ); /// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\]. - fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64); + fn add_normal( + &self, + log_base2k: usize, + a: &mut VecZnx, + col_i: usize, + log_k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ); } -impl Sampling for Module { - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source) { +impl Sampling for Module { + fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source) { let base2k: u64 = 1 << log_base2k; let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; - let size: usize = a.n() * cols; - a.raw_mut()[..size] - .iter_mut() - .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); + (0..limbs).for_each(|j| { + a.at_poly_mut(col_i, j) + .iter_mut() + .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); + }) } fn add_dist_f64>( &self, log_base2k: usize, a: &mut VecZnx, + col_i: usize, log_k: usize, source: &mut Source, dist: D, @@ -50,28 +62,42 @@ impl Sampling for Module { let log_base2k_rem: usize = log_k % log_base2k; if log_base2k_rem != 0 { - a.at_mut(a.cols() - 1).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += (dist_f64.round() as i64) << log_base2k_rem; - }); + a.at_poly_mut(col_i, a.limbs() - 1) + .iter_mut() + .for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += (dist_f64.round() as i64) << log_base2k_rem; + }); } else { - a.at_mut(a.cols() - 1).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += dist_f64.round() as i64 - }); + a.at_poly_mut(col_i, a.limbs() - 1) + .iter_mut() + .for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += dist_f64.round() as i64 + }); } } - fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64) { + fn add_normal( + &self, + log_base2k: usize, + a: &mut VecZnx, + col_i: usize, + log_k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) { self.add_dist_f64( log_base2k, a, + col_i, log_k, source, Normal::new(0.0, sigma).unwrap(), diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs index bc37f86..e293668 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/svp.rs @@ -1,6 +1,8 @@ +use std::marker::PhantomData; + use crate::ffi::svp::{self, svp_ppol_t}; use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::{BACKEND, LAYOUT, Module, VecZnx, VecZnxDft, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, assert_alignement}; use crate::{Infos, alloc_aligned, cast_mut}; use rand::seq::SliceRandom; @@ -14,7 +16,7 @@ pub struct Scalar { pub ptr: *mut i64, } -impl Module { +impl Module { pub fn new_scalar(&self) -> Scalar { Scalar::new(self.n()) } @@ -117,9 +119,8 @@ impl Scalar { pub fn as_vec_znx(&self) -> VecZnx { VecZnx { n: self.n, - size: 1, // TODO REVIEW IF NEED TO ADD size TO SCALAR cols: 1, - layout: LAYOUT::COL(1, 1), + limbs: 1, data: Vec::new(), ptr: self.ptr, } @@ -132,7 +133,7 @@ pub trait ScalarOps { fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar; fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar; } -impl ScalarOps for Module { +impl ScalarOps for Module { fn bytes_of_scalar(&self) -> usize { Scalar::bytes_of(self.n()) } @@ -147,17 +148,17 @@ impl ScalarOps for Module { } } -pub struct SvpPPol { +pub struct SvpPPol { pub n: usize, pub data: Vec, pub ptr: *mut u8, - pub backend: BACKEND, + _marker: PhantomData, } /// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft]. /// An [SvpPPol] an be seen as a [VecZnxDft] of one limb. -impl SvpPPol { - pub fn new(module: &Module) -> Self { +impl SvpPPol { + pub fn new(module: &Module) -> Self { module.new_svp_ppol() } @@ -166,11 +167,11 @@ impl SvpPPol { self.n } - pub fn bytes_of(module: &Module) -> usize { + pub fn bytes_of(module: &Module) -> usize { module.bytes_of_svp_ppol() } - pub fn from_bytes(module: &Module, bytes: &mut [u8]) -> SvpPPol { + pub fn from_bytes(module: &Module, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { assert_alignement(bytes.as_ptr()); @@ -181,12 +182,12 @@ impl SvpPPol { n: module.n(), data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), ptr: bytes.as_mut_ptr(), - backend: module.backend(), + _marker: PhantomData, } } } - pub fn from_bytes_borrow(module: &Module, tmp_bytes: &mut [u8]) -> SvpPPol { + pub fn from_bytes_borrow(module: &Module, tmp_bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -196,7 +197,7 @@ impl SvpPPol { n: module.n(), data: Vec::new(), ptr: tmp_bytes.as_mut_ptr(), - backend: module.backend(), + _marker: PhantomData, } } @@ -206,9 +207,9 @@ impl SvpPPol { } } -pub trait SvpPPolOps { +pub trait SvpPPolOps { /// Allocates a new [SvpPPol]. - fn new_svp_ppol(&self) -> SvpPPol; + fn new_svp_ppol(&self) -> SvpPPol; /// Returns the minimum number of bytes necessary to allocate /// a new [SvpPPol] through [SvpPPol::from_bytes] ro. @@ -217,30 +218,30 @@ pub trait SvpPPolOps { /// Allocates a new [SvpPPol] from an array of bytes. /// The array of bytes is owned by the [SvpPPol]. /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol; + fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol; /// Allocates a new [SvpPPol] from an array of bytes. /// The array of bytes is borrowed by the [SvpPPol]. /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol; + fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol; /// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft]. - fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar); + fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar); /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// the [VecZnxDft] is multiplied with [SvpPPol]. - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx); + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx); } -impl SvpPPolOps for Module { - fn new_svp_ppol(&self) -> SvpPPol { +impl SvpPPolOps for Module { + fn new_svp_ppol(&self) -> SvpPPol { let mut data: Vec = alloc_aligned::(self.bytes_of_svp_ppol()); let ptr: *mut u8 = data.as_mut_ptr(); - SvpPPol { + SvpPPol:: { data: data, ptr: ptr, n: self.n(), - backend: self.backend(), + _marker: PhantomData, } } @@ -248,19 +249,19 @@ impl SvpPPolOps for Module { unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize } } - fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol { + fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol { SvpPPol::from_bytes(self, bytes) } - fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol { + fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol { SvpPPol::from_bytes_borrow(self, tmp_bytes) } - fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) { + fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) { unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) } } - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) { + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) { unsafe { svp::svp_apply_dft( self.ptr, diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 71a315e..aff1ce9 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,4 +1,4 @@ -use crate::LAYOUT; +use crate::Backend; use crate::cast_mut; use crate::ffi::vec_znx; use crate::ffi::znx; @@ -22,11 +22,11 @@ pub struct VecZnx { /// Polynomial degree. pub n: usize, - /// Number of limbs - pub size: usize, + /// The number of polynomials + pub cols: usize, - /// Layout - pub layout: LAYOUT, + /// The number of limbs per polynomial (a.k.a small polynomials). + pub limbs: usize, /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. pub data: Vec, @@ -35,58 +35,60 @@ pub struct VecZnx { pub ptr: *mut i64, } -pub fn bytes_of_vec_znx(n: usize, layout: LAYOUT, size: usize) -> usize { - n * layout.size() * size * 8 +pub fn bytes_of_vec_znx(n: usize, cols: usize, limbs: usize) -> usize { + n * cols * limbs * size_of::() } impl VecZnx { /// Returns a new struct implementing [VecZnx] with the provided data as backing array. /// - /// The struct will take ownership of buf[..[VecZnx::bytes_of]] + /// The struct will take ownership of buf[..[Self::bytes_of]] /// /// User must ensure that data is properly alligned and that - /// the size of data is equal to [VecZnx::bytes_of]. - pub fn from_bytes(n: usize, layout: LAYOUT, size: usize, bytes: &mut [u8]) -> Self { + /// the limbs of data is equal to [Self::bytes_of]. + pub fn from_bytes(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(n, layout, size)); + assert!(cols > 0); + assert!(limbs > 0); + assert_eq!(bytes.len(), Self::bytes_of(n, cols, limbs)); assert_alignement(bytes.as_ptr()); } unsafe { let bytes_i64: &mut [i64] = cast_mut::(bytes); let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - VecZnx { + Self { n: n, - size: size, - layout: layout, + cols: cols, + limbs: limbs, data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), ptr: ptr, } } } - pub fn from_bytes_borrow(n: usize, layout: LAYOUT, size: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes_borrow(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert!(size > 0); - assert!(bytes.len() >= Self::bytes_of(n, layout, size)); + assert!(cols > 0); + assert!(limbs > 0); + assert!(bytes.len() >= Self::bytes_of(n, cols, limbs)); assert_alignement(bytes.as_ptr()); } - VecZnx { + Self { n: n, - size: size, - layout: layout, + cols: cols, + limbs: limbs, data: Vec::new(), ptr: bytes.as_mut_ptr() as *mut i64, } } - pub fn bytes_of(n: usize, layout: LAYOUT, size: usize) -> usize { - bytes_of_vec_znx(n, layout, size) + pub fn bytes_of(n: usize, cols: usize, limbs: usize) -> usize { + bytes_of_vec_znx(n, cols, limbs) } - pub fn copy_from(&mut self, a: &VecZnx) { + pub fn copy_from(&mut self, a: &Self) { copy_vec_znx_from(self, a); } @@ -94,15 +96,15 @@ impl VecZnx { self.data.len() == 0 } - /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::size()]. + /// Total limbs is [Self::n()] * [Self::poly_count()]. pub fn raw(&self) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.size * self.size) } + unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.poly_count()) } } /// Returns a reference to backend slice of the receiver. - /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::size()]. + /// Total size is [Self::n()] * [Self::poly_count()]. pub fn raw_mut(&mut self) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.size * self.size) } + unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.poly_count()) } } /// Returns a non-mutable pointer to the backedn slice of the receiver. @@ -115,76 +117,55 @@ impl VecZnx { self.ptr } - /// Returns a non-mutable pointer starting a the j-th column. - pub fn at_ptr(&self, i: usize) -> *const i64 { + /// Returns a non-mutable pointer starting a the (i, j)-th small poly. + pub fn at_ptr(&self, i: usize, j: usize) -> *const i64 { #[cfg(debug_assertions)] { - assert!(i < self.size); + assert!(i < self.cols()); + assert!(j < self.limbs()); } - let offset: usize = self.n * self.size * i; + let offset: usize = self.n * (j * self.cols() + i); self.ptr.wrapping_add(offset) } - /// Returns non-mutable reference to the ith-column. - /// The slice contains [VecZnx::size()] small polynomials, each of [VecZnx::n()] coefficients. - pub fn at(&self, i: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(i), self.n * self.size) } + /// Returns a non-mutable reference to the i-th limb. + /// The returned array is of size [Self::n()] * [Self::cols()]. + pub fn at_limb(&self, i: usize) -> &[i64] { + unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) } } - /// Returns a non-mutable pointer starting a the j-th column of the i-th polynomial. - pub fn at_poly_ptr(&self, i: usize, j: usize) -> *const i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.size); - assert!(j < self.size); - } - let offset: usize = self.n * (self.size * j + i); - self.ptr.wrapping_add(offset) - } - - /// Returns non-mutable reference to the j-th column of the i-th polynomial. - /// The slice contains one small polynomial of [VecZnx::n()] coefficients. + /// Returns a non-mutable reference to the (i, j)-th poly. + /// The returned array is of size [Self::n()]. pub fn at_poly(&self, i: usize, j: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.at_poly_ptr(i, j), self.n) } + unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) } } - /// Returns a mutable pointer starting a the j-th column. - pub fn at_mut_ptr(&self, i: usize) -> *mut i64 { + /// Returns a mutable pointer starting a the (i, j)-th small poly. + pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut i64 { #[cfg(debug_assertions)] { - assert!(i < self.size); + assert!(i < self.cols()); + assert!(j < self.limbs()); } - let offset: usize = self.n * self.size * i; + + let offset: usize = self.n * (j * self.cols() + i); self.ptr.wrapping_add(offset) } - /// Returns mutable reference to the ith-column. - /// The slice contains [VecZnx::size()] small polynomials, each of [VecZnx::n()] coefficients. - pub fn at_mut(&mut self, i: usize) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i), self.n * self.size) } + /// Returns a mutable reference to the i-th limb. + /// The returned array is of size [Self::n()] * [Self::cols()]. + pub fn at_limb_mut(&mut self, i: usize) -> &mut [i64] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) } } - /// Returns a mutable pointer starting a the j-th column of the i-th polynomial. - pub fn at_poly_mut_ptr(&mut self, i: usize, j: usize) -> *mut i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.size); - assert!(j < self.size); - } - - let offset: usize = self.n * (self.size * j + i); - self.ptr.wrapping_add(offset) - } - - /// Returns mutable reference to the j-th column of the i-th polynomial. - /// The slice contains one small polynomial of [VecZnx::n()] coefficients. + /// Returns a mutable reference to the (i, j)-th poly. + /// The returned array is of size [Self::n()]. pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [i64] { - let ptr: *mut i64 = self.at_poly_mut_ptr(i, j); - unsafe { std::slice::from_raw_parts_mut(ptr, self.n) } + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) } } pub fn zero(&mut self) { - unsafe { znx::znx_zero_i64_ref((self.n * self.size * self.size) as u64, self.ptr) } + unsafe { znx::znx_zero_i64_ref((self.n * self.poly_count()) as u64, self.ptr) } } pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { @@ -195,48 +176,47 @@ impl VecZnx { rsh(log_base2k, self, k, carry) } - pub fn switch_degree(&self, a: &mut VecZnx) { + pub fn switch_degree(&self, a: &mut Self) { switch_degree(a, self) } - pub fn print(&self, poly: usize, size: usize, n: usize) { - (0..size).for_each(|i| println!("{}: {:?}", i, &self.at_poly(poly, i)[..n])) + // Prints the first `n` coefficients of each limb + pub fn print(&self, n: usize) { + (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])) } } impl Infos for VecZnx { - /// Returns the base 2 logarithm of the [VecZnx] degree. - fn log_n(&self) -> usize { - (usize::BITS - (self.n - 1).leading_zeros()) as _ - } - - /// Returns the [VecZnx] degree. fn n(&self) -> usize { self.n } - fn size(&self) -> usize { - self.size + fn log_n(&self) -> usize { + (usize::BITS - (self.n() - 1).leading_zeros()) as _ } - fn layout(&self) -> LAYOUT { - self.layout - } - - /// Returns the number of size of the [VecZnx]. - fn size(&self) -> usize { - self.size - } - - /// Returns the number of rows of the [VecZnx]. fn rows(&self) -> usize { 1 } + + fn cols(&self) -> usize { + self.cols + } + + fn limbs(&self) -> usize { + self.limbs + } + + fn poly_count(&self) -> usize { + self.cols * self.limbs + } } /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. +/// Panics if the cols do not match. pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { + assert_eq!(b.cols(), a.cols()); let data_a: &[i64] = a.raw(); let data_b: &mut [i64] = b.raw_mut(); let size = min(data_b.len(), data_a.len()); @@ -245,21 +225,20 @@ pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { impl VecZnx { /// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\]. - pub fn new(n: usize, size: usize, size: usize) -> Self { + pub fn new(n: usize, cols: usize, limbs: usize) -> Self { #[cfg(debug_assertions)] { assert!(n > 0); assert!(n & (n - 1) == 0); - assert!(size > 0); - assert!(size <= u8::MAX as usize); - assert!(size > 0); + assert!(cols > 0); + assert!(limbs > 0); } - let mut data: Vec = alloc_aligned::(n * size * size); + let mut data: Vec = alloc_aligned::(n * cols * limbs); let ptr: *mut i64 = data.as_mut_ptr(); Self { n: n, - layout: LAYOUT::COL(1, size as u8), - size: size, + cols: cols, + limbs: limbs, data: data, ptr: ptr, } @@ -278,16 +257,16 @@ impl VecZnx { if !self.borrowing() { self.data - .truncate((self.size() - k / log_base2k) * self.n() * self.size()); + .truncate(self.n() * self.cols() * (self.limbs() - k / log_base2k)); } - self.size -= k / log_base2k; + self.limbs -= k / log_base2k; let k_rem: usize = k % log_base2k; if k_rem != 0 { let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem; - self.at_mut(self.size() - 1) + self.at_limb_mut(self.limbs() - 1) .iter_mut() .for_each(|x: &mut i64| *x &= mask) } @@ -305,31 +284,31 @@ pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) { b.zero(); } - let size = min(a.size(), b.size()); + let limbs: usize = min(a.limbs(), b.limbs()); - (0..size).for_each(|i| { + (0..limbs).for_each(|i| { izip!( - a.at(i).iter().step_by(gap_in), - b.at_mut(i).iter_mut().step_by(gap_out) + a.at_limb(i).iter().step_by(gap_in), + b.at_limb_mut(i).iter_mut().step_by(gap_out) ) .for_each(|(x_in, x_out)| *x_out = *x_in); }); } -fn normalize_tmp_bytes(n: usize, size: usize) -> usize { - n * size * std::mem::size_of::() +fn normalize_tmp_bytes(n: usize, limbs: usize) -> usize { + n * limbs * std::mem::size_of::() } fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { let n: usize = a.n(); - let size: usize = a.size(); + let cols: usize = a.cols(); debug_assert!( - tmp_bytes.len() >= normalize_tmp_bytes(n, size), + tmp_bytes.len() >= normalize_tmp_bytes(n, cols), "invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({}, {})", tmp_bytes.len(), n, - size, + cols, ); #[cfg(debug_assertions)] { @@ -340,45 +319,45 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { unsafe { znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); - (0..a.size()).rev().for_each(|i| { + (0..a.limbs()).rev().for_each(|i| { znx::znx_normalize( - (n * size) as u64, + (n * cols) as u64, log_base2k as u64, - a.at_mut_ptr(i), + a.at_mut_ptr(0, i), carry_i64.as_mut_ptr(), - a.at_mut_ptr(i), + a.at_mut_ptr(0, i), carry_i64.as_mut_ptr(), ) }); } } -pub fn rsh_tmp_bytes(n: usize, size: usize) -> usize { - n * size * std::mem::size_of::() +pub fn rsh_tmp_bytes(n: usize, limbs: usize) -> usize { + n * limbs * std::mem::size_of::() } pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); - let size: usize = a.size(); + let limbs: usize = a.limbs(); #[cfg(debug_assertions)] { assert!( - tmp_bytes.len() >= rsh_tmp_bytes(n, size), + tmp_bytes.len() >= rsh_tmp_bytes(n, limbs), "invalid carry: carry.len()/8={} < rsh_tmp_bytes({}, {})", tmp_bytes.len() >> 3, n, - size, + limbs, ); assert_alignement(tmp_bytes.as_ptr()); } - let size: usize = a.size(); + let limbs: usize = a.limbs(); let size_steps: usize = k / log_base2k; - a.raw_mut().rotate_right(n * size * size_steps); + a.raw_mut().rotate_right(n * limbs * size_steps); unsafe { - znx::znx_zero_i64_ref((n * size * size_steps) as u64, a.as_mut_ptr()); + znx::znx_zero_i64_ref((n * limbs * size_steps) as u64, a.as_mut_ptr()); } let k_rem = k % log_base2k; @@ -387,13 +366,13 @@ pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) { let carry_i64: &mut [i64] = cast_mut(tmp_bytes); unsafe { - znx::znx_zero_i64_ref((n * size) as u64, carry_i64.as_mut_ptr()); + znx::znx_zero_i64_ref((n * limbs) as u64, carry_i64.as_mut_ptr()); } let log_base2k: usize = log_base2k; - (size_steps..size).for_each(|i| { - izip!(carry_i64.iter_mut(), a.at_mut(i).iter_mut()).for_each(|(ci, xi)| { + (size_steps..limbs).for_each(|i| { + izip!(carry_i64.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| { *xi += *ci << log_base2k; *ci = get_base_k_carry(*xi, k_rem); *xi = (*xi - *ci) >> k_rem; @@ -412,14 +391,15 @@ pub trait VecZnxOps { /// /// # Arguments /// - /// * `size`: the number of size. - fn new_vec_znx(&self, size: usize, size: usize) -> VecZnx; + /// * `cols`: the number of polynomials. + /// * `limbs`: the number of limbs per polynomial (a.k.a small polynomials). + fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnx] through [VecZnx::from_bytes]. - fn bytes_of_vec_znx(&self, size: usize, size: usize) -> usize; + fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; - fn vec_znx_normalize_tmp_bytes(&self, size: usize) -> usize; + fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize; /// c <- a + b. fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); @@ -471,17 +451,17 @@ pub trait VecZnxOps { fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec); } -impl VecZnxOps for Module { - fn new_vec_znx(&self, size: usize, size: usize) -> VecZnx { - VecZnx::new(self.n(), size, size) +impl VecZnxOps for Module { + fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx { + VecZnx::new(self.n(), cols, limbs) } - fn bytes_of_vec_znx(&self, size: usize, size: usize) -> usize { - bytes_of_vec_znx(self.n(), size, size) + fn bytes_of_vec_znx(&self, cols: usize, limbs: usize) -> usize { + bytes_of_vec_znx(self.n(), cols, limbs) } - fn vec_znx_normalize_tmp_bytes(&self, size: usize) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * size } + fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols } } // c <- a + b @@ -497,14 +477,14 @@ impl VecZnxOps for Module { vec_znx::vec_znx_add( self.ptr, c.as_mut_ptr(), - c.size() as u64, - (n * c.size()) as u64, + c.limbs() as u64, + (n * c.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, b.as_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, ) } } @@ -521,14 +501,14 @@ impl VecZnxOps for Module { vec_znx::vec_znx_add( self.ptr, b.as_mut_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, b.as_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, ) } } @@ -546,14 +526,14 @@ impl VecZnxOps for Module { vec_znx::vec_znx_sub( self.ptr, c.as_mut_ptr(), - c.size() as u64, - (n * c.size()) as u64, + c.limbs() as u64, + (n * c.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, b.as_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, ) } } @@ -570,14 +550,14 @@ impl VecZnxOps for Module { vec_znx::vec_znx_sub( self.ptr, b.as_mut_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, b.as_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, ) } } @@ -594,14 +574,14 @@ impl VecZnxOps for Module { vec_znx::vec_znx_sub( self.ptr, b.as_mut_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, b.as_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, ) } } @@ -617,11 +597,11 @@ impl VecZnxOps for Module { vec_znx::vec_znx_negate( self.ptr, b.as_mut_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, ) } } @@ -636,11 +616,11 @@ impl VecZnxOps for Module { vec_znx::vec_znx_negate( self.ptr, a.as_mut_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, ) } } @@ -657,11 +637,11 @@ impl VecZnxOps for Module { self.ptr, k, b.as_mut_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, ) } } @@ -677,11 +657,11 @@ impl VecZnxOps for Module { self.ptr, k, a.as_mut_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, ) } } @@ -697,7 +677,7 @@ impl VecZnxOps for Module { /// /// # Panics /// - /// The method will panic if the argument `a` is greater than `a.size()`. + /// The method will panic if the argument `a` is greater than `a.limbs()`. fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { let n: usize = self.n(); #[cfg(debug_assertions)] @@ -710,11 +690,11 @@ impl VecZnxOps for Module { self.ptr, k, b.as_mut_ptr(), - b.size() as u64, - (n * b.size()) as u64, + b.limbs() as u64, + (n * b.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, ); } } @@ -729,7 +709,7 @@ impl VecZnxOps for Module { /// /// # Panics /// - /// The method will panic if the argument `size` is greater than `self.size()`. + /// The method will panic if the argument `size` is greater than `self.limbs()`. fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { let n: usize = self.n(); #[cfg(debug_assertions)] @@ -741,11 +721,11 @@ impl VecZnxOps for Module { self.ptr, k, a.as_mut_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, a.as_ptr(), - a.size() as u64, - (n * a.size()) as u64, + a.limbs() as u64, + (n * a.limbs()) as u64, ); } } diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 705a5ec..b19f126 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,24 +1,26 @@ use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; -use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement}; +use std::marker::PhantomData; -pub struct VecZnxBig { +pub struct VecZnxBig { pub data: Vec, pub ptr: *mut u8, pub n: usize, - pub size: usize, pub cols: usize, - pub layout: LAYOUT, - pub backend: BACKEND, + pub limbs: usize, + pub _marker: PhantomData, } -impl VecZnxBig { +impl VecZnxBig { /// Returns a new [VecZnxBig] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. - pub fn from_bytes(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(size, cols)); + assert!(cols > 0); + assert!(limbs > 0); + assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols, limbs)); assert_alignement(bytes.as_ptr()) }; unsafe { @@ -26,91 +28,84 @@ impl VecZnxBig { data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), ptr: bytes.as_mut_ptr(), n: module.n(), - size: size, - layout: LAYOUT::COL, cols: cols, - backend: module.backend, + limbs: limbs, + _marker: PhantomData, } } } - pub fn from_bytes_borrow(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(size, cols)); + assert!(cols > 0); + assert!(limbs > 0); + assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols, limbs)); assert_alignement(bytes.as_ptr()); } Self { data: Vec::new(), ptr: bytes.as_mut_ptr(), n: module.n(), - size: size, - layout: LAYOUT::COL, cols: cols, - backend: module.backend, + limbs: limbs, + _marker: PhantomData, } } - pub fn as_vec_znx_dft(&mut self) -> VecZnxDft { - VecZnxDft { + pub fn as_vec_znx_dft(&mut self) -> VecZnxDft { + VecZnxDft:: { data: Vec::new(), ptr: self.ptr, n: self.n, - size: self.size, - layout: LAYOUT::COL, cols: self.cols, - backend: self.backend, + limbs: self.limbs, + _marker: self._marker, } } - pub fn backend(&self) -> BACKEND { - self.backend + /// Returns a non-mutable reference to the entire contiguous array of the [VecZnxDft]. + pub fn raw(&self) -> &[i64] { + let ptr: *const i64 = self.ptr as *const i64; + unsafe { &std::slice::from_raw_parts(ptr, self.n() * self.poly_count()) } } - /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is cols * n. - pub fn raw(&self, module: &Module) -> &[T] { - let ptr: *const T = self.ptr as *const T; - let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::(); - unsafe { &std::slice::from_raw_parts(ptr, len) } + // Prints the first `n` coefficients of each limb + pub fn print(&self, n: usize) { + let raw: &[i64] = self.raw(); + (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &raw[i * self.n() * self.cols()..i * self.n() * self.cols()+n])) } } -impl Infos for VecZnxBig { - /// Returns the base 2 logarithm of the [VecZnx] degree. +impl Infos for VecZnxBig { fn log_n(&self) -> usize { (usize::BITS - (self.n - 1).leading_zeros()) as _ } - /// Returns the [VecZnx] degree. fn n(&self) -> usize { self.n } - fn size(&self) -> usize { - self.size - } - - fn layout(&self) -> LAYOUT { - self.layout - } - - /// Returns the number of cols of the [VecZnx]. fn cols(&self) -> usize { self.cols } - /// Returns the number of rows of the [VecZnx]. fn rows(&self) -> usize { 1 } + + fn limbs(&self) -> usize { + self.limbs + } + + fn poly_count(&self) -> usize { + self.cols * self.limbs + } } -pub trait VecZnxBigOps { +pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. - fn new_vec_znx_big(&self, size: usize, cols: usize) -> VecZnxBig; + fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -118,12 +113,13 @@ pub trait VecZnxBigOps { /// /// # Arguments /// - /// * `cols`: the number of cols of the [VecZnxBig]. + /// * `cols`: the number of polynomials.. + /// * `limbs`: the number of limbs (a.k.a small polynomials) per polynomial. /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxBig; + fn new_vec_znx_big_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxBig; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -131,33 +127,44 @@ pub trait VecZnxBigOps { /// /// # Arguments /// - /// * `cols`: the number of cols of the [VecZnxBig]. + /// * `cols`: the number of polynomials.. + /// * `limbs`: the number of limbs (a.k.a small polynomials) per polynomial. /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; + fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. - fn bytes_of_vec_znx_big(&self, size: usize, cols: usize) -> usize; + fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize; - /// b <- b - a - fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + /// b[VecZnxBig] <- b[VecZnxBig] - a[VecZnx] + /// + /// # Behavior + /// + /// [VecZnxBig] (3 cols and 4 limbs) + /// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3] + /// - + /// [VecZnx] (2 cols and 3 limbs) + /// [d0, e0] [d1, e1] [d2, e2] + /// = + /// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3] + fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); /// c <- b - a - fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); + fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); /// c <- b + a - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); + fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); /// b <- b + a - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; /// b <- normalize(a) - fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]); + fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]); fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize; @@ -165,100 +172,111 @@ pub trait VecZnxBigOps { &self, log_base2k: usize, res: &mut VecZnx, - a: &VecZnxBig, + a: &VecZnxBig, a_range_begin: usize, a_range_xend: usize, a_range_step: usize, tmp_bytes: &mut [u8], ); - fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig); + fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig); - fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig); + fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig); } -impl VecZnxBigOps for Module { - fn new_vec_znx_big(&self, size: usize, cols: usize) -> VecZnxBig { - let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_big(size, cols)); +impl VecZnxBigOps for Module { + fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig { + #[cfg(debug_assertions)] + { + assert!(cols > 0); + assert!(limbs > 0); + } + let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_big(cols, limbs)); let ptr: *mut u8 = data.as_mut_ptr(); - VecZnxBig { + VecZnxBig:: { data: data, ptr: ptr, n: self.n(), - size: size, - layout: LAYOUT::COL, cols: cols, - backend: self.backend(), + limbs: limbs, + _marker: PhantomData, } } - fn new_vec_znx_big_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes(self, size, cols, bytes) + fn new_vec_znx_big_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxBig { + VecZnxBig::from_bytes(self, cols, limbs, bytes) } - fn new_vec_znx_big_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes_borrow(self, size, cols, tmp_bytes) + fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { + VecZnxBig::from_bytes_borrow(self, cols, limbs, tmp_bytes) } - fn bytes_of_vec_znx_big(&self, size: usize, cols: usize) -> usize { - unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, cols as u64) as usize * size } + fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize { + unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, limbs as u64) as usize * cols } } - fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { + /// [VecZnxBig] (3 cols and 4 limbs) + /// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3] + /// - + /// [VecZnx] (2 cols and 3 limbs) + /// [d0, e0] [d1, e1] [d2, e2] + /// = + /// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3] + fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { unsafe { vec_znx_big::vec_znx_big_sub_small_a( self.ptr, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, a.as_ptr(), - a.cols() as u64, + a.poly_count() as u64, a.n() as u64, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, ) } } - fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { + fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_sub_small_a( self.ptr, c.ptr as *mut vec_znx_big_t, - c.cols() as u64, + c.poly_count() as u64, a.as_ptr(), - a.cols() as u64, + a.poly_count() as u64, a.n() as u64, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, ) } } - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { + fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_add_small( self.ptr, c.ptr as *mut vec_znx_big_t, - c.cols() as u64, + c.poly_count() as u64, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, a.as_ptr(), - a.cols() as u64, + a.poly_count() as u64, a.n() as u64, ) } } - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { + fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { unsafe { vec_znx_big::vec_znx_big_add_small( self.ptr, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, a.as_ptr(), - a.cols() as u64, + a.poly_count() as u64, a.n() as u64, ) } @@ -268,12 +286,12 @@ impl VecZnxBigOps for Module { unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize } } - fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) { + fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) { debug_assert!( - tmp_bytes.len() >= ::vec_znx_big_normalize_tmp_bytes(self), + tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(self), "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", tmp_bytes.len(), - ::vec_znx_big_normalize_tmp_bytes(self) + Self::vec_znx_big_normalize_tmp_bytes(self) ); #[cfg(debug_assertions)] { @@ -284,10 +302,10 @@ impl VecZnxBigOps for Module { self.ptr, log_base2k as u64, b.as_mut_ptr(), - b.cols() as u64, + b.limbs() as u64, b.n() as u64, a.ptr as *mut vec_znx_big_t, - a.cols() as u64, + a.limbs() as u64, tmp_bytes.as_mut_ptr(), ) } @@ -301,17 +319,17 @@ impl VecZnxBigOps for Module { &self, log_base2k: usize, res: &mut VecZnx, - a: &VecZnxBig, + a: &VecZnxBig, a_range_begin: usize, a_range_xend: usize, a_range_step: usize, tmp_bytes: &mut [u8], ) { debug_assert!( - tmp_bytes.len() >= ::vec_znx_big_range_normalize_base2k_tmp_bytes(self), + tmp_bytes.len() >= Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self), "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}", tmp_bytes.len(), - ::vec_znx_big_range_normalize_base2k_tmp_bytes(self) + Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self) ); #[cfg(debug_assertions)] { @@ -322,7 +340,7 @@ impl VecZnxBigOps for Module { self.ptr, log_base2k as u64, res.as_mut_ptr(), - res.cols() as u64, + res.limbs() as u64, res.n() as u64, a.ptr as *mut vec_znx_big_t, a_range_begin as u64, @@ -333,28 +351,28 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { + fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_automorphism( self.ptr, gal_el, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, a.ptr as *mut vec_znx_big_t, - a.cols() as u64, + a.poly_count() as u64, ); } } - fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { + fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_automorphism( self.ptr, gal_el, a.ptr as *mut vec_znx_big_t, - a.cols() as u64, + a.poly_count() as u64, a.ptr as *mut vec_znx_big_t, - a.cols() as u64, + a.poly_count() as u64, ); } } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index b512fd8..ec4067f 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,136 +1,139 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; -use crate::{BACKEND, Infos, LAYOUT, Module, VecZnxBig, assert_alignement}; +use crate::{Backend, FFT64, Infos, Module, VecZnxBig, assert_alignement}; use crate::{DEFAULTALIGN, VecZnx, alloc_aligned}; +use std::marker::PhantomData; -pub struct VecZnxDft { +pub struct VecZnxDft { pub data: Vec, pub ptr: *mut u8, pub n: usize, - pub size: usize, - pub layout: LAYOUT, pub cols: usize, - pub backend: BACKEND, + pub limbs: usize, + pub _marker: PhantomData, } -impl VecZnxDft { +impl VecZnxDft { + pub fn new(module: &Module, cols: usize, limbs: usize) -> Self { + let mut data: Vec = alloc_aligned::(module.bytes_of_vec_znx_dft(cols, limbs)); + let ptr: *mut u8 = data.as_mut_ptr(); + Self { + data: data, + ptr: ptr, + n: module.n(), + limbs: limbs, + cols: cols, + _marker: PhantomData, + } + } /// Returns a new [VecZnxDft] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. - pub fn from_bytes(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft { + pub fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(size, cols)); + assert!(cols > 0); + assert!(limbs > 0); + assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols, limbs)); assert_alignement(bytes.as_ptr()) } unsafe { - VecZnxDft { + Self { data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), ptr: bytes.as_mut_ptr(), n: module.n(), - size: size, - layout: LAYOUT::COL, cols: cols, - backend: module.backend, + limbs: limbs, + _marker: PhantomData, } } } - pub fn from_bytes_borrow(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft { + pub fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(size, cols)); + assert!(cols > 0); + assert!(limbs > 0); + assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols, limbs)); assert_alignement(bytes.as_ptr()); } - VecZnxDft { + Self { data: Vec::new(), ptr: bytes.as_mut_ptr(), n: module.n(), - size: size, - layout: LAYOUT::COL, cols: cols, - backend: module.backend, + limbs: limbs, + _marker: PhantomData, } } /// Cast a [VecZnxDft] into a [VecZnxBig]. /// The returned [VecZnxBig] shares the backing array /// with the original [VecZnxDft]. - pub fn as_vec_znx_big(&mut self) -> VecZnxBig { - VecZnxBig { + pub fn as_vec_znx_big(&mut self) -> VecZnxBig { + VecZnxBig:: { data: Vec::new(), ptr: self.ptr, n: self.n, - layout: LAYOUT::COL, - size: self.size, cols: self.cols, - backend: self.backend, + limbs: self.limbs, + _marker: PhantomData, } } - pub fn backend(&self) -> BACKEND { - self.backend + pub fn raw(&self) -> &[f64] { + let ptr: *mut f64 = self.ptr as *mut f64; + let size: usize = self.n() * self.poly_count(); + unsafe { &std::slice::from_raw_parts(ptr, size) } } - /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is cols * n. - pub fn raw(&self, module: &Module) -> &[T] { - let ptr: *const T = self.ptr as *const T; - let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::(); - unsafe { &std::slice::from_raw_parts(ptr, len) } + pub fn at(&self, col_i: usize) -> &[f64] { + &self.raw()[col_i * self.n() * self.limbs()..(col_i + 1) * self.n() * self.limbs()] } - pub fn at(&self, module: &Module, col_i: usize) -> &[T] { - &self.raw::(module)[col_i * module.n()..(col_i + 1) * module.n()] + pub fn raw_mut(&mut self) -> &mut [f64] { + let ptr: *mut f64 = self.ptr as *mut f64; + let size: usize = self.n() * self.poly_count(); + unsafe { std::slice::from_raw_parts_mut(ptr, size) } } - /// Returns a mutable reference of `T` of the entire contiguous array of the [VecZnxDft]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is cols * n. - pub fn raw_mut(&self, module: &Module) -> &mut [T] { - let ptr: *mut T = self.ptr as *mut T; - let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::(); - unsafe { std::slice::from_raw_parts_mut(ptr, len) } - } - - pub fn at_mut(&self, module: &Module, col_i: usize) -> &mut [T] { - &mut self.raw_mut::(module)[col_i * module.n()..(col_i + 1) * module.n()] + pub fn at_mut(&mut self, col_i: usize) -> &mut [f64] { + let n: usize = self.n(); + let limbs:usize = self.limbs(); + &mut self.raw_mut()[col_i * n * limbs..(col_i + 1) * n * limbs] } } -impl Infos for VecZnxDft { - /// Returns the base 2 logarithm of the [VecZnx] degree. - fn log_n(&self) -> usize { - (usize::BITS - (self.n - 1).leading_zeros()) as _ - } - - /// Returns the [VecZnx] degree. +impl Infos for VecZnxDft { fn n(&self) -> usize { self.n } - fn layout(&self) -> LAYOUT { - self.layout + fn log_n(&self) -> usize { + (usize::BITS - (self.n() - 1).leading_zeros()) as _ + } + + fn rows(&self) -> usize { + 1 } - /// Returns the number of cols of the [VecZnx]. fn cols(&self) -> usize { self.cols } - /// Returns the number of rows of the [VecZnx]. - fn rows(&self) -> usize { - 1 + fn limbs(&self) -> usize { + self.limbs + } + + fn poly_count(&self) -> usize { + self.cols * self.limbs } } -pub trait VecZnxDftOps { +pub trait VecZnxDftOps { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. - fn new_vec_znx_dft(&self, size: usize, cols: usize) -> VecZnxDft; + fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -143,7 +146,7 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft; + fn new_vec_znx_dft_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -156,7 +159,7 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes_borrow(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft; + fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -167,61 +170,51 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn bytes_of_vec_znx_dft(&self, size: usize, cols: usize) -> usize; + fn bytes_of_vec_znx_dft(&self, cols: usize, limbs: usize) -> usize; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. fn vec_znx_idft_tmp_bytes(&self) -> usize; /// b <- IDFT(a), uses a as scratch space. - fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft); + fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft); - fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]); + fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]); - fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx); + fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx); - fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft); + fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft); - fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]); + fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]); fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize; } -impl VecZnxDftOps for Module { - fn new_vec_znx_dft(&self, size: usize, cols: usize) -> VecZnxDft { - let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_dft(size, cols)); - let ptr: *mut u8 = data.as_mut_ptr(); - VecZnxDft { - data: data, - ptr: ptr, - n: self.n(), - size: size, - layout: LAYOUT::COL, - cols: cols, - backend: self.backend(), - } +impl VecZnxDftOps for Module { + fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft { + VecZnxDft::::new(&self, cols, limbs) } - fn new_vec_znx_dft_from_bytes(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes(self, size, cols, tmp_bytes) + fn new_vec_znx_dft_from_bytes(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { + VecZnxDft::from_bytes(self, cols, limbs, tmp_bytes) } - fn new_vec_znx_dft_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes_borrow(self, size, cols, tmp_bytes) + fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { + VecZnxDft::from_bytes_borrow(self, cols, limbs, tmp_bytes) } - fn bytes_of_vec_znx_dft(&self, size: usize, cols: usize) -> usize { - unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize * size } + fn bytes_of_vec_znx_dft(&self, cols: usize, limbs: usize) -> usize { + unsafe { bytes_of_vec_znx_dft(self.ptr, limbs as u64) as usize * cols } } - fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) { + fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) { unsafe { vec_znx_dft::vec_znx_idft_tmp_a( self.ptr, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, a.ptr as *mut vec_znx_dft_t, - a.cols() as u64, + a.poly_count() as u64, ) } } @@ -234,21 +227,21 @@ impl VecZnxDftOps for Module { /// /// # Panics /// If b.cols < a_cols - fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx) { + fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx) { unsafe { vec_znx_dft::vec_znx_dft( self.ptr, b.ptr as *mut vec_znx_dft_t, - b.cols() as u64, + b.limbs() as u64, a.as_ptr(), - a.cols() as u64, - a.n() as u64, + a.limbs() as u64, + (a.n() * a.cols()) as u64, ) } } // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. - fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]) { + fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] { assert!( @@ -263,29 +256,29 @@ impl VecZnxDftOps for Module { vec_znx_dft::vec_znx_idft( self.ptr, b.ptr as *mut vec_znx_big_t, - b.cols() as u64, + b.poly_count() as u64, a.ptr as *const vec_znx_dft_t, - a.cols() as u64, + a.poly_count() as u64, tmp_bytes.as_mut_ptr(), ) } } - fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft) { + fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft) { unsafe { vec_znx_dft::vec_znx_dft_automorphism( self.ptr, k, b.ptr as *mut vec_znx_dft_t, - b.cols() as u64, + b.poly_count() as u64, a.ptr as *const vec_znx_dft_t, - a.cols() as u64, + a.poly_count() as u64, [0u8; 0].as_mut_ptr(), ); } } - fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]) { + fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] { assert!( @@ -301,9 +294,9 @@ impl VecZnxDftOps for Module { self.ptr, k, a.ptr as *mut vec_znx_dft_t, - a.cols() as u64, + a.poly_count() as u64, a.ptr as *const vec_znx_dft_t, - a.cols() as u64, + a.poly_count() as u64, tmp_bytes.as_mut_ptr(), ); } @@ -321,41 +314,47 @@ impl VecZnxDftOps for Module { #[cfg(test)] mod tests { - use crate::{BACKEND, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned}; + use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned}; use itertools::izip; use sampling::source::{Source, new_seed}; #[test] fn test_automorphism_dft() { - let module: Module = Module::new(128, BACKEND::FFT64); + let module: Module = Module::::new(128); - let cols: usize = 2; + let limbs: usize = 2; let log_base2k: usize = 17; - let mut a: VecZnx = module.new_vec_znx(1, cols); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, cols); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, cols); + let mut a: VecZnx = module.new_vec_znx(1, limbs); + let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs); + let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs); let mut source: Source = Source::new(new_seed()); - module.fill_uniform(log_base2k, &mut a, cols, &mut source); + module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); let mut tmp_bytes: Vec = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes()); let p: i64 = -5; + + // a_dft <- DFT(a) module.vec_znx_dft(&mut a_dft, &a); + + // a_dft <- AUTO(a_dft) module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes); + println!("123"); + // a <- AUTO(a) module.vec_znx_automorphism_inplace(p, &mut a); // b_dft <- DFT(AUTO(a)) module.vec_znx_dft(&mut b_dft, &a); - let a_f64: &[f64] = a_dft.raw(&module); - let b_f64: &[f64] = b_dft.raw(&module); + let a_f64: &[f64] = a_dft.raw(); + let b_f64: &[f64] = b_dft.raw(); izip!(a_f64.iter(), b_f64.iter()).for_each(|(ai, bi)| { assert!((ai - bi).abs() <= 1e-9, "{:+e} > 1e-9", (ai - bi).abs()); }); diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index b04232d..05dd027 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -1,7 +1,8 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp::{self, vmp_pmat_t}; -use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement}; +use std::marker::PhantomData; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], /// stored as a 3D matrix in the DFT domain in a single contiguous array. @@ -9,28 +10,23 @@ use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_ /// /// [VmpPMat] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [VmpPMat]. /// See the trait [VmpPMatOps] for additional information. -pub struct VmpPMat { +pub struct VmpPMat { /// Raw data, is empty if borrowing scratch space. data: Vec, /// Pointer to data. Can point to scratch space. ptr: *mut u8, - /// The size of the decomposition basis (i.e. nb. [VecZnxDft]). - rows: usize, - /// The size of each [VecZnxDft]. - cols: usize, - /// The ring degree of each [VecZnxDft]. + /// The ring degree of each polynomial. n: usize, - /// 1nd dim: the number of stacked [VecZnxDft] per decomposition basis (row-dimension). - /// A value greater than one enables to compute a sum of [VecZnx] x [VmpPMat]. - /// 2st dim: the number of stacked [VecZnxDft] (col-dimension). - /// A value greater than one enables to compute multiple [VecZnx] x [VmpPMat] in parallel. - layout: LAYOUT, - /// The backend fft or ntt. - backend: BACKEND, + /// Number of rows + rows: usize, + /// Number of cols + cols: usize, + /// The number of small polynomials + limbs: usize, + _marker: PhantomData, } -impl Infos for VmpPMat { - /// Returns the ring dimension of the [VmpPMat]. +impl Infos for VmpPMat { fn n(&self) -> usize { self.n } @@ -39,29 +35,39 @@ impl Infos for VmpPMat { (usize::BITS - (self.n() - 1).leading_zeros()) as _ } - fn size(&self) -> usize { - self.size - } - - fn layout(&self) -> LAYOUT { - self.layout - } - - /// Returns the number of rows (i.e. of [VecZnxDft]) of the [VmpPMat] fn rows(&self) -> usize { self.rows } - /// Returns the number of cols of the [VmpPMat]. - /// The number of cols refers to the number of cols - /// of each [VecZnxDft]. - /// This method is equivalent to [Self::cols]. fn cols(&self) -> usize { self.cols } + + fn limbs(&self) -> usize { + self.limbs + } + + fn poly_count(&self) -> usize { + self.rows * self.cols * self.limbs + } } -impl VmpPMat { +impl VmpPMat { + + fn new(module: &Module, rows: usize, cols: usize, limbs: usize) -> VmpPMat { + let mut data: Vec = alloc_aligned::(module.bytes_of_vmp_pmat(rows, cols, limbs)); + let ptr: *mut u8 = data.as_mut_ptr(); + VmpPMat:: { + data: data, + ptr: ptr, + n: module.n(), + rows: rows, + cols: cols, + limbs: limbs, + _marker: PhantomData, + } + } + pub fn as_ptr(&self) -> *const u8 { self.ptr } @@ -74,41 +80,31 @@ impl VmpPMat { self.data.len() == 0 } - /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is rows * cols * n. - pub fn raw(&self) -> &[T] { - let ptr: *const T = self.ptr as *const T; - let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); - unsafe { &std::slice::from_raw_parts(ptr, len) } + /// Returns a non-mutable reference to the entire contiguous array of the [VmpPMat]. + pub fn raw(&self) -> &[f64] { + let ptr: *const f64 = self.ptr as *const f64; + let size: usize = self.n() * self.poly_count(); + unsafe { &std::slice::from_raw_parts(ptr, size) } } - /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is rows * cols * n. - pub fn raw_mut(&self) -> &mut [T] { - let ptr: *mut T = self.ptr as *mut T; - let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); - unsafe { std::slice::from_raw_parts_mut(ptr, len) } + /// Returns a mutable reference of to the entire contiguous array of the [VmpPMat]. + pub fn raw_mut(&self) -> &mut [f64] { + let ptr: *mut f64 = self.ptr as *mut f64; + let size: usize = self.n() * self.poly_count(); + unsafe { std::slice::from_raw_parts_mut(ptr, size) } } /// Returns a copy of the backend array at index (i, j) of the [VmpPMat]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. /// /// # Arguments /// /// * `row`: row index (i). /// * `col`: col index (j). - pub fn at(&self, row: usize, col: usize) -> Vec { - let mut res: Vec = alloc_aligned(self.n); + pub fn at(&self, row: usize, col: usize) -> Vec { + let mut res: Vec = alloc_aligned(self.n); if self.n < 8 { - res.copy_from_slice( - &self.raw::()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)], - ); + res.copy_from_slice(&self.raw()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)]); } else { (0..self.n >> 3).for_each(|blk| { res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); @@ -118,43 +114,37 @@ impl VmpPMat { res } - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - fn at_block(&self, row: usize, col: usize, blk: usize) -> &[T] { + fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { let nrows: usize = self.rows(); - let ncols: usize = self.cols(); - if col == (ncols - 1) && (ncols & 1 == 1) { - &self.raw::()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..] + let nsize: usize = self.limbs(); + if col == (nsize - 1) && (nsize & 1 == 1) { + &self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..] } else { - &self.raw::()[blk * nrows * ncols * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] + &self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] } } - - fn backend(&self) -> BACKEND { - self.backend - } } /// This trait implements methods for vector matrix product, /// that is, multiplying a [VecZnx] with a [VmpPMat]. -pub trait VmpPMatOps { - fn bytes_of_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> usize; +pub trait VmpPMatOps { + fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> usize; /// Allocates a new [VmpPMat] with the given number of rows and columns. /// /// # Arguments /// /// * `rows`: number of rows (number of [VecZnxDft]). - /// * `cols`: number of cols (number of cols of each [VecZnxDft]). - fn new_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> VmpPMat; + /// * `size`: number of size (number of size of each [VecZnxDft]). + fn new_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> VmpPMat; /// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous]. /// /// # Arguments /// /// * `rows`: number of rows of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. - /// * `cols`: number of cols of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. - fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize) -> usize; + /// * `size`: number of size of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. + fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize; /// Prepares a [VmpPMat] from a contiguous array of [i64]. /// The helper struct [Matrix3D] can be used to contruct and populate @@ -165,18 +155,7 @@ pub trait VmpPMatOps { /// * `b`: [VmpPMat] on which the values are encoded. /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat]. /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]); - - /// Prepares a [VmpPMat] from a vector of [VecZnx]. - /// - /// # Arguments - /// - /// * `b`: [VmpPMat] on which the values are encoded. - /// * `a`: the vector of [VecZnx] to encode on the [VmpPMat]. - /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - /// - /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]); + fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]); /// Prepares the ith-row of [VmpPMat] from a [VecZnx]. /// @@ -188,7 +167,7 @@ pub trait VmpPMatOps { /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); + fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); /// Extracts the ith-row of [VmpPMat] into a [VecZnxBig]. /// @@ -197,7 +176,7 @@ pub trait VmpPMatOps { /// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat]. /// * `a`: [VmpPMat] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize); + fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize); /// Prepares the ith-row of [VmpPMat] from a [VecZnxDft]. /// @@ -208,7 +187,7 @@ pub trait VmpPMatOps { /// * `row_i`: the index of the row to prepare. /// /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize); + fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize); /// Extracts the ith-row of [VmpPMat] into a [VecZnxDft]. /// @@ -217,17 +196,17 @@ pub trait VmpPMatOps { /// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat]. /// * `a`: [VmpPMat] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize); + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize); /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. /// /// # Arguments /// - /// * `c_cols`: number of cols of the output [VecZnxDft]. - /// * `a_cols`: number of cols of the input [VecZnx]. + /// * `c_size`: number of size of the output [VecZnxDft]. + /// * `a_size`: number of size of the input [VecZnx]. /// * `rows`: number of rows of the input [VmpPMat]. - /// * `cols`: number of cols of the input [VmpPMat]. - fn vmp_apply_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize; + /// * `size`: number of size of the input [VmpPMat]. + fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; /// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. /// @@ -235,8 +214,8 @@ pub trait VmpPMatOps { /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. + /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. /// @@ -253,7 +232,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver. /// @@ -261,8 +240,8 @@ pub trait VmpPMatOps { /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. + /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. /// @@ -279,17 +258,17 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft]. /// /// # Arguments /// - /// * `c_cols`: number of cols of the output [VecZnxDft]. - /// * `a_cols`: number of cols of the input [VecZnxDft]. + /// * `c_size`: number of size of the output [VecZnxDft]. + /// * `a_size`: number of size of the input [VecZnxDft]. /// * `rows`: number of rows of the input [VmpPMat]. - /// * `cols`: number of cols of the input [VmpPMat]. - fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize; + /// * `size`: number of size of the input [VmpPMat]. + fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; /// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -298,8 +277,8 @@ pub trait VmpPMatOps { /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. + /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. /// @@ -316,7 +295,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it. /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -325,8 +304,8 @@ pub trait VmpPMatOps { /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. + /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. /// @@ -343,7 +322,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place. /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -352,8 +331,8 @@ pub trait VmpPMatOps { /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. + /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. /// @@ -369,38 +348,29 @@ pub trait VmpPMatOps { /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. /// * `a`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]); } -impl VmpPMatOps for Module { - fn bytes_of_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> usize { - unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, cols as u64) as usize * size } +impl VmpPMatOps for Module { + + fn new_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> VmpPMat { + VmpPMat::::new(self, rows, cols, limbs) } - fn new_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> VmpPMat { - let mut data: Vec = alloc_aligned::(self.bytes_of_vmp_pmat(size, rows, cols)); - let ptr: *mut u8 = data.as_mut_ptr(); - VmpPMat { - data: data, - ptr: ptr, - n: self.n(), - size: size, - layout: LAYOUT::COL, - cols: cols, - rows: rows, - backend: self.backend(), - } + fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> usize { + unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs* cols) as u64) as usize } } - fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize) -> usize { - unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, cols as u64) as usize } + fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize { + unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize } } - fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) { - debug_assert_eq!(a.len(), b.n * b.rows * b.cols); - debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols())); + fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) { + #[cfg(debug_assertions)] { + assert_eq!(a.len(), b.n() * b.poly_count()); + assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs())); assert_alignement(tmp_bytes.as_ptr()); } unsafe { @@ -409,40 +379,17 @@ impl VmpPMatOps for Module { b.as_mut_ptr() as *mut vmp_pmat_t, a.as_ptr(), b.rows() as u64, - b.cols() as u64, + (b.limbs()*b.cols()) as u64, tmp_bytes.as_mut_ptr(), ); } } - fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], tmp_bytes: &mut [u8]) { - let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect(); + fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] - { - debug_assert_eq!(a.len(), b.rows); - a.iter().for_each(|ai| { - debug_assert_eq!(ai.len(), b.n * b.cols); - }); - debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols())); - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_prepare_dblptr( - self.ptr, - b.as_mut_ptr() as *mut vmp_pmat_t, - ptrs.as_ptr(), - b.rows() as u64, - b.cols() as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { - #[cfg(debug_assertions)] - { - assert_eq!(a.len(), b.cols() * self.n()); - assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols())); + { + assert_eq!(a.len(), b.limbs() * self.n() * b.cols()); + assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs())); assert_alignement(tmp_bytes.as_ptr()); } unsafe { @@ -452,16 +399,17 @@ impl VmpPMatOps for Module { a.as_ptr(), row_i as u64, b.rows() as u64, - b.cols() as u64, + (b.limbs()*b.cols()) as u64, tmp_bytes.as_mut_ptr(), ); } } - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize) { + fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); + assert_eq!(a.limbs(), b.limbs()); assert_eq!(a.cols(), b.cols()); } unsafe { @@ -471,16 +419,16 @@ impl VmpPMatOps for Module { a.as_ptr() as *const vmp_pmat_t, row_i as u64, a.rows() as u64, - a.cols() as u64, + (a.limbs()*a.cols()) as u64, ); } } - fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize) { + fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); - assert_eq!(a.cols(), b.cols()); + assert_eq!(a.limbs(), b.limbs()); } unsafe { vmp::vmp_prepare_row_dft( @@ -489,16 +437,16 @@ impl VmpPMatOps for Module { a.ptr as *const vec_znx_dft_t, row_i as u64, b.rows() as u64, - b.cols() as u64, + b.limbs() as u64, ); } } - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize) { + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); - assert_eq!(a.cols(), b.cols()); + assert_eq!(a.limbs(), b.limbs()); } unsafe { vmp::vmp_extract_row_dft( @@ -507,48 +455,47 @@ impl VmpPMatOps for Module { a.as_ptr() as *const vmp_pmat_t, row_i as u64, a.rows() as u64, - a.cols() as u64, + a.limbs() as u64, ); } } - fn vmp_apply_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize { + fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize { unsafe { vmp::vmp_apply_dft_tmp_bytes( self.ptr, - res_cols as u64, - a_cols as u64, + res_size as u64, + a_size as u64, gct_rows as u64, - gct_cols as u64, + gct_size as u64, ) as usize } } - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())); + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); - assert_eq!(a.size()*a.size(), b.size()); } unsafe { vmp::vmp_apply_dft( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, + c.limbs() as u64, a.as_ptr(), - a.cols() as u64, - (a.n()*a.size()) as u64, + a.limbs() as u64, + (a.n() * a.cols()) as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.cols() as u64, + b.limbs() as u64, tmp_bytes.as_mut_ptr(), ) } } - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())); + fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -557,32 +504,32 @@ impl VmpPMatOps for Module { vmp::vmp_apply_dft_add( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, + c.limbs() as u64, a.as_ptr(), - a.cols() as u64, - (a.n()*a.size()) as u64, + a.limbs() as u64, + (a.n() * a.limbs()) as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.cols() as u64, + b.limbs() as u64, tmp_bytes.as_mut_ptr(), ) } } - fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize { + fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize { unsafe { vmp::vmp_apply_dft_to_dft_tmp_bytes( self.ptr, - res_cols as u64, - a_cols as u64, + res_size as u64, + a_size as u64, gct_rows as u64, - gct_cols as u64, + gct_size as u64, ) as usize } } - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())); + fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -591,19 +538,19 @@ impl VmpPMatOps for Module { vmp::vmp_apply_dft_to_dft( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, + c.limbs() as u64, a.ptr as *const vec_znx_dft_t, - a.cols() as u64, + a.limbs() as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.cols() as u64, + b.limbs() as u64, tmp_bytes.as_mut_ptr(), ) } } - fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())); + fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -612,19 +559,19 @@ impl VmpPMatOps for Module { vmp::vmp_apply_dft_to_dft_add( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, + c.limbs() as u64, a.ptr as *const vec_znx_dft_t, - a.cols() as u64, + a.limbs() as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.cols() as u64, + b.limbs() as u64, tmp_bytes.as_mut_ptr(), ) } } - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols())); + fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.limbs(), b.limbs(), a.rows(), a.limbs())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -633,12 +580,12 @@ impl VmpPMatOps for Module { vmp::vmp_apply_dft_to_dft( self.ptr, b.ptr as *mut vec_znx_dft_t, - b.cols() as u64, + b.limbs() as u64, b.ptr as *mut vec_znx_dft_t, - b.cols() as u64, + b.limbs() as u64, a.as_ptr() as *const vmp_pmat_t, a.rows() as u64, - a.cols() as u64, + a.limbs() as u64, tmp_bytes.as_mut_ptr(), ) } @@ -648,44 +595,45 @@ impl VmpPMatOps for Module { #[cfg(test)] mod tests { use crate::{ - Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, alloc_aligned, + FFT64, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, + alloc_aligned, }; use sampling::source::Source; #[test] fn vmp_prepare_row_dft() { - let module: Module = Module::new(32, crate::BACKEND::FFT64); + let module: Module = Module::::new(32); let vpmat_rows: usize = 4; - let vpmat_cols: usize = 5; + let vpmat_size: usize = 5; let log_base2k: usize = 8; - let mut a: VecZnx = module.new_vec_znx(1, vpmat_cols); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_cols); - let mut a_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_cols); - let mut b_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_cols); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_cols); - let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(1, vpmat_rows, vpmat_cols); - let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(1, vpmat_rows, vpmat_cols); + let mut a: VecZnx = module.new_vec_znx(1, vpmat_size); + let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_size); + let mut a_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); + let mut b_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); + let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_size); + let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, 1, vpmat_size); + let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, 1, vpmat_size); - let mut tmp_bytes: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols)); + let mut tmp_bytes: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, 1, vpmat_size)); for row_i in 0..vpmat_rows { let mut source: Source = Source::new([0u8; 32]); - module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source); + module.fill_uniform(log_base2k, &mut a, 0, vpmat_size, &mut source); module.vec_znx_dft(&mut a_dft, &a); module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes); // Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft) module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i); - assert_eq!(vmpmat_0.raw::(), vmpmat_1.raw::()); + assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); // Checks that a_dft = extract_dft(prepare(vmp_pmat, a), b_dft) module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i); - assert_eq!(a_dft.raw::(&module), b_dft.raw::(&module)); + assert_eq!(a_dft.raw(), b_dft.raw()); // Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big) module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes); - assert_eq!(a_big.raw::(&module), b_big.raw::(&module)); + assert_eq!(a_big.raw(), b_big.raw()); } module.free(); diff --git a/rlwe/src/ciphertext.rs b/rlwe/src/ciphertext.rs index 9d1fe1a..73addb5 100644 --- a/rlwe/src/ciphertext.rs +++ b/rlwe/src/ciphertext.rs @@ -1,6 +1,6 @@ use crate::elem::{Elem, ElemCommon}; use crate::parameters::Parameters; -use base2k::{Infos, LAYOUT, Module, VecZnx, VmpPMat}; +use base2k::{Infos, Layout, Module, VecZnx, VmpPMat}; pub struct Ciphertext(pub Elem); @@ -38,7 +38,7 @@ where self.elem().size() } - fn layout(&self) -> LAYOUT { + fn layout(&self) -> Layout { self.elem().layout() } diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index e7e61c4..656cc3a 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,4 +1,4 @@ -use base2k::{Infos, LAYOUT, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps}; +use base2k::{Infos, Layout, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps}; pub struct Elem { pub value: Vec, @@ -71,7 +71,7 @@ pub trait ElemCommon { fn elem(&self) -> &Elem; fn elem_mut(&mut self) -> &mut Elem; fn size(&self) -> usize; - fn layout(&self) -> LAYOUT; + fn layout(&self) -> Layout; fn rows(&self) -> usize; fn cols(&self) -> usize; fn log_base2k(&self) -> usize; @@ -102,7 +102,7 @@ impl ElemCommon for Elem { self.value.len() } - fn layout(&self) -> LAYOUT { + fn layout(&self) -> Layout { self.value[0].layout() } diff --git a/rlwe/src/plaintext.rs b/rlwe/src/plaintext.rs index 86f7e32..258756b 100644 --- a/rlwe/src/plaintext.rs +++ b/rlwe/src/plaintext.rs @@ -1,7 +1,7 @@ use crate::ciphertext::Ciphertext; use crate::elem::{Elem, ElemCommon, ElemVecZnx}; use crate::parameters::Parameters; -use base2k::{LAYOUT, Module, VecZnx}; +use base2k::{Layout, Module, VecZnx}; pub struct Plaintext(pub Elem); @@ -79,7 +79,7 @@ impl ElemCommon for Plaintext { self.elem().size() } - fn layout(&self) -> LAYOUT { + fn layout(&self) -> Layout { self.elem().layout() } From 82082db727c0794f49d92927ad553894b3ab41ad Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sat, 26 Apr 2025 11:23:47 +0200 Subject: [PATCH 04/87] improved alligned vec allocation & fixed vec_znx calls, fixed auto dft test --- base2k/examples/rlwe_encrypt.rs | 4 -- base2k/src/lib.rs | 10 +++- base2k/src/vec_znx.rs | 54 ++++++++++---------- base2k/src/vec_znx_big.rs | 8 ++- base2k/src/vec_znx_dft.rs | 89 ++++++++++++++++++++++++--------- base2k/src/vmp.rs | 15 +++--- 6 files changed, 113 insertions(+), 67 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 1da44e9..5385a5b 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -34,8 +34,6 @@ fn main() { let mut a: VecZnx = module.new_vec_znx(1, limbs); module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); - - // Scratch space for DFT values let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.limbs()); @@ -93,8 +91,6 @@ fn main() { // res <- normalize(buf_big) module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry); - - // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; res.decode_vec_i64(0, log_base2k, res.limbs() * log_base2k, &mut have); diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 5144afd..83c937a 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -106,6 +106,14 @@ pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { unsafe { Vec::from_raw_parts(ptr, len, cap) } } +// Allocates an aligned of size equal to the smallest power of two equal or greater to `size` that is +// at least as bit as DEFAULTALIGN / std::mem::size_of::(). pub fn alloc_aligned(size: usize) -> Vec { - alloc_aligned_custom::(size, DEFAULTALIGN) + alloc_aligned_custom::( + std::cmp::max( + size.next_power_of_two(), + DEFAULTALIGN / std::mem::size_of::(), + ), + DEFAULTALIGN, + ) } diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index aff1ce9..a6d5858 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -478,13 +478,13 @@ impl VecZnxOps for Module { self.ptr, c.as_mut_ptr(), c.limbs() as u64, - (n * c.limbs()) as u64, + (n * c.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, b.as_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, ) } } @@ -502,13 +502,13 @@ impl VecZnxOps for Module { self.ptr, b.as_mut_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, b.as_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, ) } } @@ -527,13 +527,13 @@ impl VecZnxOps for Module { self.ptr, c.as_mut_ptr(), c.limbs() as u64, - (n * c.limbs()) as u64, + (n * c.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, b.as_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, ) } } @@ -551,13 +551,13 @@ impl VecZnxOps for Module { self.ptr, b.as_mut_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, b.as_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, ) } } @@ -575,13 +575,13 @@ impl VecZnxOps for Module { self.ptr, b.as_mut_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, b.as_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, ) } } @@ -598,10 +598,10 @@ impl VecZnxOps for Module { self.ptr, b.as_mut_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, ) } } @@ -617,10 +617,10 @@ impl VecZnxOps for Module { self.ptr, a.as_mut_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, ) } } @@ -638,10 +638,10 @@ impl VecZnxOps for Module { k, b.as_mut_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, ) } } @@ -658,10 +658,10 @@ impl VecZnxOps for Module { k, a.as_mut_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, ) } } @@ -691,10 +691,10 @@ impl VecZnxOps for Module { k, b.as_mut_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, ); } } @@ -722,10 +722,10 @@ impl VecZnxOps for Module { k, a.as_mut_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, ); } } diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index b19f126..a7bdd59 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -73,7 +73,13 @@ impl VecZnxBig { // Prints the first `n` coefficients of each limb pub fn print(&self, n: usize) { let raw: &[i64] = self.raw(); - (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &raw[i * self.n() * self.cols()..i * self.n() * self.cols()+n])) + (0..self.limbs()).for_each(|i| { + println!( + "{}: {:?}", + i, + &raw[i * self.n() * self.cols()..i * self.n() * self.cols() + n] + ) + }) } } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index ec4067f..61c2a85 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -82,14 +82,18 @@ impl VecZnxDft { } } - pub fn raw(&self) -> &[f64] { - let ptr: *mut f64 = self.ptr as *mut f64; - let size: usize = self.n() * self.poly_count(); - unsafe { &std::slice::from_raw_parts(ptr, size) } + /// Returns a non-mutable pointer to the backedn slice of the receiver. + pub fn as_ptr(&self) -> *const f64 { + self.ptr as *const f64 } - pub fn at(&self, col_i: usize) -> &[f64] { - &self.raw()[col_i * self.n() * self.limbs()..(col_i + 1) * self.n() * self.limbs()] + /// Returns a mutable pointer to the backedn slice of the receiver. + pub fn as_mut_ptr(&mut self) -> *mut f64 { + self.ptr as *mut f64 + } + + pub fn raw(&self) -> &[f64] { + unsafe { &std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } } pub fn raw_mut(&mut self) -> &mut [f64] { @@ -98,10 +102,54 @@ impl VecZnxDft { unsafe { std::slice::from_raw_parts_mut(ptr, size) } } - pub fn at_mut(&mut self, col_i: usize) -> &mut [f64] { - let n: usize = self.n(); - let limbs:usize = self.limbs(); - &mut self.raw_mut()[col_i * n * limbs..(col_i + 1) * n * limbs] + pub fn at_ptr(&self, i: usize, j: usize) -> *const f64 { + #[cfg(debug_assertions)] + { + assert!(i < self.cols()); + assert!(j < self.limbs()); + } + let offset: usize = self.n * (j * self.cols() + i); + self.as_ptr().wrapping_add(offset) + } + + /// Returns a non-mutable reference to the i-th limb. + /// The returned array is of size [Self::n()] * [Self::cols()]. + pub fn at_limb(&self, i: usize) -> &[f64] { + unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) } + } + + /// Returns a non-mutable reference to the (i, j)-th poly. + /// The returned array is of size [Self::n()]. + pub fn at_poly(&self, i: usize, j: usize) -> &[f64] { + unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) } + } + + /// Returns a mutable pointer starting a the (i, j)-th small poly. + pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut f64 { + #[cfg(debug_assertions)] + { + assert!(i < self.cols()); + assert!(j < self.limbs()); + } + + let offset: usize = self.n * (j * self.cols() + i); + self.as_mut_ptr().wrapping_add(offset) + } + + /// Returns a mutable reference to the i-th limb. + /// The returned array is of size [Self::n()] * [Self::cols()]. + pub fn at_limb_mut(&mut self, i: usize) -> &mut [f64] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) } + } + + /// Returns a mutable reference to the (i, j)-th poly. + /// The returned array is of size [Self::n()]. + pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [f64] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) } + } + + pub fn print(&self, n: usize) { + (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); } } @@ -289,6 +337,7 @@ impl VecZnxDftOps for Module { ); assert_alignement(tmp_bytes.as_ptr()) } + println!("{}", a.poly_count()); unsafe { vec_znx_dft::vec_znx_dft_automorphism( self.ptr, @@ -303,12 +352,7 @@ impl VecZnxDftOps for Module { } fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize { - unsafe { - std::cmp::max( - vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize, - DEFAULTALIGN, - ) - } + unsafe { vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize } } } @@ -316,11 +360,12 @@ impl VecZnxDftOps for Module { mod tests { use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned}; use itertools::izip; - use sampling::source::{Source, new_seed}; + use sampling::source::Source; #[test] fn test_automorphism_dft() { - let module: Module = Module::::new(128); + let n: usize = 8; + let module: Module = Module::::new(n); let limbs: usize = 2; let log_base2k: usize = 17; @@ -328,25 +373,19 @@ mod tests { let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs); let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs); - let mut source: Source = Source::new(new_seed()); + let mut source: Source = Source::new([0u8; 32]); module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); let mut tmp_bytes: Vec = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes()); let p: i64 = -5; - - // a_dft <- DFT(a) module.vec_znx_dft(&mut a_dft, &a); - - // a_dft <- AUTO(a_dft) module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes); - println!("123"); - // a <- AUTO(a) module.vec_znx_automorphism_inplace(p, &mut a); diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index 05dd027..f868a06 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -53,7 +53,6 @@ impl Infos for VmpPMat { } impl VmpPMat { - fn new(module: &Module, rows: usize, cols: usize, limbs: usize) -> VmpPMat { let mut data: Vec = alloc_aligned::(module.bytes_of_vmp_pmat(rows, cols, limbs)); let ptr: *mut u8 = data.as_mut_ptr(); @@ -352,21 +351,19 @@ pub trait VmpPMatOps { } impl VmpPMatOps for Module { - fn new_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> VmpPMat { VmpPMat::::new(self, rows, cols, limbs) } fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> usize { - unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs* cols) as u64) as usize } + unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs * cols) as u64) as usize } } fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize { - unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize } + unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize } } fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) { - #[cfg(debug_assertions)] { assert_eq!(a.len(), b.n() * b.poly_count()); @@ -379,7 +376,7 @@ impl VmpPMatOps for Module { b.as_mut_ptr() as *mut vmp_pmat_t, a.as_ptr(), b.rows() as u64, - (b.limbs()*b.cols()) as u64, + (b.limbs() * b.cols()) as u64, tmp_bytes.as_mut_ptr(), ); } @@ -387,7 +384,7 @@ impl VmpPMatOps for Module { fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] - { + { assert_eq!(a.len(), b.limbs() * self.n() * b.cols()); assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs())); assert_alignement(tmp_bytes.as_ptr()); @@ -399,7 +396,7 @@ impl VmpPMatOps for Module { a.as_ptr(), row_i as u64, b.rows() as u64, - (b.limbs()*b.cols()) as u64, + (b.limbs() * b.cols()) as u64, tmp_bytes.as_mut_ptr(), ); } @@ -419,7 +416,7 @@ impl VmpPMatOps for Module { a.as_ptr() as *const vmp_pmat_t, row_i as u64, a.rows() as u64, - (a.limbs()*a.cols()) as u64, + (a.limbs() * a.cols()) as u64, ); } } From 5841845e22a7ba9ff95cfb0199f6b20aaadbecf0 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sat, 26 Apr 2025 11:29:58 +0200 Subject: [PATCH 05/87] uniformized data access between VecZnx, VecZnxBig & VecZnxDft --- base2k/src/vec_znx_big.rs | 113 ++++++++++++++++++++++++++++---------- base2k/src/vec_znx_dft.rs | 5 ++ 2 files changed, 90 insertions(+), 28 deletions(-) diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index a7bdd59..e1f656f 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -12,6 +12,25 @@ pub struct VecZnxBig { } impl VecZnxBig { + + pub fn new(module: &Module, cols: usize, limbs: usize) -> Self { + #[cfg(debug_assertions)] + { + assert!(cols > 0); + assert!(limbs > 0); + } + let mut data: Vec = alloc_aligned::(module.bytes_of_vec_znx_big(cols, limbs)); + let ptr: *mut u8 = data.as_mut_ptr(); + Self { + data: data, + ptr: ptr, + n: module.n(), + cols: cols, + limbs: limbs, + _marker: PhantomData, + } + } + /// Returns a new [VecZnxBig] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. @@ -64,22 +83,74 @@ impl VecZnxBig { } } - /// Returns a non-mutable reference to the entire contiguous array of the [VecZnxDft]. - pub fn raw(&self) -> &[i64] { - let ptr: *const i64 = self.ptr as *const i64; - unsafe { &std::slice::from_raw_parts(ptr, self.n() * self.poly_count()) } + /// Returns a non-mutable pointer to the backedn slice of the receiver. + pub fn as_ptr(&self) -> *const i64 { + self.ptr as *const i64 + } + + /// Returns a mutable pointer to the backedn slice of the receiver. + pub fn as_mut_ptr(&mut self) -> *mut i64 { + self.ptr as *mut i64 + } + + pub fn raw(&self) -> &[i64] { + unsafe { &std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } + } + + pub fn raw_mut(&mut self) -> &mut [i64] { + let ptr: *mut i64 = self.ptr as *mut i64; + let size: usize = self.n() * self.poly_count(); + unsafe { std::slice::from_raw_parts_mut(ptr, size) } + } + + pub fn at_ptr(&self, i: usize, j: usize) -> *const i64 { + #[cfg(debug_assertions)] + { + assert!(i < self.cols()); + assert!(j < self.limbs()); + } + let offset: usize = self.n * (j * self.cols() + i); + self.as_ptr().wrapping_add(offset) + } + + /// Returns a non-mutable reference to the i-th limb. + /// The returned array is of size [Self::n()] * [Self::cols()]. + pub fn at_limb(&self, i: usize) -> &[i64] { + unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) } + } + + /// Returns a non-mutable reference to the (i, j)-th poly. + /// The returned array is of size [Self::n()]. + pub fn at_poly(&self, i: usize, j: usize) -> &[i64] { + unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) } + } + + /// Returns a mutable pointer starting a the (i, j)-th small poly. + pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut i64 { + #[cfg(debug_assertions)] + { + assert!(i < self.cols()); + assert!(j < self.limbs()); + } + + let offset: usize = self.n * (j * self.cols() + i); + self.as_mut_ptr().wrapping_add(offset) + } + + /// Returns a mutable reference to the i-th limb. + /// The returned array is of size [Self::n()] * [Self::cols()]. + pub fn at_limb_mut(&mut self, i: usize) -> &mut [i64] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) } + } + + /// Returns a mutable reference to the (i, j)-th poly. + /// The returned array is of size [Self::n()]. + pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [i64] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) } } - // Prints the first `n` coefficients of each limb pub fn print(&self, n: usize) { - let raw: &[i64] = self.raw(); - (0..self.limbs()).for_each(|i| { - println!( - "{}: {:?}", - i, - &raw[i * self.n() * self.cols()..i * self.n() * self.cols() + n] - ) - }) + (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); } } @@ -192,21 +263,7 @@ pub trait VecZnxBigOps { impl VecZnxBigOps for Module { fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(limbs > 0); - } - let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_big(cols, limbs)); - let ptr: *mut u8 = data.as_mut_ptr(); - VecZnxBig:: { - data: data, - ptr: ptr, - n: self.n(), - cols: cols, - limbs: limbs, - _marker: PhantomData, - } + VecZnxBig::new(self, cols, limbs) } fn new_vec_znx_big_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxBig { diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 61c2a85..d984cdd 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -16,6 +16,11 @@ pub struct VecZnxDft { impl VecZnxDft { pub fn new(module: &Module, cols: usize, limbs: usize) -> Self { + #[cfg(debug_assertions)] + { + assert!(cols > 0); + assert!(limbs > 0); + } let mut data: Vec = alloc_aligned::(module.bytes_of_vec_znx_dft(cols, limbs)); let ptr: *mut u8 = data.as_mut_ptr(); Self { From 6532f30f66796e3accd2ec576dc1130fae82e040 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sat, 26 Apr 2025 12:34:42 +0200 Subject: [PATCH 06/87] centralized sensitive code into VecZnxLayout --- base2k/examples/rlwe_encrypt.rs | 2 +- base2k/examples/vector_matrix_product.rs | 4 +- base2k/src/commons.rs | 70 +++++++ base2k/src/encoding.rs | 4 +- base2k/src/infos.rs | 19 -- base2k/src/lib.rs | 4 +- base2k/src/sampling.rs | 2 +- base2k/src/svp.rs | 2 +- base2k/src/vec_znx.rs | 255 +++++++++-------------- base2k/src/vec_znx_big.rs | 81 ++----- base2k/src/vec_znx_dft.rs | 91 ++------ base2k/src/vmp.rs | 6 +- 12 files changed, 218 insertions(+), 322 deletions(-) create mode 100644 base2k/src/commons.rs delete mode 100644 base2k/src/infos.rs diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 5385a5b..cb9dfa8 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,6 @@ use base2k::{ Encoding, FFT64, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, - VecZnxDftOps, VecZnxOps, alloc_aligned, + VecZnxDftOps, VecZnxLayout, VecZnxOps, alloc_aligned, }; use itertools::izip; use sampling::source::Source; diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 4e8b97e..8d4a33d 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -1,6 +1,6 @@ use base2k::{ - Encoding, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, - alloc_aligned, + Encoding, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxLayout, VecZnxOps, VmpPMat, + VmpPMatOps, alloc_aligned, }; fn main() { diff --git a/base2k/src/commons.rs b/base2k/src/commons.rs new file mode 100644 index 0000000..ef7a410 --- /dev/null +++ b/base2k/src/commons.rs @@ -0,0 +1,70 @@ +pub trait Infos { + /// Returns the ring degree of the polynomials. + fn n(&self) -> usize; + + /// Returns the base two logarithm of the ring dimension of the polynomials. + fn log_n(&self) -> usize; + + /// Returns the number of rows. + fn rows(&self) -> usize; + + /// Returns the number of polynomials in each row. + fn cols(&self) -> usize; + + /// Returns the number of limbs per polynomial. + fn limbs(&self) -> usize; + + /// Returns the total number of small polynomials. + fn poly_count(&self) -> usize; +} + +pub trait VecZnxLayout: Infos { + type Scalar; + + fn as_ptr(&self) -> *const Self::Scalar; + fn as_mut_ptr(&mut self) -> *mut Self::Scalar; + + fn raw(&self) -> &[Self::Scalar] { + unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } + } + + fn raw_mut(&mut self) -> &mut [Self::Scalar] { + unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } + } + + fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { + #[cfg(debug_assertions)] + { + assert!(i < self.cols()); + assert!(j < self.limbs()); + } + let offset = self.n() * (j * self.cols() + i); + unsafe { self.as_ptr().add(offset) } + } + + fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { + #[cfg(debug_assertions)] + { + assert!(i < self.cols()); + assert!(j < self.limbs()); + } + let offset = self.n() * (j * self.cols() + i); + unsafe { self.as_mut_ptr().add(offset) } + } + + fn at_poly(&self, i: usize, j: usize) -> &[Self::Scalar] { + unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } + } + + fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } + } + + fn at_limb(&self, j: usize) -> &[Self::Scalar] { + unsafe { std::slice::from_raw_parts(self.at_ptr(0, j), self.n() * self.cols()) } + } + + fn at_limb_mut(&mut self, j: usize) -> &mut [Self::Scalar] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, j), self.n() * self.cols()) } + } +} diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index d4085cb..5944f3c 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -1,5 +1,5 @@ use crate::ffi::znx::znx_zero_i64_ref; -use crate::{Infos, VecZnx}; +use crate::{Infos, VecZnx, VecZnxLayout}; use itertools::izip; use rug::{Assign, Float}; use std::cmp::min; @@ -262,7 +262,7 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i #[cfg(test)] mod tests { - use crate::{Encoding, Infos, VecZnx}; + use crate::{Encoding, Infos, VecZnx, VecZnxLayout}; use itertools::izip; use sampling::source::Source; diff --git a/base2k/src/infos.rs b/base2k/src/infos.rs deleted file mode 100644 index 764a7fe..0000000 --- a/base2k/src/infos.rs +++ /dev/null @@ -1,19 +0,0 @@ -pub trait Infos { - /// Returns the ring degree of the polynomials. - fn n(&self) -> usize; - - /// Returns the base two logarithm of the ring dimension of the polynomials. - fn log_n(&self) -> usize; - - /// Returns the number of rows. - fn rows(&self) -> usize; - - /// Returns the number of polynomials in each row. - fn cols(&self) -> usize; - - /// Returns the number of limbs per polynomial. - fn limbs(&self) -> usize; - - /// Returns the total number of small polynomials. - fn poly_count(&self) -> usize; -} diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 83c937a..4d54ca0 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -1,8 +1,8 @@ +pub mod commons; pub mod encoding; #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] // Other modules and exports pub mod ffi; -pub mod infos; pub mod module; pub mod sampling; pub mod stats; @@ -12,8 +12,8 @@ pub mod vec_znx_big; pub mod vec_znx_dft; pub mod vmp; +pub use commons::*; pub use encoding::*; -pub use infos::*; pub use module::*; pub use sampling::*; #[allow(unused_imports)] diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index db9a79b..b60e420 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,4 +1,4 @@ -use crate::{Backend, Infos, Module, VecZnx}; +use crate::{Backend, Infos, Module, VecZnx, VecZnxLayout}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs index e293668..ba375c7 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/svp.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use crate::ffi::svp::{self, svp_ppol_t}; use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, VecZnxLayout, assert_alignement}; use crate::{Infos, alloc_aligned, cast_mut}; use rand::seq::SliceRandom; diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index a6d5858..9b47eae 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -2,7 +2,7 @@ use crate::Backend; use crate::cast_mut; use crate::ffi::vec_znx; use crate::ffi::znx; -use crate::{Infos, Module}; +use crate::{Infos, Module, VecZnxLayout}; use crate::{alloc_aligned, assert_alignement}; use itertools::izip; use std::cmp::min; @@ -35,157 +35,6 @@ pub struct VecZnx { pub ptr: *mut i64, } -pub fn bytes_of_vec_znx(n: usize, cols: usize, limbs: usize) -> usize { - n * cols * limbs * size_of::() -} - -impl VecZnx { - /// Returns a new struct implementing [VecZnx] with the provided data as backing array. - /// - /// The struct will take ownership of buf[..[Self::bytes_of]] - /// - /// User must ensure that data is properly alligned and that - /// the limbs of data is equal to [Self::bytes_of]. - pub fn from_bytes(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(limbs > 0); - assert_eq!(bytes.len(), Self::bytes_of(n, cols, limbs)); - assert_alignement(bytes.as_ptr()); - } - unsafe { - let bytes_i64: &mut [i64] = cast_mut::(bytes); - let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - Self { - n: n, - cols: cols, - limbs: limbs, - data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), - ptr: ptr, - } - } - } - - pub fn from_bytes_borrow(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(limbs > 0); - assert!(bytes.len() >= Self::bytes_of(n, cols, limbs)); - assert_alignement(bytes.as_ptr()); - } - Self { - n: n, - cols: cols, - limbs: limbs, - data: Vec::new(), - ptr: bytes.as_mut_ptr() as *mut i64, - } - } - - pub fn bytes_of(n: usize, cols: usize, limbs: usize) -> usize { - bytes_of_vec_znx(n, cols, limbs) - } - - pub fn copy_from(&mut self, a: &Self) { - copy_vec_znx_from(self, a); - } - - pub fn borrowing(&self) -> bool { - self.data.len() == 0 - } - - /// Total limbs is [Self::n()] * [Self::poly_count()]. - pub fn raw(&self) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.poly_count()) } - } - - /// Returns a reference to backend slice of the receiver. - /// Total size is [Self::n()] * [Self::poly_count()]. - pub fn raw_mut(&mut self) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.poly_count()) } - } - - /// Returns a non-mutable pointer to the backedn slice of the receiver. - pub fn as_ptr(&self) -> *const i64 { - self.ptr - } - - /// Returns a mutable pointer to the backedn slice of the receiver. - pub fn as_mut_ptr(&mut self) -> *mut i64 { - self.ptr - } - - /// Returns a non-mutable pointer starting a the (i, j)-th small poly. - pub fn at_ptr(&self, i: usize, j: usize) -> *const i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.cols()); - assert!(j < self.limbs()); - } - let offset: usize = self.n * (j * self.cols() + i); - self.ptr.wrapping_add(offset) - } - - /// Returns a non-mutable reference to the i-th limb. - /// The returned array is of size [Self::n()] * [Self::cols()]. - pub fn at_limb(&self, i: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) } - } - - /// Returns a non-mutable reference to the (i, j)-th poly. - /// The returned array is of size [Self::n()]. - pub fn at_poly(&self, i: usize, j: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) } - } - - /// Returns a mutable pointer starting a the (i, j)-th small poly. - pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.cols()); - assert!(j < self.limbs()); - } - - let offset: usize = self.n * (j * self.cols() + i); - self.ptr.wrapping_add(offset) - } - - /// Returns a mutable reference to the i-th limb. - /// The returned array is of size [Self::n()] * [Self::cols()]. - pub fn at_limb_mut(&mut self, i: usize) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) } - } - - /// Returns a mutable reference to the (i, j)-th poly. - /// The returned array is of size [Self::n()]. - pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) } - } - - pub fn zero(&mut self) { - unsafe { znx::znx_zero_i64_ref((self.n * self.poly_count()) as u64, self.ptr) } - } - - pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { - normalize(log_base2k, self, carry) - } - - pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { - rsh(log_base2k, self, k, carry) - } - - pub fn switch_degree(&self, a: &mut Self) { - switch_degree(a, self) - } - - // Prints the first `n` coefficients of each limb - pub fn print(&self, n: usize) { - (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])) - } -} - impl Infos for VecZnx { fn n(&self) -> usize { self.n @@ -212,6 +61,18 @@ impl Infos for VecZnx { } } +impl VecZnxLayout for VecZnx { + type Scalar = i64; + + fn as_ptr(&self) -> *const Self::Scalar { + self.ptr + } + + fn as_mut_ptr(&mut self) -> *mut Self::Scalar { + self.ptr + } +} + /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. /// Panics if the cols do not match. @@ -271,6 +132,83 @@ impl VecZnx { .for_each(|x: &mut i64| *x &= mask) } } + + fn bytes_of(n: usize, cols: usize, limbs: usize) -> usize { + n * cols * limbs * size_of::() + } + + /// Returns a new struct implementing [VecZnx] with the provided data as backing array. + /// + /// The struct will take ownership of buf[..[Self::bytes_of]] + /// + /// User must ensure that data is properly alligned and that + /// the limbs of data is equal to [Self::bytes_of]. + pub fn from_bytes(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + #[cfg(debug_assertions)] + { + assert!(cols > 0); + assert!(limbs > 0); + assert_eq!(bytes.len(), Self::bytes_of(n, cols, limbs)); + assert_alignement(bytes.as_ptr()); + } + unsafe { + let bytes_i64: &mut [i64] = cast_mut::(bytes); + let ptr: *mut i64 = bytes_i64.as_mut_ptr(); + Self { + n: n, + cols: cols, + limbs: limbs, + data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), + ptr: ptr, + } + } + } + + pub fn from_bytes_borrow(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + #[cfg(debug_assertions)] + { + assert!(cols > 0); + assert!(limbs > 0); + assert!(bytes.len() >= Self::bytes_of(n, cols, limbs)); + assert_alignement(bytes.as_ptr()); + } + Self { + n: n, + cols: cols, + limbs: limbs, + data: Vec::new(), + ptr: bytes.as_mut_ptr() as *mut i64, + } + } + + pub fn copy_from(&mut self, a: &Self) { + copy_vec_znx_from(self, a); + } + + pub fn borrowing(&self) -> bool { + self.data.len() == 0 + } + + pub fn zero(&mut self) { + unsafe { znx::znx_zero_i64_ref((self.n * self.poly_count()) as u64, self.ptr) } + } + + pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { + normalize(log_base2k, self, carry) + } + + pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { + rsh(log_base2k, self, k, carry) + } + + pub fn switch_degree(&self, a: &mut Self) { + switch_degree(a, self) + } + + // Prints the first `n` coefficients of each limb + pub fn print(&self, n: usize) { + (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])) + } } pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) { @@ -395,6 +333,9 @@ pub trait VecZnxOps { /// * `limbs`: the number of limbs per polynomial (a.k.a small polynomials). fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx; + fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx; + fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx; + /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnx] through [VecZnx::from_bytes]. fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; @@ -457,7 +398,15 @@ impl VecZnxOps for Module { } fn bytes_of_vec_znx(&self, cols: usize, limbs: usize) -> usize { - bytes_of_vec_znx(self.n(), cols, limbs) + VecZnx::bytes_of(self.n(), cols, limbs) + } + + fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx { + VecZnx::from_bytes(self.n(), cols, limbs, bytes) + } + + fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx { + VecZnx::from_bytes_borrow(self.n(), cols, limbs, tmp_bytes) } fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize { diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index e1f656f..ac02aab 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,5 +1,5 @@ use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; -use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxDft, VecZnxLayout, alloc_aligned, assert_alignement}; use std::marker::PhantomData; pub struct VecZnxBig { @@ -12,7 +12,6 @@ pub struct VecZnxBig { } impl VecZnxBig { - pub fn new(module: &Module, cols: usize, limbs: usize) -> Self { #[cfg(debug_assertions)] { @@ -83,72 +82,6 @@ impl VecZnxBig { } } - /// Returns a non-mutable pointer to the backedn slice of the receiver. - pub fn as_ptr(&self) -> *const i64 { - self.ptr as *const i64 - } - - /// Returns a mutable pointer to the backedn slice of the receiver. - pub fn as_mut_ptr(&mut self) -> *mut i64 { - self.ptr as *mut i64 - } - - pub fn raw(&self) -> &[i64] { - unsafe { &std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } - } - - pub fn raw_mut(&mut self) -> &mut [i64] { - let ptr: *mut i64 = self.ptr as *mut i64; - let size: usize = self.n() * self.poly_count(); - unsafe { std::slice::from_raw_parts_mut(ptr, size) } - } - - pub fn at_ptr(&self, i: usize, j: usize) -> *const i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.cols()); - assert!(j < self.limbs()); - } - let offset: usize = self.n * (j * self.cols() + i); - self.as_ptr().wrapping_add(offset) - } - - /// Returns a non-mutable reference to the i-th limb. - /// The returned array is of size [Self::n()] * [Self::cols()]. - pub fn at_limb(&self, i: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) } - } - - /// Returns a non-mutable reference to the (i, j)-th poly. - /// The returned array is of size [Self::n()]. - pub fn at_poly(&self, i: usize, j: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) } - } - - /// Returns a mutable pointer starting a the (i, j)-th small poly. - pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.cols()); - assert!(j < self.limbs()); - } - - let offset: usize = self.n * (j * self.cols() + i); - self.as_mut_ptr().wrapping_add(offset) - } - - /// Returns a mutable reference to the i-th limb. - /// The returned array is of size [Self::n()] * [Self::cols()]. - pub fn at_limb_mut(&mut self, i: usize) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) } - } - - /// Returns a mutable reference to the (i, j)-th poly. - /// The returned array is of size [Self::n()]. - pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) } - } - pub fn print(&self, n: usize) { (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); } @@ -180,6 +113,18 @@ impl Infos for VecZnxBig { } } +impl VecZnxLayout for VecZnxBig { + type Scalar = i64; + + fn as_ptr(&self) -> *const Self::Scalar { + self.ptr as *const Self::Scalar + } + + fn as_mut_ptr(&mut self) -> *mut Self::Scalar { + self.ptr as *mut Self::Scalar + } +} + pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig; diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index d984cdd..6d3c6f6 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,8 +1,8 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; -use crate::{Backend, FFT64, Infos, Module, VecZnxBig, assert_alignement}; -use crate::{DEFAULTALIGN, VecZnx, alloc_aligned}; +use crate::{Backend, FFT64, Infos, Module, VecZnxBig, VecZnxLayout, assert_alignement}; +use crate::{VecZnx, alloc_aligned}; use std::marker::PhantomData; pub struct VecZnxDft { @@ -32,6 +32,11 @@ impl VecZnxDft { _marker: PhantomData, } } + + fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { + unsafe { bytes_of_vec_znx_dft(module.ptr, limbs as u64) as usize * cols } + } + /// Returns a new [VecZnxDft] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. @@ -87,72 +92,6 @@ impl VecZnxDft { } } - /// Returns a non-mutable pointer to the backedn slice of the receiver. - pub fn as_ptr(&self) -> *const f64 { - self.ptr as *const f64 - } - - /// Returns a mutable pointer to the backedn slice of the receiver. - pub fn as_mut_ptr(&mut self) -> *mut f64 { - self.ptr as *mut f64 - } - - pub fn raw(&self) -> &[f64] { - unsafe { &std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } - } - - pub fn raw_mut(&mut self) -> &mut [f64] { - let ptr: *mut f64 = self.ptr as *mut f64; - let size: usize = self.n() * self.poly_count(); - unsafe { std::slice::from_raw_parts_mut(ptr, size) } - } - - pub fn at_ptr(&self, i: usize, j: usize) -> *const f64 { - #[cfg(debug_assertions)] - { - assert!(i < self.cols()); - assert!(j < self.limbs()); - } - let offset: usize = self.n * (j * self.cols() + i); - self.as_ptr().wrapping_add(offset) - } - - /// Returns a non-mutable reference to the i-th limb. - /// The returned array is of size [Self::n()] * [Self::cols()]. - pub fn at_limb(&self, i: usize) -> &[f64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) } - } - - /// Returns a non-mutable reference to the (i, j)-th poly. - /// The returned array is of size [Self::n()]. - pub fn at_poly(&self, i: usize, j: usize) -> &[f64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) } - } - - /// Returns a mutable pointer starting a the (i, j)-th small poly. - pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut f64 { - #[cfg(debug_assertions)] - { - assert!(i < self.cols()); - assert!(j < self.limbs()); - } - - let offset: usize = self.n * (j * self.cols() + i); - self.as_mut_ptr().wrapping_add(offset) - } - - /// Returns a mutable reference to the i-th limb. - /// The returned array is of size [Self::n()] * [Self::cols()]. - pub fn at_limb_mut(&mut self, i: usize) -> &mut [f64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) } - } - - /// Returns a mutable reference to the (i, j)-th poly. - /// The returned array is of size [Self::n()]. - pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [f64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) } - } - pub fn print(&self, n: usize) { (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); } @@ -184,6 +123,18 @@ impl Infos for VecZnxDft { } } +impl VecZnxLayout for VecZnxDft { + type Scalar = f64; + + fn as_ptr(&self) -> *const Self::Scalar { + self.ptr as *const Self::Scalar + } + + fn as_mut_ptr(&mut self) -> *mut Self::Scalar { + self.ptr as *mut Self::Scalar + } +} + pub trait VecZnxDftOps { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft; @@ -257,7 +208,7 @@ impl VecZnxDftOps for Module { } fn bytes_of_vec_znx_dft(&self, cols: usize, limbs: usize) -> usize { - unsafe { bytes_of_vec_znx_dft(self.ptr, limbs as u64) as usize * cols } + VecZnxDft::bytes_of(&self, cols, limbs) } fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) { @@ -363,7 +314,7 @@ impl VecZnxDftOps for Module { #[cfg(test)] mod tests { - use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned}; + use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxLayout, VecZnxOps, alloc_aligned}; use itertools::izip; use sampling::source::Source; diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index f868a06..f2af561 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -1,7 +1,7 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp::{self, vmp_pmat_t}; -use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, VecZnxLayout, alloc_aligned, assert_alignement}; use std::marker::PhantomData; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], @@ -592,8 +592,8 @@ impl VmpPMatOps for Module { #[cfg(test)] mod tests { use crate::{ - FFT64, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, - alloc_aligned, + FFT64, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxLayout, VecZnxOps, VmpPMat, + VmpPMatOps, alloc_aligned, }; use sampling::source::Source; From 54148acf6b24b8306ee8c0319c134f7d276c9995 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sat, 26 Apr 2025 13:19:22 +0200 Subject: [PATCH 07/87] more refactoring --- base2k/examples/rlwe_encrypt.rs | 6 +- base2k/examples/vector_matrix_product.rs | 10 +- base2k/src/commons.rs | 25 ++++- base2k/src/encoding.rs | 10 +- base2k/src/lib.rs | 8 +- base2k/src/{vmp.rs => mat_znx_dft.rs} | 88 ++++++++-------- base2k/src/sampling.rs | 2 +- base2k/src/{svp.rs => scalar_znx_dft.rs} | 38 +++---- base2k/src/stats.rs | 2 +- base2k/src/vec_znx.rs | 123 ++++++++++++----------- base2k/src/vec_znx_big.rs | 50 +++++---- base2k/src/vec_znx_dft.rs | 44 ++++---- rlwe/benches/gadget_product.rs | 10 +- rlwe/examples/encryption.rs | 4 +- rlwe/src/automorphism.rs | 14 +-- rlwe/src/ciphertext.rs | 12 +-- rlwe/src/decryptor.rs | 10 +- rlwe/src/elem.rs | 8 +- rlwe/src/encryptor.rs | 26 ++--- rlwe/src/gadget_product.rs | 16 +-- rlwe/src/key_generator.rs | 6 +- rlwe/src/key_switching.rs | 12 +-- rlwe/src/keys.rs | 8 +- rlwe/src/rgsw_product.rs | 12 +-- rlwe/src/trace.rs | 6 +- 25 files changed, 294 insertions(+), 256 deletions(-) rename base2k/src/{vmp.rs => mat_znx_dft.rs} (88%) rename base2k/src/{svp.rs => scalar_znx_dft.rs} (87%) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index cb9dfa8..3d53141 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,6 @@ use base2k::{ - Encoding, FFT64, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, - VecZnxDftOps, VecZnxLayout, VecZnxOps, alloc_aligned, + Encoding, FFT64, Module, Sampling, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, + VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxLayout, alloc_aligned, }; use itertools::izip; use sampling::source::Source; @@ -25,7 +25,7 @@ fn main() { s.fill_ternary_prob(0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_ppol: SvpPPol = module.new_svp_ppol(); + let mut s_ppol: ScalarZnxDft = module.new_svp_ppol(); // s_ppol <- DFT(s) module.svp_prepare(&mut s_ppol, &s); diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 8d4a33d..0120f61 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -1,6 +1,6 @@ use base2k::{ - Encoding, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxLayout, VecZnxOps, VmpPMat, - VmpPMatOps, alloc_aligned, + Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, + ZnxInfos, ZnxLayout, alloc_aligned, }; fn main() { @@ -31,16 +31,16 @@ fn main() { a.print(n); println!(); - let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows_mat, 1, limbs_mat); + let mut mat_znx_dft: MatZnxDft = module.new_mat_znx_dft(rows_mat, 1, limbs_mat); (0..a.limbs()).for_each(|row_i| { let mut tmp: VecZnx = module.new_vec_znx(1, limbs_mat); tmp.at_limb_mut(row_i)[1] = 1 as i64; - module.vmp_prepare_row(&mut vmp_pmat, tmp.raw(), row_i, &mut buf); + module.vmp_prepare_row(&mut mat_znx_dft, tmp.raw(), row_i, &mut buf); }); let mut c_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs_mat); - module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); + module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut buf); let mut c_big: VecZnxBig = c_dft.as_vec_znx_big(); module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft); diff --git a/base2k/src/commons.rs b/base2k/src/commons.rs index ef7a410..290599d 100644 --- a/base2k/src/commons.rs +++ b/base2k/src/commons.rs @@ -1,4 +1,6 @@ -pub trait Infos { +use crate::{Backend, Module}; + +pub trait ZnxInfos { /// Returns the ring degree of the polynomials. fn n(&self) -> usize; @@ -18,20 +20,34 @@ pub trait Infos { fn poly_count(&self) -> usize; } -pub trait VecZnxLayout: Infos { +pub trait ZnxBase { + type Scalar; + fn new(module: &Module, cols: usize, limbs: usize) -> Self; + fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self; + fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self; + fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize; +} + +pub trait ZnxLayout: ZnxInfos { type Scalar; + /// Returns a non-mutable pointer to the underlying coefficients array. fn as_ptr(&self) -> *const Self::Scalar; + + /// Returns a mutable pointer to the underlying coefficients array. fn as_mut_ptr(&mut self) -> *mut Self::Scalar; + /// Returns a non-mutable reference to the entire underlying coefficient array. fn raw(&self) -> &[Self::Scalar] { unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } } + /// Returns a mutable reference to the entire underlying coefficient array. fn raw_mut(&mut self) -> &mut [Self::Scalar] { unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } } + /// Returns a non-mutable pointer starting at the (i, j)-th small polynomial. fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { #[cfg(debug_assertions)] { @@ -42,6 +58,7 @@ pub trait VecZnxLayout: Infos { unsafe { self.as_ptr().add(offset) } } + /// Returns a mutable pointer starting at the (i, j)-th small polynomial. fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { #[cfg(debug_assertions)] { @@ -52,18 +69,22 @@ pub trait VecZnxLayout: Infos { unsafe { self.as_mut_ptr().add(offset) } } + /// Returns non-mutable reference to the (i, j)-th small polynomial. fn at_poly(&self, i: usize, j: usize) -> &[Self::Scalar] { unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } } + /// Returns mutable reference to the (i, j)-th small polynomial. fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } } + /// Returns non-mutable reference to the i-th limb. fn at_limb(&self, j: usize) -> &[Self::Scalar] { unsafe { std::slice::from_raw_parts(self.at_ptr(0, j), self.n() * self.cols()) } } + /// Returns mutable reference to the i-th limb. fn at_limb_mut(&mut self, j: usize) -> &mut [Self::Scalar] { unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, j), self.n() * self.cols()) } } diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 5944f3c..6034b95 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -1,5 +1,5 @@ use crate::ffi::znx::znx_zero_i64_ref; -use crate::{Infos, VecZnx, VecZnxLayout}; +use crate::{VecZnx, ZnxInfos, ZnxLayout}; use itertools::izip; use rug::{Assign, Float}; use std::cmp::min; @@ -262,17 +262,18 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i #[cfg(test)] mod tests { - use crate::{Encoding, Infos, VecZnx, VecZnxLayout}; + use crate::{Encoding, FFT64, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout}; use itertools::izip; use sampling::source::Source; #[test] fn test_set_get_i64_lo_norm() { let n: usize = 8; + let module: Module = Module::::new(n); let log_base2k: usize = 17; let cols: usize = 5; let log_k: usize = cols * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(n, 2, cols); + let mut a: VecZnx = VecZnx::new(&module, 2, cols); let mut source: Source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); @@ -290,10 +291,11 @@ mod tests { #[test] fn test_set_get_i64_hi_norm() { let n: usize = 8; + let module: Module = Module::::new(n); let log_base2k: usize = 17; let cols: usize = 5; let log_k: usize = cols * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(n, 2, cols); + let mut a: VecZnx = VecZnx::new(&module, 2, cols); let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 4d54ca0..40df3bb 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -3,26 +3,26 @@ pub mod encoding; #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] // Other modules and exports pub mod ffi; +pub mod mat_znx_dft; pub mod module; pub mod sampling; +pub mod scalar_znx_dft; pub mod stats; -pub mod svp; pub mod vec_znx; pub mod vec_znx_big; pub mod vec_znx_dft; -pub mod vmp; pub use commons::*; pub use encoding::*; +pub use mat_znx_dft::*; pub use module::*; pub use sampling::*; +pub use scalar_znx_dft::*; #[allow(unused_imports)] pub use stats::*; -pub use svp::*; pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_dft::*; -pub use vmp::*; pub const GALOISGENERATOR: u64 = 5; pub const DEFAULTALIGN: usize = 64; diff --git a/base2k/src/vmp.rs b/base2k/src/mat_znx_dft.rs similarity index 88% rename from base2k/src/vmp.rs rename to base2k/src/mat_znx_dft.rs index f2af561..9466696 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,7 +1,7 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp::{self, vmp_pmat_t}; -use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, VecZnxLayout, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; use std::marker::PhantomData; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], @@ -10,7 +10,7 @@ use std::marker::PhantomData; /// /// [VmpPMat] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [VmpPMat]. /// See the trait [VmpPMatOps] for additional information. -pub struct VmpPMat { +pub struct MatZnxDft { /// Raw data, is empty if borrowing scratch space. data: Vec, /// Pointer to data. Can point to scratch space. @@ -26,7 +26,7 @@ pub struct VmpPMat { _marker: PhantomData, } -impl Infos for VmpPMat { +impl ZnxInfos for MatZnxDft { fn n(&self) -> usize { self.n } @@ -52,11 +52,11 @@ impl Infos for VmpPMat { } } -impl VmpPMat { - fn new(module: &Module, rows: usize, cols: usize, limbs: usize) -> VmpPMat { - let mut data: Vec = alloc_aligned::(module.bytes_of_vmp_pmat(rows, cols, limbs)); +impl MatZnxDft { + fn new(module: &Module, rows: usize, cols: usize, limbs: usize) -> MatZnxDft { + let mut data: Vec = alloc_aligned::(module.bytes_of_mat_znx_dft(rows, cols, limbs)); let ptr: *mut u8 = data.as_mut_ptr(); - VmpPMat:: { + MatZnxDft:: { data: data, ptr: ptr, n: module.n(), @@ -126,8 +126,8 @@ impl VmpPMat { /// This trait implements methods for vector matrix product, /// that is, multiplying a [VecZnx] with a [VmpPMat]. -pub trait VmpPMatOps { - fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> usize; +pub trait MatZnxDftOps { + fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize; /// Allocates a new [VmpPMat] with the given number of rows and columns. /// @@ -135,7 +135,7 @@ pub trait VmpPMatOps { /// /// * `rows`: number of rows (number of [VecZnxDft]). /// * `size`: number of size (number of size of each [VecZnxDft]). - fn new_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> VmpPMat; + fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft; /// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous]. /// @@ -154,7 +154,7 @@ pub trait VmpPMatOps { /// * `b`: [VmpPMat] on which the values are encoded. /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat]. /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]); + fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft, a: &[i64], buf: &mut [u8]); /// Prepares the ith-row of [VmpPMat] from a [VecZnx]. /// @@ -166,7 +166,7 @@ pub trait VmpPMatOps { /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); + fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); /// Extracts the ith-row of [VmpPMat] into a [VecZnxBig]. /// @@ -175,7 +175,7 @@ pub trait VmpPMatOps { /// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat]. /// * `a`: [VmpPMat] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize); + fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &MatZnxDft, row_i: usize); /// Prepares the ith-row of [VmpPMat] from a [VecZnxDft]. /// @@ -186,7 +186,7 @@ pub trait VmpPMatOps { /// * `row_i`: the index of the row to prepare. /// /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize); + fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, a: &VecZnxDft, row_i: usize); /// Extracts the ith-row of [VmpPMat] into a [VecZnxDft]. /// @@ -195,7 +195,7 @@ pub trait VmpPMatOps { /// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat]. /// * `a`: [VmpPMat] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize); + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, row_i: usize); /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. /// @@ -231,7 +231,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver. /// @@ -257,7 +257,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft]. /// @@ -294,7 +294,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it. /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -321,7 +321,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place. /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -347,15 +347,15 @@ pub trait VmpPMatOps { /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. /// * `a`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, buf: &mut [u8]); } -impl VmpPMatOps for Module { - fn new_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> VmpPMat { - VmpPMat::::new(self, rows, cols, limbs) +impl MatZnxDftOps for Module { + fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft { + MatZnxDft::::new(self, rows, cols, limbs) } - fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> usize { + fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize { unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs * cols) as u64) as usize } } @@ -363,7 +363,7 @@ impl VmpPMatOps for Module { unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize } } - fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) { + fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft, a: &[i64], tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] { assert_eq!(a.len(), b.n() * b.poly_count()); @@ -382,7 +382,7 @@ impl VmpPMatOps for Module { } } - fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { + fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] { assert_eq!(a.len(), b.limbs() * self.n() * b.cols()); @@ -402,7 +402,7 @@ impl VmpPMatOps for Module { } } - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize) { + fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &MatZnxDft, row_i: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); @@ -421,7 +421,7 @@ impl VmpPMatOps for Module { } } - fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize) { + fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, a: &VecZnxDft, row_i: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); @@ -439,7 +439,7 @@ impl VmpPMatOps for Module { } } - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize) { + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, row_i: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); @@ -469,7 +469,7 @@ impl VmpPMatOps for Module { } } - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { @@ -491,7 +491,7 @@ impl VmpPMatOps for Module { } } - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { + fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { @@ -525,7 +525,7 @@ impl VmpPMatOps for Module { } } - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) { + fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, tmp_bytes: &mut [u8]) { debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { @@ -546,7 +546,13 @@ impl VmpPMatOps for Module { } } - fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) { + fn vmp_apply_dft_to_dft_add( + &self, + c: &mut VecZnxDft, + a: &VecZnxDft, + b: &MatZnxDft, + tmp_bytes: &mut [u8], + ) { debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { @@ -567,7 +573,7 @@ impl VmpPMatOps for Module { } } - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) { + fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, tmp_bytes: &mut [u8]) { debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.limbs(), b.limbs(), a.rows(), a.limbs())); #[cfg(debug_assertions)] { @@ -592,8 +598,8 @@ impl VmpPMatOps for Module { #[cfg(test)] mod tests { use crate::{ - FFT64, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxLayout, VecZnxOps, VmpPMat, - VmpPMatOps, alloc_aligned, + FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, + ZnxLayout, alloc_aligned, }; use sampling::source::Source; @@ -608,8 +614,8 @@ mod tests { let mut a_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); let mut b_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_size); - let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, 1, vpmat_size); - let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, 1, vpmat_size); + let mut vmpmat_0: MatZnxDft = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size); + let mut vmpmat_1: MatZnxDft = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size); let mut tmp_bytes: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, 1, vpmat_size)); @@ -619,15 +625,15 @@ mod tests { module.vec_znx_dft(&mut a_dft, &a); module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes); - // Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft) + // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i); assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); - // Checks that a_dft = extract_dft(prepare(vmp_pmat, a), b_dft) + // Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i); assert_eq!(a_dft.raw(), b_dft.raw()); - // Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big) + // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes); assert_eq!(a_big.raw(), b_big.raw()); diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index b60e420..c415b80 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,4 +1,4 @@ -use crate::{Backend, Infos, Module, VecZnx, VecZnxLayout}; +use crate::{Backend, Module, VecZnx, ZnxInfos, ZnxLayout}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; diff --git a/base2k/src/svp.rs b/base2k/src/scalar_znx_dft.rs similarity index 87% rename from base2k/src/svp.rs rename to base2k/src/scalar_znx_dft.rs index ba375c7..7457ca2 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -2,9 +2,9 @@ use std::marker::PhantomData; use crate::ffi::svp::{self, svp_ppol_t}; use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, VecZnxLayout, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, ZnxLayout, assert_alignement}; -use crate::{Infos, alloc_aligned, cast_mut}; +use crate::{ZnxInfos, alloc_aligned, cast_mut}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; @@ -148,7 +148,7 @@ impl ScalarOps for Module { } } -pub struct SvpPPol { +pub struct ScalarZnxDft { pub n: usize, pub data: Vec, pub ptr: *mut u8, @@ -157,7 +157,7 @@ pub struct SvpPPol { /// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft]. /// An [SvpPPol] an be seen as a [VecZnxDft] of one limb. -impl SvpPPol { +impl ScalarZnxDft { pub fn new(module: &Module) -> Self { module.new_svp_ppol() } @@ -207,9 +207,9 @@ impl SvpPPol { } } -pub trait SvpPPolOps { +pub trait ScalarZnxDftOps { /// Allocates a new [SvpPPol]. - fn new_svp_ppol(&self) -> SvpPPol; + fn new_svp_ppol(&self) -> ScalarZnxDft; /// Returns the minimum number of bytes necessary to allocate /// a new [SvpPPol] through [SvpPPol::from_bytes] ro. @@ -218,26 +218,26 @@ pub trait SvpPPolOps { /// Allocates a new [SvpPPol] from an array of bytes. /// The array of bytes is owned by the [SvpPPol]. /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol; + fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft; /// Allocates a new [SvpPPol] from an array of bytes. /// The array of bytes is borrowed by the [SvpPPol]. /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol; + fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft; /// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft]. - fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar); + fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft, a: &Scalar); /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// the [VecZnxDft] is multiplied with [SvpPPol]. - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx); + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx); } -impl SvpPPolOps for Module { - fn new_svp_ppol(&self) -> SvpPPol { +impl ScalarZnxDftOps for Module { + fn new_svp_ppol(&self) -> ScalarZnxDft { let mut data: Vec = alloc_aligned::(self.bytes_of_svp_ppol()); let ptr: *mut u8 = data.as_mut_ptr(); - SvpPPol:: { + ScalarZnxDft:: { data: data, ptr: ptr, n: self.n(), @@ -249,19 +249,19 @@ impl SvpPPolOps for Module { unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize } } - fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol { - SvpPPol::from_bytes(self, bytes) + fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft { + ScalarZnxDft::from_bytes(self, bytes) } - fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol { - SvpPPol::from_bytes_borrow(self, tmp_bytes) + fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft { + ScalarZnxDft::from_bytes_borrow(self, tmp_bytes) } - fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) { + fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft, a: &Scalar) { unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) } } - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) { + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx) { unsafe { svp::svp_apply_dft( self.ptr, diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs index f72ebaa..44e441f 100644 --- a/base2k/src/stats.rs +++ b/base2k/src/stats.rs @@ -1,4 +1,4 @@ -use crate::{Encoding, Infos, VecZnx}; +use crate::{Encoding, VecZnx, ZnxInfos}; use rug::Float; use rug::float::Round; use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound}; diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 9b47eae..89173f0 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,8 +1,9 @@ use crate::Backend; +use crate::ZnxBase; use crate::cast_mut; use crate::ffi::vec_znx; use crate::ffi::znx; -use crate::{Infos, Module, VecZnxLayout}; +use crate::{Module, ZnxInfos, ZnxLayout}; use crate::{alloc_aligned, assert_alignement}; use itertools::izip; use std::cmp::min; @@ -35,7 +36,7 @@ pub struct VecZnx { pub ptr: *mut i64, } -impl Infos for VecZnx { +impl ZnxInfos for VecZnx { fn n(&self) -> usize { self.n } @@ -61,7 +62,7 @@ impl Infos for VecZnx { } } -impl VecZnxLayout for VecZnx { +impl ZnxLayout for VecZnx { type Scalar = i64; fn as_ptr(&self) -> *const Self::Scalar { @@ -84,9 +85,12 @@ pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { data_b[..size].copy_from_slice(&data_a[..size]) } -impl VecZnx { +impl ZnxBase for VecZnx { + type Scalar = i64; + /// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\]. - pub fn new(n: usize, cols: usize, limbs: usize) -> Self { + fn new(module: &Module, cols: usize, limbs: usize) -> Self { + let n: usize = module.n(); #[cfg(debug_assertions)] { assert!(n > 0); @@ -94,7 +98,7 @@ impl VecZnx { assert!(cols > 0); assert!(limbs > 0); } - let mut data: Vec = alloc_aligned::(n * cols * limbs); + let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, limbs)); let ptr: *mut i64 = data.as_mut_ptr(); Self { n: n, @@ -105,6 +109,57 @@ impl VecZnx { } } + fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { + module.n() * cols * limbs * size_of::() + } + + /// Returns a new struct implementing [VecZnx] with the provided data as backing array. + /// + /// The struct will take ownership of buf[..[Self::bytes_of]] + /// + /// User must ensure that data is properly alligned and that + /// the limbs of data is equal to [Self::bytes_of]. + fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + let n: usize = module.n(); + #[cfg(debug_assertions)] + { + assert!(cols > 0); + assert!(limbs > 0); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); + assert_alignement(bytes.as_ptr()); + } + unsafe { + let bytes_i64: &mut [i64] = cast_mut::(bytes); + let ptr: *mut i64 = bytes_i64.as_mut_ptr(); + Self { + n: n, + cols: cols, + limbs: limbs, + data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), + ptr: ptr, + } + } + } + + fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + #[cfg(debug_assertions)] + { + assert!(cols > 0); + assert!(limbs > 0); + assert!(bytes.len() >= Self::bytes_of(module, cols, limbs)); + assert_alignement(bytes.as_ptr()); + } + Self { + n: module.n(), + cols: cols, + limbs: limbs, + data: Vec::new(), + ptr: bytes.as_mut_ptr() as *mut i64, + } + } +} + +impl VecZnx { /// Truncates the precision of the [VecZnx] by k bits. /// /// # Arguments @@ -133,54 +188,6 @@ impl VecZnx { } } - fn bytes_of(n: usize, cols: usize, limbs: usize) -> usize { - n * cols * limbs * size_of::() - } - - /// Returns a new struct implementing [VecZnx] with the provided data as backing array. - /// - /// The struct will take ownership of buf[..[Self::bytes_of]] - /// - /// User must ensure that data is properly alligned and that - /// the limbs of data is equal to [Self::bytes_of]. - pub fn from_bytes(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(limbs > 0); - assert_eq!(bytes.len(), Self::bytes_of(n, cols, limbs)); - assert_alignement(bytes.as_ptr()); - } - unsafe { - let bytes_i64: &mut [i64] = cast_mut::(bytes); - let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - Self { - n: n, - cols: cols, - limbs: limbs, - data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), - ptr: ptr, - } - } - } - - pub fn from_bytes_borrow(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(limbs > 0); - assert!(bytes.len() >= Self::bytes_of(n, cols, limbs)); - assert_alignement(bytes.as_ptr()); - } - Self { - n: n, - cols: cols, - limbs: limbs, - data: Vec::new(), - ptr: bytes.as_mut_ptr() as *mut i64, - } - } - pub fn copy_from(&mut self, a: &Self) { copy_vec_znx_from(self, a); } @@ -394,19 +401,19 @@ pub trait VecZnxOps { impl VecZnxOps for Module { fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx { - VecZnx::new(self.n(), cols, limbs) + VecZnx::new(self, cols, limbs) } fn bytes_of_vec_znx(&self, cols: usize, limbs: usize) -> usize { - VecZnx::bytes_of(self.n(), cols, limbs) + VecZnx::bytes_of(self, cols, limbs) } fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx { - VecZnx::from_bytes(self.n(), cols, limbs, bytes) + VecZnx::from_bytes(self, cols, limbs, bytes) } fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx { - VecZnx::from_bytes_borrow(self.n(), cols, limbs, tmp_bytes) + VecZnx::from_bytes_borrow(self, cols, limbs, tmp_bytes) } fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize { diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index ac02aab..7a8cc48 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,5 +1,5 @@ use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; -use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxDft, VecZnxLayout, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, ZnxBase, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; use std::marker::PhantomData; pub struct VecZnxBig { @@ -10,16 +10,17 @@ pub struct VecZnxBig { pub limbs: usize, pub _marker: PhantomData, } +impl ZnxBase for VecZnxBig { + type Scalar = u8; -impl VecZnxBig { - pub fn new(module: &Module, cols: usize, limbs: usize) -> Self { + fn new(module: &Module, cols: usize, limbs: usize) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); assert!(limbs > 0); } - let mut data: Vec = alloc_aligned::(module.bytes_of_vec_znx_big(cols, limbs)); - let ptr: *mut u8 = data.as_mut_ptr(); + let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, limbs)); + let ptr: *mut Self::Scalar = data.as_mut_ptr(); Self { data: data, ptr: ptr, @@ -30,15 +31,19 @@ impl VecZnxBig { } } + fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { + unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, limbs as u64) as usize * cols } + } + /// Returns a new [VecZnxBig] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. - pub fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); assert!(limbs > 0); - assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols, limbs)); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); assert_alignement(bytes.as_ptr()) }; unsafe { @@ -53,12 +58,12 @@ impl VecZnxBig { } } - pub fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); assert!(limbs > 0); - assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols, limbs)); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); assert_alignement(bytes.as_ptr()); } Self { @@ -70,24 +75,9 @@ impl VecZnxBig { _marker: PhantomData, } } - - pub fn as_vec_znx_dft(&mut self) -> VecZnxDft { - VecZnxDft:: { - data: Vec::new(), - ptr: self.ptr, - n: self.n, - cols: self.cols, - limbs: self.limbs, - _marker: self._marker, - } - } - - pub fn print(&self, n: usize) { - (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); - } } -impl Infos for VecZnxBig { +impl ZnxInfos for VecZnxBig { fn log_n(&self) -> usize { (usize::BITS - (self.n - 1).leading_zeros()) as _ } @@ -113,7 +103,7 @@ impl Infos for VecZnxBig { } } -impl VecZnxLayout for VecZnxBig { +impl ZnxLayout for VecZnxBig { type Scalar = i64; fn as_ptr(&self) -> *const Self::Scalar { @@ -125,6 +115,12 @@ impl VecZnxLayout for VecZnxBig { } } +impl VecZnxBig { + pub fn print(&self, n: usize) { + (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); + } +} + pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig; @@ -220,7 +216,7 @@ impl VecZnxBigOps for Module { } fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize { - unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, limbs as u64) as usize * cols } + VecZnxBig::bytes_of(self, cols, limbs) } /// [VecZnxBig] (3 cols and 4 limbs) diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 6d3c6f6..7724710 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,7 +1,7 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; -use crate::{Backend, FFT64, Infos, Module, VecZnxBig, VecZnxLayout, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnxBig, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement}; use crate::{VecZnx, alloc_aligned}; use std::marker::PhantomData; @@ -14,15 +14,17 @@ pub struct VecZnxDft { pub _marker: PhantomData, } -impl VecZnxDft { - pub fn new(module: &Module, cols: usize, limbs: usize) -> Self { +impl ZnxBase for VecZnxDft { + type Scalar = u8; + + fn new(module: &Module, cols: usize, limbs: usize) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); assert!(limbs > 0); } - let mut data: Vec = alloc_aligned::(module.bytes_of_vec_znx_dft(cols, limbs)); - let ptr: *mut u8 = data.as_mut_ptr(); + let mut data: Vec = alloc_aligned(Self::bytes_of(module, cols, limbs)); + let ptr: *mut Self::Scalar = data.as_mut_ptr(); Self { data: data, ptr: ptr, @@ -33,19 +35,19 @@ impl VecZnxDft { } } - fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { + fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { unsafe { bytes_of_vec_znx_dft(module.ptr, limbs as u64) as usize * cols } } /// Returns a new [VecZnxDft] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. - pub fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); assert!(limbs > 0); - assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols, limbs)); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); assert_alignement(bytes.as_ptr()) } unsafe { @@ -60,12 +62,12 @@ impl VecZnxDft { } } - pub fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); assert!(limbs > 0); - assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols, limbs)); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); assert_alignement(bytes.as_ptr()); } Self { @@ -77,12 +79,14 @@ impl VecZnxDft { _marker: PhantomData, } } +} +impl VecZnxDft { /// Cast a [VecZnxDft] into a [VecZnxBig]. /// The returned [VecZnxBig] shares the backing array /// with the original [VecZnxDft]. - pub fn as_vec_znx_big(&mut self) -> VecZnxBig { - VecZnxBig:: { + pub fn as_vec_znx_big(&mut self) -> VecZnxBig { + VecZnxBig:: { data: Vec::new(), ptr: self.ptr, n: self.n, @@ -91,13 +95,9 @@ impl VecZnxDft { _marker: PhantomData, } } - - pub fn print(&self, n: usize) { - (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); - } } -impl Infos for VecZnxDft { +impl ZnxInfos for VecZnxDft { fn n(&self) -> usize { self.n } @@ -123,7 +123,7 @@ impl Infos for VecZnxDft { } } -impl VecZnxLayout for VecZnxDft { +impl ZnxLayout for VecZnxDft { type Scalar = f64; fn as_ptr(&self) -> *const Self::Scalar { @@ -135,6 +135,12 @@ impl VecZnxLayout for VecZnxDft { } } +impl VecZnxDft { + pub fn print(&self, n: usize) { + (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); + } +} + pub trait VecZnxDftOps { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft; @@ -314,7 +320,7 @@ impl VecZnxDftOps for Module { #[cfg(test)] mod tests { - use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxLayout, VecZnxOps, alloc_aligned}; + use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxLayout, alloc_aligned}; use itertools::izip; use sampling::source::Source; diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs index fdd2240..14bb06d 100644 --- a/rlwe/benches/gadget_product.rs +++ b/rlwe/benches/gadget_product.rs @@ -1,4 +1,4 @@ -use base2k::{BACKEND, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8}; +use base2k::{BACKEND, Module, Sampling, ScalarZnxDftOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, MatZnxDft, alloc_aligned_u8}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use rlwe::{ ciphertext::{Ciphertext, new_gadget_ciphertext}, @@ -16,7 +16,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { res_dft_0: &'a mut VecZnxDft, res_dft_1: &'a mut VecZnxDft, a: &'a VecZnx, - b: &'a Ciphertext, + b: &'a Ciphertext, b_cols: usize, tmp_bytes: &'a mut [u8], ) -> Box { @@ -69,13 +69,13 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { let mut source_xe: Source = Source::new([4; 32]); let mut source_xa: Source = Source::new([5; 32]); - let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); + let mut sk0_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0); - let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); + let mut sk1_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0); - let mut gadget_ct: Ciphertext = new_gadget_ciphertext( + let mut gadget_ct: Ciphertext = new_gadget_ciphertext( params.module(), params.log_base2k(), params.cols_q(), diff --git a/rlwe/examples/encryption.rs b/rlwe/examples/encryption.rs index b9d66cd..20a0603 100644 --- a/rlwe/examples/encryption.rs +++ b/rlwe/examples/encryption.rs @@ -1,4 +1,4 @@ -use base2k::{Encoding, SvpPPolOps, VecZnx, alloc_aligned}; +use base2k::{Encoding, ScalarZnxDftOps, VecZnx, alloc_aligned}; use rlwe::{ ciphertext::Ciphertext, elem::ElemCommon, @@ -51,7 +51,7 @@ fn main() { let mut source_xe: Source = Source::new([1; 32]); let mut source_xa: Source = Source::new([2; 32]); - let mut sk_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); + let mut sk_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); params.module().svp_prepare(&mut sk_svp_ppol, &sk.0); params.encrypt_rlwe_sk( diff --git a/rlwe/src/automorphism.rs b/rlwe/src/automorphism.rs index 5e5b48a..d76e356 100644 --- a/rlwe/src/automorphism.rs +++ b/rlwe/src/automorphism.rs @@ -7,15 +7,15 @@ use crate::{ parameters::Parameters, }; use base2k::{ - Module, Scalar, ScalarOps, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, - VmpPMatOps, assert_alignement, + Module, Scalar, ScalarOps, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, MatZnxDft, + MatZnxDftOps, assert_alignement, }; use sampling::source::Source; use std::collections::HashMap; /// Stores DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E, A]) where AUTO(X, p): X^{i} -> X^{i*p} pub struct AutomorphismKey { - pub value: Ciphertext, + pub value: Ciphertext, pub p: i64, } @@ -106,12 +106,12 @@ impl AutomorphismKey { let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol()); let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes); - let mut sk_out: SvpPPol = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes); + let mut sk_out: ScalarZnxDft = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes); let mut keys: Vec = Vec::new(); p.iter().for_each(|pi| { - let mut value: Ciphertext = new_gadget_ciphertext(module, log_base2k, rows, log_q); + let mut value: Ciphertext = new_gadget_ciphertext(module, log_base2k, rows, log_q); let p_inv: i64 = module.galois_element_inv(*pi); @@ -223,7 +223,7 @@ mod test { parameters::{Parameters, ParametersLiteral}, plaintext::Plaintext, }; - use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, alloc_aligned}; + use base2k::{BACKEND, Encoding, Module, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxOps, alloc_aligned}; use sampling::source::{Source, new_seed}; #[test] @@ -267,7 +267,7 @@ mod test { let mut sk: SecretKey = SecretKey::new(module); sk.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol(); + let mut sk_svp_ppol: ScalarZnxDft = module.new_svp_ppol(); module.svp_prepare(&mut sk_svp_ppol, &sk.0); let p: i64 = -5; diff --git a/rlwe/src/ciphertext.rs b/rlwe/src/ciphertext.rs index 73addb5..bcffeec 100644 --- a/rlwe/src/ciphertext.rs +++ b/rlwe/src/ciphertext.rs @@ -1,6 +1,6 @@ use crate::elem::{Elem, ElemCommon}; use crate::parameters::Parameters; -use base2k::{Infos, Layout, Module, VecZnx, VmpPMat}; +use base2k::{ZnxInfos, Layout, Module, VecZnx, MatZnxDft}; pub struct Ciphertext(pub Elem); @@ -12,7 +12,7 @@ impl Parameters { impl ElemCommon for Ciphertext where - T: Infos, + T: ZnxInfos, { fn n(&self) -> usize { self.elem().n() @@ -78,16 +78,16 @@ pub fn new_rlwe_ciphertext(module: &Module, log_base2k: usize, log_q: usize) -> Ciphertext::::new(module, log_base2k, log_q, rows) } -pub fn new_gadget_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext { +pub fn new_gadget_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext { let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut elem: Elem = Elem::::new(module, log_base2k, 2, rows, cols); + let mut elem: Elem = Elem::::new(module, log_base2k, 2, rows, cols); elem.log_q = log_q; Ciphertext(elem) } -pub fn new_rgsw_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext { +pub fn new_rgsw_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext { let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut elem: Elem = Elem::::new(module, log_base2k, 4, rows, cols); + let mut elem: Elem = Elem::::new(module, log_base2k, 4, rows, cols); elem.log_q = log_q; Ciphertext(elem) } diff --git a/rlwe/src/decryptor.rs b/rlwe/src/decryptor.rs index 6eeea27..4c1fb7e 100644 --- a/rlwe/src/decryptor.rs +++ b/rlwe/src/decryptor.rs @@ -5,16 +5,16 @@ use crate::{ parameters::Parameters, plaintext::Plaintext, }; -use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZnxDftOps}; +use base2k::{Module, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZnxDftOps}; use std::cmp::min; pub struct Decryptor { - sk: SvpPPol, + sk: ScalarZnxDft, } impl Decryptor { pub fn new(params: &Parameters, sk: &SecretKey) -> Self { - let mut sk_svp_ppol: SvpPPol = params.module().new_svp_ppol(); + let mut sk_svp_ppol: ScalarZnxDft = params.module().new_svp_ppol(); sk.prepare(params.module(), &mut sk_svp_ppol); Self { sk: sk_svp_ppol } } @@ -32,12 +32,12 @@ impl Parameters { ) } - pub fn decrypt_rlwe(&self, res: &mut Plaintext, ct: &Ciphertext, sk: &SvpPPol, tmp_bytes: &mut [u8]) { + pub fn decrypt_rlwe(&self, res: &mut Plaintext, ct: &Ciphertext, sk: &ScalarZnxDft, tmp_bytes: &mut [u8]) { decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes) } } -pub fn decrypt_rlwe(module: &Module, res: &mut Elem, a: &Elem, sk: &SvpPPol, tmp_bytes: &mut [u8]) { +pub fn decrypt_rlwe(module: &Module, res: &mut Elem, a: &Elem, sk: &ScalarZnxDft, tmp_bytes: &mut [u8]) { let cols: usize = a.cols(); assert!( diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 656cc3a..c6fe59f 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,4 +1,4 @@ -use base2k::{Infos, Layout, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps}; +use base2k::{ZnxInfos, Layout, Module, VecZnx, VecZnxOps, MatZnxDft, MatZnxDftOps}; pub struct Elem { pub value: Vec, @@ -81,7 +81,7 @@ pub trait ElemCommon { fn at_mut(&mut self, i: usize) -> &mut T; } -impl ElemCommon for Elem { +impl ElemCommon for Elem { fn n(&self) -> usize { self.value[0].n() } @@ -152,11 +152,11 @@ impl Elem { } } -impl Elem { +impl Elem { pub fn new(module: &Module, log_base2k: usize, size: usize, rows: usize, cols: usize) -> Self { assert!(rows > 0); assert!(cols > 0); - let mut value: Vec = Vec::new(); + let mut value: Vec = Vec::new(); (0..size).for_each(|_| value.push(module.new_vmp_pmat(1, rows, cols))); Self { value: value, diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs index bdb383c..7354a0f 100644 --- a/rlwe/src/encryptor.rs +++ b/rlwe/src/encryptor.rs @@ -5,8 +5,8 @@ use crate::parameters::Parameters; use crate::plaintext::Plaintext; use base2k::sampling::Sampling; use base2k::{ - Infos, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, - VmpPMatOps, + ZnxInfos, Module, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, MatZnxDft, + MatZnxDftOps, }; use sampling::source::{Source, new_seed}; @@ -19,7 +19,7 @@ impl Parameters { &self, ct: &mut Ciphertext, pt: Option<&Plaintext>, - sk: &SvpPPol, + sk: &ScalarZnxDft, source_xa: &mut Source, source_xe: &mut Source, tmp_bytes: &mut [u8], @@ -38,7 +38,7 @@ impl Parameters { } pub struct EncryptorSk { - sk: SvpPPol, + sk: ScalarZnxDft, source_xa: Source, source_xe: Source, initialized: bool, @@ -47,7 +47,7 @@ pub struct EncryptorSk { impl EncryptorSk { pub fn new(params: &Parameters, sk: Option<&SecretKey>) -> Self { - let mut sk_svp_ppol: SvpPPol = params.module().new_svp_ppol(); + let mut sk_svp_ppol: ScalarZnxDft = params.module().new_svp_ppol(); let mut initialized: bool = false; if let Some(sk) = sk { sk.prepare(params.module(), &mut sk_svp_ppol); @@ -114,7 +114,7 @@ pub fn encrypt_rlwe_sk( module: &Module, ct: &mut Elem, pt: Option<&VecZnx>, - sk: &SvpPPol, + sk: &ScalarZnxDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -127,7 +127,7 @@ fn encrypt_rlwe_sk_core( module: &Module, ct: &mut Elem, pt: Option<&VecZnx>, - sk: &SvpPPol, + sk: &ScalarZnxDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -217,9 +217,9 @@ pub fn encrypt_grlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usiz pub fn encrypt_grlwe_sk( module: &Module, - ct: &mut Ciphertext, + ct: &mut Ciphertext, m: &Scalar, - sk: &SvpPPol, + sk: &ScalarZnxDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -258,9 +258,9 @@ pub fn encrypt_rgsw_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usize pub fn encrypt_rgsw_sk( module: &Module, - ct: &mut Ciphertext, + ct: &mut Ciphertext, m: &Scalar, - sk: &SvpPPol, + sk: &ScalarZnxDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -302,10 +302,10 @@ pub fn encrypt_rgsw_sk( fn encrypt_grlwe_sk_core( module: &Module, log_base2k: usize, - mut ct: [&mut VmpPMat; 2], + mut ct: [&mut MatZnxDft; 2], log_q: usize, m: &Scalar, - sk: &SvpPPol, + sk: &ScalarZnxDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, diff --git a/rlwe/src/gadget_product.rs b/rlwe/src/gadget_product.rs index bbf9642..9315cd8 100644 --- a/rlwe/src/gadget_product.rs +++ b/rlwe/src/gadget_product.rs @@ -1,5 +1,5 @@ use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; -use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps}; +use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, MatZnxDft, MatZnxDftOps}; use std::cmp::min; pub fn gadget_product_core_tmp_bytes( @@ -34,7 +34,7 @@ pub fn gadget_product_core( res_dft_0: &mut VecZnxDft, res_dft_1: &mut VecZnxDft, a: &VecZnx, - b: &Ciphertext, + b: &Ciphertext, b_cols: usize, tmp_bytes: &mut [u8], ) { @@ -61,7 +61,7 @@ pub fn gadget_product_big( module: &Module, c: &mut Ciphertext, a: &Ciphertext, - b: &Ciphertext, + b: &Ciphertext, tmp_bytes: &mut [u8], ) { let cols: usize = min(c.cols(), a.cols()); @@ -94,7 +94,7 @@ pub fn gadget_product( module: &Module, c: &mut Ciphertext, a: &Ciphertext, - b: &Ciphertext, + b: &Ciphertext, tmp_bytes: &mut [u8], ) { let cols: usize = min(c.cols(), a.cols()); @@ -130,7 +130,7 @@ mod test { plaintext::Plaintext, }; use base2k::{ - BACKEND, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, + BACKEND, ZnxInfos, Sampling, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, MatZnxDft, alloc_aligned_u8, }; use sampling::source::{Source, new_seed}; @@ -175,16 +175,16 @@ mod test { // Two secret keys let mut sk0: SecretKey = SecretKey::new(params.module()); sk0.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); + let mut sk0_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0); let mut sk1: SecretKey = SecretKey::new(params.module()); sk1.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); + let mut sk1_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0); // The gadget ciphertext - let mut gadget_ct: Ciphertext = new_gadget_ciphertext( + let mut gadget_ct: Ciphertext = new_gadget_ciphertext( params.module(), log_base2k, params.cols_qp(), diff --git a/rlwe/src/key_generator.rs b/rlwe/src/key_generator.rs index 4f62a2c..88a2331 100644 --- a/rlwe/src/key_generator.rs +++ b/rlwe/src/key_generator.rs @@ -1,7 +1,7 @@ use crate::encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes}; use crate::keys::{PublicKey, SecretKey, SwitchingKey}; use crate::parameters::Parameters; -use base2k::{Module, SvpPPol}; +use base2k::{Module, ScalarZnxDft}; use sampling::source::Source; pub struct KeyGenerator {} @@ -16,7 +16,7 @@ impl KeyGenerator { pub fn gen_public_key_thread_safe( &self, params: &Parameters, - sk_ppol: &SvpPPol, + sk_ppol: &ScalarZnxDft, source: &mut Source, tmp_bytes: &mut [u8], ) -> PublicKey { @@ -43,7 +43,7 @@ pub fn gen_switching_key( module: &Module, swk: &mut SwitchingKey, sk_in: &SecretKey, - sk_out: &SvpPPol, + sk_out: &ScalarZnxDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, diff --git a/rlwe/src/key_switching.rs b/rlwe/src/key_switching.rs index 4e0001a..e73c7f9 100644 --- a/rlwe/src/key_switching.rs +++ b/rlwe/src/key_switching.rs @@ -1,6 +1,6 @@ use crate::ciphertext::Ciphertext; use crate::elem::ElemCommon; -use base2k::{Module, VecZnx, VecZnxBigOps, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement}; +use base2k::{Module, VecZnx, VecZnxBigOps, VecZnxDftOps, MatZnxDft, MatZnxDftOps, assert_alignement}; use std::cmp::min; pub fn key_switch_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { @@ -16,7 +16,7 @@ pub fn key_switch_rlwe( module: &Module, c: &mut Ciphertext, a: &Ciphertext, - b: &Ciphertext, + b: &Ciphertext, b_cols: usize, tmp_bytes: &mut [u8], ) { @@ -26,7 +26,7 @@ pub fn key_switch_rlwe( pub fn key_switch_rlwe_inplace( module: &Module, a: &mut Ciphertext, - b: &Ciphertext, + b: &Ciphertext, b_cols: usize, tmp_bytes: &mut [u8], ) { @@ -37,7 +37,7 @@ fn key_switch_rlwe_core( module: &Module, c: *mut Ciphertext, a: *const Ciphertext, - b: &Ciphertext, + b: &Ciphertext, b_cols: usize, tmp_bytes: &mut [u8], ) { @@ -74,6 +74,6 @@ fn key_switch_rlwe_core( module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut res_big, tmp_bytes); } -pub fn key_switch_grlwe(module: &Module, c: &mut Ciphertext, a: &Ciphertext, b: &Ciphertext) {} +pub fn key_switch_grlwe(module: &Module, c: &mut Ciphertext, a: &Ciphertext, b: &Ciphertext) {} -pub fn key_switch_rgsw(module: &Module, c: &mut Ciphertext, a: &Ciphertext, b: &Ciphertext) {} +pub fn key_switch_rgsw(module: &Module, c: &mut Ciphertext, a: &Ciphertext, b: &Ciphertext) {} diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index da7c412..6017159 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,7 +1,7 @@ use crate::ciphertext::{Ciphertext, new_gadget_ciphertext}; use crate::elem::{Elem, ElemCommon}; use crate::encryptor::{encrypt_rlwe_sk, encrypt_rlwe_sk_tmp_bytes}; -use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VmpPMat}; +use base2k::{Module, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, MatZnxDft}; use sampling::source::Source; pub struct SecretKey(pub Scalar); @@ -19,7 +19,7 @@ impl SecretKey { self.0.fill_ternary_hw(hw, source); } - pub fn prepare(&self, module: &Module, sk_ppol: &mut SvpPPol) { + pub fn prepare(&self, module: &Module, sk_ppol: &mut ScalarZnxDft) { module.svp_prepare(sk_ppol, &self.0) } } @@ -34,7 +34,7 @@ impl PublicKey { pub fn gen_thread_safe( &mut self, module: &Module, - sk: &SvpPPol, + sk: &ScalarZnxDft, xe: f64, xa_source: &mut Source, xe_source: &mut Source, @@ -57,7 +57,7 @@ impl PublicKey { } } -pub struct SwitchingKey(pub Ciphertext); +pub struct SwitchingKey(pub Ciphertext); impl SwitchingKey { pub fn new(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> SwitchingKey { diff --git a/rlwe/src/rgsw_product.rs b/rlwe/src/rgsw_product.rs index dc42602..1f76166 100644 --- a/rlwe/src/rgsw_product.rs +++ b/rlwe/src/rgsw_product.rs @@ -1,5 +1,5 @@ use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; -use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement}; +use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, MatZnxDft, MatZnxDftOps, assert_alignement}; use std::cmp::min; impl Parameters { @@ -26,7 +26,7 @@ pub fn rgsw_product( module: &Module, c: &mut Ciphertext, a: &Ciphertext, - b: &Ciphertext, + b: &Ciphertext, b_cols: usize, tmp_bytes: &mut [u8], ) { @@ -69,7 +69,7 @@ pub fn rgsw_product( pub fn rgsw_product_inplace( module: &Module, a: &mut Ciphertext, - b: &Ciphertext, + b: &Ciphertext, b_cols: usize, tmp_bytes: &mut [u8], ) { @@ -120,7 +120,7 @@ mod test { plaintext::Plaintext, rgsw_product::rgsw_product_inplace, }; - use base2k::{BACKEND, Encoding, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, VmpPMat, alloc_aligned}; + use base2k::{BACKEND, Encoding, Module, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxOps, MatZnxDft, alloc_aligned}; use sampling::source::{Source, new_seed}; #[test] @@ -164,10 +164,10 @@ mod test { let mut sk: SecretKey = SecretKey::new(module); sk.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol(); + let mut sk_svp_ppol: ScalarZnxDft = module.new_svp_ppol(); module.svp_prepare(&mut sk_svp_ppol, &sk.0); - let mut ct_rgsw: Ciphertext = new_rgsw_ciphertext(module, log_base2k, gct_rows, log_qp); + let mut ct_rgsw: Ciphertext = new_rgsw_ciphertext(module, log_base2k, gct_rows, log_qp); let k: i64 = 3; diff --git a/rlwe/src/trace.rs b/rlwe/src/trace.rs index 9e7feb8..005c497 100644 --- a/rlwe/src/trace.rs +++ b/rlwe/src/trace.rs @@ -1,5 +1,5 @@ use crate::{automorphism::AutomorphismKey, ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; -use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMatOps, assert_alignement}; +use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, MatZnxDftOps, assert_alignement}; use std::collections::HashMap; pub fn trace_galois_elements(module: &Module) -> Vec { @@ -115,7 +115,7 @@ mod test { parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral}, plaintext::Plaintext, }; - use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, alloc_aligned}; + use base2k::{BACKEND, Encoding, Module, ScalarZnxDft, ScalarZnxDftOps, VecZnx, alloc_aligned}; use sampling::source::{Source, new_seed}; use std::collections::HashMap; @@ -160,7 +160,7 @@ mod test { let mut sk: SecretKey = SecretKey::new(module); sk.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol(); + let mut sk_svp_ppol: ScalarZnxDft = module.new_svp_ppol(); module.svp_prepare(&mut sk_svp_ppol, &sk.0); let gal_els: Vec = trace_galois_elements(module); From 78b6e9544d37e63d7edd7e1064abfddc67d838be Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Mon, 28 Apr 2025 11:17:16 +0530 Subject: [PATCH 08/87] Updated all crates to edition 2024 and set workspace resolver to "3". `gen` is reserved keyword in 2024. So modigied `galois_element` function in base2k/src/module.rs for compat --- Cargo.toml | 2 +- base2k/Cargo.toml | 2 +- base2k/build.rs | 5 +++-- base2k/src/module.rs | 19 +++++++++++-------- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a17e5f7..b99028c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] members = ["base2k", "rlwe", "sampling", "utils"] - +resolver = "3" [workspace.dependencies] rug = "1.27" diff --git a/base2k/Cargo.toml b/base2k/Cargo.toml index 2ebb8db..089cbde 100644 --- a/base2k/Cargo.toml +++ b/base2k/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "base2k" version = "0.1.0" -edition = "2021" +edition = "2024" [dependencies] rug = {workspace = true} diff --git a/base2k/build.rs b/base2k/build.rs index 4ddb96c..f592b15 100644 --- a/base2k/build.rs +++ b/base2k/build.rs @@ -3,10 +3,11 @@ use std::path::absolute; fn main() { println!( "cargo:rustc-link-search=native={}", - absolute("./spqlios-arithmetic/build/spqlios") + absolute("spqlios-arithmetic/build/spqlios") .unwrap() .to_str() .unwrap() ); - println!("cargo:rustc-link-lib=static=spqlios"); //"cargo:rustc-link-lib=dylib=spqlios" + println!("cargo:rustc-link-lib=static=spqlios"); + // println!("cargo:rustc-link-lib=dylib=spqlios") } diff --git a/base2k/src/module.rs b/base2k/src/module.rs index 205cf62..c1799be 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -65,21 +65,24 @@ impl Module { (self.n() << 1) as _ } - // Returns GALOISGENERATOR^|gen| * sign(gen) - pub fn galois_element(&self, gen: i64) -> i64 { - if gen == 0 { + // Returns GALOISGENERATOR^|generator| * sign(generator) + pub fn galois_element(&self, generator: i64) -> i64 { + if generator == 0 { return 1; } - ((mod_exp_u64(GALOISGENERATOR, gen.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * gen.signum() + ((mod_exp_u64(GALOISGENERATOR, generator.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * generator.signum() } // Returns gen^-1 - pub fn galois_element_inv(&self, gen: i64) -> i64 { - if gen == 0 { + pub fn galois_element_inv(&self, generator: i64) -> i64 { + if generator == 0 { panic!("cannot invert 0") } - ((mod_exp_u64(gen.abs() as u64, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1)) as i64) - * gen.signum() + ((mod_exp_u64( + generator.abs() as u64, + (self.cyclotomic_order() - 1) as usize, + ) & (self.cyclotomic_order() - 1)) as i64) + * generator.signum() } pub fn free(self) { From 39bbe5b91704469c7956ed3adaad0dd8294113ed Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 28 Apr 2025 09:02:42 +0200 Subject: [PATCH 09/87] added tests for sampling (and indirectly stats) --- base2k/.vscode/settings.json | 8 +++ base2k/src/encoding.rs | 12 ++--- base2k/src/sampling.rs | 102 ++++++++++++++++++++++++++++------- base2k/src/stats.rs | 13 +++-- base2k/src/vec_znx_big.rs | 2 +- 5 files changed, 107 insertions(+), 30 deletions(-) create mode 100644 base2k/.vscode/settings.json diff --git a/base2k/.vscode/settings.json b/base2k/.vscode/settings.json new file mode 100644 index 0000000..eecbcdc --- /dev/null +++ b/base2k/.vscode/settings.json @@ -0,0 +1,8 @@ +{ + "github.copilot.enable": { + "*": false, + "plaintext": false, + "markdown": false, + "scminput": false + } +} \ No newline at end of file diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 6034b95..980dab4 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -271,9 +271,9 @@ mod tests { let n: usize = 8; let module: Module = Module::::new(n); let log_base2k: usize = 17; - let cols: usize = 5; - let log_k: usize = cols * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(&module, 2, cols); + let limbs: usize = 5; + let log_k: usize = limbs * log_base2k - 5; + let mut a: VecZnx = VecZnx::new(&module, 2, limbs); let mut source: Source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); @@ -293,9 +293,9 @@ mod tests { let n: usize = 8; let module: Module = Module::::new(n); let log_base2k: usize = 17; - let cols: usize = 5; - let log_k: usize = cols * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(&module, 2, cols); + let limbs: usize = 5; + let log_k: usize = limbs * log_base2k - 5; + let mut a: VecZnx = VecZnx::new(&module, 2, limbs); let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index c415b80..80d174c 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,4 +1,4 @@ -use crate::{Backend, Module, VecZnx, ZnxInfos, ZnxLayout}; +use crate::{Backend, Module, VecZnx, ZnxLayout}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; @@ -59,28 +59,25 @@ impl Sampling for Module { (bound.log2().ceil() as i64) ); + let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; let log_base2k_rem: usize = log_k % log_base2k; if log_base2k_rem != 0 { - a.at_poly_mut(col_i, a.limbs() - 1) - .iter_mut() - .for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += (dist_f64.round() as i64) << log_base2k_rem; - }); + a.at_poly_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += (dist_f64.round() as i64) << log_base2k_rem; + }); } else { - a.at_poly_mut(col_i, a.limbs() - 1) - .iter_mut() - .for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += dist_f64.round() as i64 - }); + a.at_poly_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += dist_f64.round() as i64 + }); } } @@ -105,3 +102,70 @@ impl Sampling for Module { ); } } + +#[cfg(test)] +mod tests { + use super::Sampling; + use crate::{FFT64, Module, Stats, VecZnx, ZnxBase, ZnxLayout}; + use sampling::source::Source; + + #[test] + fn fill_uniform() { + let n: usize = 4096; + let module: Module = Module::::new(n); + let log_base2k: usize = 17; + let limbs: usize = 5; + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let zero: Vec = vec![0; n]; + let one_12_sqrt: f64 = 0.28867513459481287; + (0..cols).for_each(|col_i| { + let mut a: VecZnx = VecZnx::new(&module, cols, limbs); + module.fill_uniform(log_base2k, &mut a, col_i, limbs, &mut source); + (0..cols).for_each(|col_j| { + if col_j != col_i { + (0..limbs).for_each(|limb_i| { + assert_eq!(a.at_poly(col_j, limb_i), zero); + }) + } else { + let std: f64 = a.std(col_i, log_base2k); + assert!( + (std - one_12_sqrt).abs() < 0.01, + "std={} ~!= {}", + std, + one_12_sqrt + ); + } + }) + }); + } + + #[test] + fn add_normal() { + let n: usize = 4096; + let module: Module = Module::::new(n); + let log_base2k: usize = 17; + let log_k: usize = 2 * 17; + let limbs: usize = 5; + let sigma: f64 = 3.2; + let bound: f64 = 6.0 * sigma; + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let zero: Vec = vec![0; n]; + let k_f64: f64 = (1u64 << log_k as u64) as f64; + (0..cols).for_each(|col_i| { + let mut a: VecZnx = VecZnx::new(&module, cols, limbs); + module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound); + (0..cols).for_each(|col_j| { + if col_j != col_i { + (0..limbs).for_each(|limb_i| { + assert_eq!(a.at_poly(col_j, limb_i), zero); + }) + } else { + let std: f64 = a.std(col_i, log_base2k) * k_f64; + assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma); + } + }) + }); + } +} diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs index 44e441f..7fcf7c3 100644 --- a/base2k/src/stats.rs +++ b/base2k/src/stats.rs @@ -3,11 +3,16 @@ use rug::Float; use rug::float::Round; use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound}; -impl VecZnx { - pub fn std(&self, poly_idx: usize, log_base2k: usize) -> f64 { - let prec: u32 = (self.cols() * log_base2k) as u32; +pub trait Stats { + /// Returns the standard devaition of the i-th polynomial. + fn std(&self, col_i: usize, log_base2k: usize) -> f64; +} + +impl Stats for VecZnx { + fn std(&self, col_i: usize, log_base2k: usize) -> f64 { + let prec: u32 = (self.limbs() * log_base2k) as u32; let mut data: Vec = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); - self.decode_vec_float(poly_idx, log_base2k, &mut data); + self.decode_vec_float(col_i, log_base2k, &mut data); // std = sqrt(sum((xi - avg)^2) / n) let mut avg: Float = Float::with_val(prec, 0); data.iter().for_each(|x| { diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 7a8cc48..8c67a8d 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,5 +1,5 @@ use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, ZnxBase, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; use std::marker::PhantomData; pub struct VecZnxBig { From 2f9a1cf6d9e493606dd80c4d10f735d1821cdda0 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 28 Apr 2025 10:33:15 +0200 Subject: [PATCH 10/87] refactoring of vec_znx --- base2k/examples/rlwe_encrypt.rs | 6 +- base2k/examples/vector_matrix_product.rs | 2 +- base2k/src/commons.rs | 227 ++++++- base2k/src/encoding.rs | 74 +-- base2k/src/lib.rs | 2 + base2k/src/mat_znx_dft.rs | 98 ++- base2k/src/sampling.rs | 22 +- base2k/src/scalar_znx_dft.rs | 2 +- base2k/src/stats.rs | 2 +- base2k/src/vec_znx.rs | 556 +--------------- base2k/src/vec_znx_big.rs | 86 ++- base2k/src/vec_znx_dft.rs | 84 ++- base2k/src/vec_znx_ops.rs | 795 +++++++++++++++++++++++ 13 files changed, 1218 insertions(+), 738 deletions(-) create mode 100644 base2k/src/vec_znx_ops.rs diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 3d53141..3661f0d 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -35,7 +35,7 @@ fn main() { module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); // Scratch space for DFT values - let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.limbs()); + let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.size()); // Applies buf_dft <- s * a module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); @@ -93,9 +93,9 @@ fn main() { // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; - res.decode_vec_i64(0, log_base2k, res.limbs() * log_base2k, &mut have); + res.decode_vec_i64(0, log_base2k, res.size() * log_base2k, &mut have); - let scale: f64 = (1 << (res.limbs() * log_base2k - log_scale)) as f64; + let scale: f64 = (1 << (res.size() * log_base2k - log_scale)) as f64; izip!(want.iter(), have.iter()) .enumerate() .for_each(|(i, (a, b))| { diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 0120f61..96a0df7 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -33,7 +33,7 @@ fn main() { let mut mat_znx_dft: MatZnxDft = module.new_mat_znx_dft(rows_mat, 1, limbs_mat); - (0..a.limbs()).for_each(|row_i| { + (0..a.size()).for_each(|row_i| { let mut tmp: VecZnx = module.new_vec_znx(1, limbs_mat); tmp.at_limb_mut(row_i)[1] = 1 as i64; module.vmp_prepare_row(&mut mat_znx_dft, tmp.raw(), row_i, &mut buf); diff --git a/base2k/src/commons.rs b/base2k/src/commons.rs index 290599d..1d7a0c9 100644 --- a/base2k/src/commons.rs +++ b/base2k/src/commons.rs @@ -1,11 +1,15 @@ -use crate::{Backend, Module}; +use crate::{Backend, Module, assert_alignement, cast_mut}; +use itertools::izip; +use std::cmp::{max, min}; pub trait ZnxInfos { /// Returns the ring degree of the polynomials. fn n(&self) -> usize; /// Returns the base two logarithm of the ring dimension of the polynomials. - fn log_n(&self) -> usize; + fn log_n(&self) -> usize { + (usize::BITS - (self.n() - 1).leading_zeros()) as _ + } /// Returns the number of rows. fn rows(&self) -> usize; @@ -13,21 +17,28 @@ pub trait ZnxInfos { /// Returns the number of polynomials in each row. fn cols(&self) -> usize; - /// Returns the number of limbs per polynomial. - fn limbs(&self) -> usize; + /// Returns the number of size per polynomial. + fn size(&self) -> usize; /// Returns the total number of small polynomials. - fn poly_count(&self) -> usize; + fn poly_count(&self) -> usize { + self.rows() * self.cols() * self.size() + } + + /// Returns the slice size, which is the offset between + /// two size of the same column. + fn sl(&self) -> usize { + self.n() * self.cols() + } } pub trait ZnxBase { type Scalar; - fn new(module: &Module, cols: usize, limbs: usize) -> Self; - fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self; - fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self; - fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize; + fn new(module: &Module, cols: usize, size: usize) -> Self; + fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self; + fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self; + fn bytes_of(module: &Module, cols: usize, size: usize) -> usize; } - pub trait ZnxLayout: ZnxInfos { type Scalar; @@ -52,7 +63,7 @@ pub trait ZnxLayout: ZnxInfos { #[cfg(debug_assertions)] { assert!(i < self.cols()); - assert!(j < self.limbs()); + assert!(j < self.size()); } let offset = self.n() * (j * self.cols() + i); unsafe { self.as_ptr().add(offset) } @@ -63,7 +74,7 @@ pub trait ZnxLayout: ZnxInfos { #[cfg(debug_assertions)] { assert!(i < self.cols()); - assert!(j < self.limbs()); + assert!(j < self.size()); } let offset = self.n() * (j * self.cols() + i); unsafe { self.as_mut_ptr().add(offset) } @@ -89,3 +100,195 @@ pub trait ZnxLayout: ZnxInfos { unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, j), self.n() * self.cols()) } } } + +use std::convert::TryFrom; +use std::num::TryFromIntError; +use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; +pub trait IntegerType: + Copy + + std::fmt::Debug + + Default + + PartialEq + + PartialOrd + + Add + + Sub + + Mul + + Div + + Neg + + Shr + + Shl + + AddAssign + + TryFrom +{ + const BITS: u32; +} + +impl IntegerType for i64 { + const BITS: u32 = 64; +} + +impl IntegerType for i128 { + const BITS: u32 = 128; +} + +pub trait ZnxBasics: ZnxLayout +where + Self: Sized, + Self::Scalar: IntegerType, +{ + fn zero(&mut self) { + unsafe { + std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * size_of::()); + } + } + + fn zero_at(&mut self, i: usize, j: usize) { + unsafe { + std::ptr::write_bytes( + self.at_mut_ptr(i, j), + 0, + self.n() * size_of::(), + ); + } + } + + fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { + rsh(log_base2k, self, k, carry) + } +} + +pub fn rsh(log_base2k: usize, a: &mut V, k: usize, tmp_bytes: &mut [u8]) +where + V::Scalar: IntegerType, +{ + let n: usize = a.n(); + let size: usize = a.size(); + let cols: usize = a.cols(); + + #[cfg(debug_assertions)] + { + assert!( + tmp_bytes.len() >= rsh_tmp_bytes::(n, cols), + "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", + tmp_bytes.len() / size_of::(), + n, + size, + ); + assert_alignement(tmp_bytes.as_ptr()); + } + + let size: usize = a.size(); + let steps: usize = k / log_base2k; + + a.raw_mut().rotate_right(n * steps * cols); + (0..cols).for_each(|i| { + (0..steps).for_each(|j| { + a.zero_at(i, j); + }) + }); + + let k_rem: usize = k % log_base2k; + + if k_rem != 0 { + let carry: &mut [V::Scalar] = cast_mut(tmp_bytes); + + unsafe { + std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::()); + } + + let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap(); + let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap(); + let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap(); + + (steps..size).for_each(|i| { + izip!(carry.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| { + *xi += *ci << log_base2k_t; + *ci = get_base_k_carry(*xi, shift); + *xi = (*xi - *ci) >> k_rem_t; + }); + }) + } +} + +#[inline(always)] +fn get_base_k_carry(x: T, shift: T) -> T { + (x << shift) >> shift +} + +pub fn rsh_tmp_bytes(n: usize, cols: usize) -> usize { + n * cols * std::mem::size_of::() +} + +pub fn switch_degree(b: &mut T, a: &T) +where + ::Scalar: IntegerType, +{ + let (n_in, n_out) = (a.n(), b.n()); + let (gap_in, gap_out): (usize, usize); + + if n_in > n_out { + (gap_in, gap_out) = (n_in / n_out, 1) + } else { + (gap_in, gap_out) = (1, n_out / n_in); + b.zero(); + } + + let size: usize = min(a.size(), b.size()); + + (0..size).for_each(|i| { + izip!( + a.at_limb(i).iter().step_by(gap_in), + b.at_limb_mut(i).iter_mut().step_by(gap_out) + ) + .for_each(|(x_in, x_out)| *x_out = *x_in); + }); +} + +pub fn znx_post_process_ternary_op(c: &mut T, a: &T, b: &T) +where + ::Scalar: IntegerType, +{ + #[cfg(debug_assertions)] + { + assert_ne!(a.as_ptr(), b.as_ptr()); + assert_ne!(b.as_ptr(), c.as_ptr()); + assert_ne!(a.as_ptr(), c.as_ptr()); + } + + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let c_cols: usize = c.cols(); + + let min_ab_cols: usize = min(a_cols, b_cols); + let max_ab_cols: usize = max(a_cols, b_cols); + + // Copies shared shared cols between (c, max(a, b)) + if a_cols != b_cols { + let mut x: &T = a; + if a_cols < b_cols { + x = b; + } + + let min_size = min(c.size(), x.size()); + (min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| { + (0..min_size).for_each(|j| { + c.at_poly_mut(i, j).copy_from_slice(x.at_poly(i, j)); + if NEGATE { + c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); + } + }); + (min_size..c.size()).for_each(|j| { + c.zero_at(i, j); + }); + }); + } + + // Zeroes the cols of c > max(a, b). + if c_cols > max_ab_cols { + (max_ab_cols..c_cols).for_each(|i| { + (0..c.size()).for_each(|j| { + c.zero_at(i, j); + }) + }); + } +} diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 980dab4..8c41381 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -81,15 +81,15 @@ impl Encoding for VecZnx { } fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { - let limbs: usize = (log_k + log_base2k - 1) / log_base2k; + let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!( - limbs <= a.limbs(), - "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.limbs()={}", - limbs, - a.limbs() + size <= a.size(), + "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.size()={}", + size, + a.size() ); assert!(col_i < a.cols()); assert!(data.len() <= a.n()) @@ -99,7 +99,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, let log_k_rem: usize = log_base2k - (log_k % log_base2k); // Zeroes coefficients of the i-th column - (0..a.limbs()).for_each(|i| unsafe { + (0..a.size()).for_each(|i| unsafe { znx_zero_i64_ref(a.n() as u64, a.at_mut_ptr(col_i, i)); }); @@ -107,11 +107,11 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_poly_mut(col_i, limbs - 1)[..data_len].copy_from_slice(&data[..data_len]); + a.at_poly_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]); } else { let mask: i64 = (1 << log_base2k) - 1; - let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); - (limbs - steps..limbs) + let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); + (size - steps..size) .rev() .enumerate() .for_each(|(i, i_rev)| { @@ -122,8 +122,8 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, // Case where self.prec % self.k != 0. if log_k_rem != log_base2k { - let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); - (limbs - steps..limbs).rev().for_each(|i| { + let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); + (size - steps..size).rev().for_each(|i| { a.at_poly_mut(col_i, i)[..data_len] .iter_mut() .for_each(|x| *x <<= log_k_rem); @@ -132,7 +132,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, } fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { - let limbs: usize = (log_k + log_base2k - 1) / log_base2k; + let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!( @@ -145,8 +145,8 @@ fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, dat } data.copy_from_slice(a.at_poly(col_i, 0)); let rem: usize = log_base2k - (log_k % log_base2k); - (1..limbs).for_each(|i| { - if i == limbs - 1 && rem != log_base2k { + (1..size).for_each(|i| { + if i == size - 1 && rem != log_base2k { let k_rem: usize = log_base2k - rem; izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << k_rem) + (x >> rem); @@ -160,7 +160,7 @@ fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, dat } fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) { - let limbs: usize = a.limbs(); + let size: usize = a.size(); #[cfg(debug_assertions)] { assert!( @@ -172,20 +172,20 @@ fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Flo assert!(col_i < a.cols()); } - let prec: u32 = (log_base2k * limbs) as u32; + let prec: u32 = (log_base2k * size) as u32; // 2^{log_base2k} let base = Float::with_val(prec, (1 << log_base2k) as f64); // y[i] = sum x[j][i] * 2^{-log_base2k*j} - (0..limbs).for_each(|i| { + (0..size).for_each(|i| { if i == 0 { - izip!(a.at_poly(col_i, limbs - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { y.assign(*x); *y /= &base; }); } else { - izip!(a.at_poly(col_i, limbs - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { *y += Float::with_val(prec, *x); *y /= &base; }); @@ -194,32 +194,32 @@ fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Flo } fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { - let limbs: usize = (log_k + log_base2k - 1) / log_base2k; + let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!(i < a.n()); assert!( - limbs <= a.limbs(), - "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.limbs()={}", - limbs, - a.limbs() + size <= a.size(), + "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.size()={}", + size, + a.size() ); assert!(col_i < a.cols()); } let log_k_rem: usize = log_base2k - (log_k % log_base2k); - (0..a.limbs()).for_each(|j| a.at_poly_mut(col_i, j)[i] = 0); + (0..a.size()).for_each(|j| a.at_poly_mut(col_i, j)[i] = 0); // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_poly_mut(col_i, limbs - 1)[i] = value; + a.at_poly_mut(col_i, size - 1)[i] = value; } else { let mask: i64 = (1 << log_base2k) - 1; - let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); - (limbs - steps..limbs) + let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); + (size - steps..size) .rev() .enumerate() .for_each(|(j, j_rev)| { @@ -229,8 +229,8 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz // Case where prec % k != 0. if log_k_rem != log_base2k { - let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); - (limbs - steps..limbs).rev().for_each(|j| { + let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); + (size - steps..size).rev().for_each(|j| { a.at_poly_mut(col_i, j)[i] <<= log_k_rem; }) } @@ -247,7 +247,7 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i let data: &[i64] = a.raw(); let mut res: i64 = data[i]; let rem: usize = log_base2k - (log_k % log_base2k); - let slice_size: usize = a.n() * a.limbs(); + let slice_size: usize = a.n() * a.size(); (1..cols).for_each(|i| { let x = data[i * slice_size]; if i == cols - 1 && rem != log_base2k { @@ -271,9 +271,9 @@ mod tests { let n: usize = 8; let module: Module = Module::::new(n); let log_base2k: usize = 17; - let limbs: usize = 5; - let log_k: usize = limbs * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(&module, 2, limbs); + let size: usize = 5; + let log_k: usize = size * log_base2k - 5; + let mut a: VecZnx = VecZnx::new(&module, 2, size); let mut source: Source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); @@ -293,9 +293,9 @@ mod tests { let n: usize = 8; let module: Module = Module::::new(n); let log_base2k: usize = 17; - let limbs: usize = 5; - let log_k: usize = limbs * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(&module, 2, limbs); + let size: usize = 5; + let log_k: usize = size * log_base2k - 5; + let mut a: VecZnx = VecZnx::new(&module, 2, size); let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 40df3bb..3c48319 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -11,6 +11,7 @@ pub mod stats; pub mod vec_znx; pub mod vec_znx_big; pub mod vec_znx_dft; +pub mod vec_znx_ops; pub use commons::*; pub use encoding::*; @@ -23,6 +24,7 @@ pub use stats::*; pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_dft::*; +pub use vec_znx_ops::*; pub const GALOISGENERATOR: u64 = 5; pub const DEFAULTALIGN: usize = 64; diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index 9466696..b40ed71 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -22,7 +22,7 @@ pub struct MatZnxDft { /// Number of cols cols: usize, /// The number of small polynomials - limbs: usize, + size: usize, _marker: PhantomData, } @@ -31,10 +31,6 @@ impl ZnxInfos for MatZnxDft { self.n } - fn log_n(&self) -> usize { - (usize::BITS - (self.n() - 1).leading_zeros()) as _ - } - fn rows(&self) -> usize { self.rows } @@ -43,18 +39,14 @@ impl ZnxInfos for MatZnxDft { self.cols } - fn limbs(&self) -> usize { - self.limbs - } - - fn poly_count(&self) -> usize { - self.rows * self.cols * self.limbs + fn size(&self) -> usize { + self.size } } impl MatZnxDft { - fn new(module: &Module, rows: usize, cols: usize, limbs: usize) -> MatZnxDft { - let mut data: Vec = alloc_aligned::(module.bytes_of_mat_znx_dft(rows, cols, limbs)); + fn new(module: &Module, rows: usize, cols: usize, size: usize) -> MatZnxDft { + let mut data: Vec = alloc_aligned::(module.bytes_of_mat_znx_dft(rows, cols, size)); let ptr: *mut u8 = data.as_mut_ptr(); MatZnxDft:: { data: data, @@ -62,7 +54,7 @@ impl MatZnxDft { n: module.n(), rows: rows, cols: cols, - limbs: limbs, + size: size, _marker: PhantomData, } } @@ -115,7 +107,7 @@ impl MatZnxDft { fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { let nrows: usize = self.rows(); - let nsize: usize = self.limbs(); + let nsize: usize = self.size(); if col == (nsize - 1) && (nsize & 1 == 1) { &self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..] } else { @@ -127,7 +119,7 @@ impl MatZnxDft { /// This trait implements methods for vector matrix product, /// that is, multiplying a [VecZnx] with a [VmpPMat]. pub trait MatZnxDftOps { - fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize; + fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize; /// Allocates a new [VmpPMat] with the given number of rows and columns. /// @@ -135,7 +127,7 @@ pub trait MatZnxDftOps { /// /// * `rows`: number of rows (number of [VecZnxDft]). /// * `size`: number of size (number of size of each [VecZnxDft]). - fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft; + fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft; /// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous]. /// @@ -351,12 +343,12 @@ pub trait MatZnxDftOps { } impl MatZnxDftOps for Module { - fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft { - MatZnxDft::::new(self, rows, cols, limbs) + fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft { + MatZnxDft::::new(self, rows, cols, size) } - fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize { - unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs * cols) as u64) as usize } + fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize { + unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (size * cols) as u64) as usize } } fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize { @@ -367,7 +359,7 @@ impl MatZnxDftOps for Module { #[cfg(debug_assertions)] { assert_eq!(a.len(), b.n() * b.poly_count()); - assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs())); + assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size())); assert_alignement(tmp_bytes.as_ptr()); } unsafe { @@ -376,7 +368,7 @@ impl MatZnxDftOps for Module { b.as_mut_ptr() as *mut vmp_pmat_t, a.as_ptr(), b.rows() as u64, - (b.limbs() * b.cols()) as u64, + (b.size() * b.cols()) as u64, tmp_bytes.as_mut_ptr(), ); } @@ -385,8 +377,8 @@ impl MatZnxDftOps for Module { fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] { - assert_eq!(a.len(), b.limbs() * self.n() * b.cols()); - assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs())); + assert_eq!(a.len(), b.size() * self.n() * b.cols()); + assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size())); assert_alignement(tmp_bytes.as_ptr()); } unsafe { @@ -396,7 +388,7 @@ impl MatZnxDftOps for Module { a.as_ptr(), row_i as u64, b.rows() as u64, - (b.limbs() * b.cols()) as u64, + (b.size() * b.cols()) as u64, tmp_bytes.as_mut_ptr(), ); } @@ -406,7 +398,7 @@ impl MatZnxDftOps for Module { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); - assert_eq!(a.limbs(), b.limbs()); + assert_eq!(a.size(), b.size()); assert_eq!(a.cols(), b.cols()); } unsafe { @@ -416,7 +408,7 @@ impl MatZnxDftOps for Module { a.as_ptr() as *const vmp_pmat_t, row_i as u64, a.rows() as u64, - (a.limbs() * a.cols()) as u64, + (a.size() * a.cols()) as u64, ); } } @@ -425,7 +417,7 @@ impl MatZnxDftOps for Module { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); - assert_eq!(a.limbs(), b.limbs()); + assert_eq!(a.size(), b.size()); } unsafe { vmp::vmp_prepare_row_dft( @@ -434,7 +426,7 @@ impl MatZnxDftOps for Module { a.ptr as *const vec_znx_dft_t, row_i as u64, b.rows() as u64, - b.limbs() as u64, + b.size() as u64, ); } } @@ -443,7 +435,7 @@ impl MatZnxDftOps for Module { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); - assert_eq!(a.limbs(), b.limbs()); + assert_eq!(a.size(), b.size()); } unsafe { vmp::vmp_extract_row_dft( @@ -452,7 +444,7 @@ impl MatZnxDftOps for Module { a.as_ptr() as *const vmp_pmat_t, row_i as u64, a.rows() as u64, - a.limbs() as u64, + a.size() as u64, ); } } @@ -470,7 +462,7 @@ impl MatZnxDftOps for Module { } fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -479,20 +471,20 @@ impl MatZnxDftOps for Module { vmp::vmp_apply_dft( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.limbs() as u64, + c.size() as u64, a.as_ptr(), - a.limbs() as u64, + a.size() as u64, (a.n() * a.cols()) as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.limbs() as u64, + b.size() as u64, tmp_bytes.as_mut_ptr(), ) } } fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -501,13 +493,13 @@ impl MatZnxDftOps for Module { vmp::vmp_apply_dft_add( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.limbs() as u64, + c.size() as u64, a.as_ptr(), - a.limbs() as u64, - (a.n() * a.limbs()) as u64, + a.size() as u64, + (a.n() * a.size()) as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.limbs() as u64, + b.size() as u64, tmp_bytes.as_mut_ptr(), ) } @@ -526,7 +518,7 @@ impl MatZnxDftOps for Module { } fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -535,12 +527,12 @@ impl MatZnxDftOps for Module { vmp::vmp_apply_dft_to_dft( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.limbs() as u64, + c.size() as u64, a.ptr as *const vec_znx_dft_t, - a.limbs() as u64, + a.size() as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.limbs() as u64, + b.size() as u64, tmp_bytes.as_mut_ptr(), ) } @@ -553,7 +545,7 @@ impl MatZnxDftOps for Module { b: &MatZnxDft, tmp_bytes: &mut [u8], ) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -562,19 +554,19 @@ impl MatZnxDftOps for Module { vmp::vmp_apply_dft_to_dft_add( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.limbs() as u64, + c.size() as u64, a.ptr as *const vec_znx_dft_t, - a.limbs() as u64, + a.size() as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.limbs() as u64, + b.size() as u64, tmp_bytes.as_mut_ptr(), ) } } fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.limbs(), b.limbs(), a.rows(), a.limbs())); + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.size(), b.size(), a.rows(), a.size())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -583,12 +575,12 @@ impl MatZnxDftOps for Module { vmp::vmp_apply_dft_to_dft( self.ptr, b.ptr as *mut vec_znx_dft_t, - b.limbs() as u64, + b.size() as u64, b.ptr as *mut vec_znx_dft_t, - b.limbs() as u64, + b.size() as u64, a.as_ptr() as *const vmp_pmat_t, a.rows() as u64, - a.limbs() as u64, + a.size() as u64, tmp_bytes.as_mut_ptr(), ) } diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index 80d174c..a96937e 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -3,8 +3,8 @@ use rand_distr::{Distribution, Normal}; use sampling::source::Source; pub trait Sampling { - /// Fills the first `limbs` limbs with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source); + /// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] + fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source); /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. fn add_dist_f64>( @@ -32,11 +32,11 @@ pub trait Sampling { } impl Sampling for Module { - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source) { + fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source) { let base2k: u64 = 1 << log_base2k; let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; - (0..limbs).for_each(|j| { + (0..size).for_each(|j| { a.at_poly_mut(col_i, j) .iter_mut() .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); @@ -114,17 +114,17 @@ mod tests { let n: usize = 4096; let module: Module = Module::::new(n); let log_base2k: usize = 17; - let limbs: usize = 5; + let size: usize = 5; let mut source: Source = Source::new([0u8; 32]); let cols: usize = 2; let zero: Vec = vec![0; n]; let one_12_sqrt: f64 = 0.28867513459481287; (0..cols).for_each(|col_i| { - let mut a: VecZnx = VecZnx::new(&module, cols, limbs); - module.fill_uniform(log_base2k, &mut a, col_i, limbs, &mut source); + let mut a: VecZnx = VecZnx::new(&module, cols, size); + module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source); (0..cols).for_each(|col_j| { if col_j != col_i { - (0..limbs).for_each(|limb_i| { + (0..size).for_each(|limb_i| { assert_eq!(a.at_poly(col_j, limb_i), zero); }) } else { @@ -146,7 +146,7 @@ mod tests { let module: Module = Module::::new(n); let log_base2k: usize = 17; let log_k: usize = 2 * 17; - let limbs: usize = 5; + let size: usize = 5; let sigma: f64 = 3.2; let bound: f64 = 6.0 * sigma; let mut source: Source = Source::new([0u8; 32]); @@ -154,11 +154,11 @@ mod tests { let zero: Vec = vec![0; n]; let k_f64: f64 = (1u64 << log_k as u64) as f64; (0..cols).for_each(|col_i| { - let mut a: VecZnx = VecZnx::new(&module, cols, limbs); + let mut a: VecZnx = VecZnx::new(&module, cols, size); module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound); (0..cols).for_each(|col_j| { if col_j != col_i { - (0..limbs).for_each(|limb_i| { + (0..size).for_each(|limb_i| { assert_eq!(a.at_poly(col_j, limb_i), zero); }) } else { diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index 7457ca2..cfe2f45 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -120,7 +120,7 @@ impl Scalar { VecZnx { n: self.n, cols: 1, - limbs: 1, + size: 1, data: Vec::new(), ptr: self.ptr, } diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs index 7fcf7c3..4e2a512 100644 --- a/base2k/src/stats.rs +++ b/base2k/src/stats.rs @@ -10,7 +10,7 @@ pub trait Stats { impl Stats for VecZnx { fn std(&self, col_i: usize, log_base2k: usize) -> f64 { - let prec: u32 = (self.limbs() * log_base2k) as u32; + let prec: u32 = (self.size() * log_base2k) as u32; let mut data: Vec = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); self.decode_vec_float(col_i, log_base2k, &mut data); // std = sqrt(sum((xi - avg)^2) / n) diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 89173f0..1bb8ab3 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,11 +1,10 @@ use crate::Backend; use crate::ZnxBase; use crate::cast_mut; -use crate::ffi::vec_znx; use crate::ffi::znx; -use crate::{Module, ZnxInfos, ZnxLayout}; +use crate::switch_degree; +use crate::{Module, ZnxBasics, ZnxInfos, ZnxLayout}; use crate::{alloc_aligned, assert_alignement}; -use itertools::izip; use std::cmp::min; /// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of @@ -26,8 +25,8 @@ pub struct VecZnx { /// The number of polynomials pub cols: usize, - /// The number of limbs per polynomial (a.k.a small polynomials). - pub limbs: usize, + /// The number of size per polynomial (a.k.a small polynomials). + pub size: usize, /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. pub data: Vec, @@ -41,10 +40,6 @@ impl ZnxInfos for VecZnx { self.n } - fn log_n(&self) -> usize { - (usize::BITS - (self.n() - 1).leading_zeros()) as _ - } - fn rows(&self) -> usize { 1 } @@ -53,12 +48,8 @@ impl ZnxInfos for VecZnx { self.cols } - fn limbs(&self) -> usize { - self.limbs - } - - fn poly_count(&self) -> usize { - self.cols * self.limbs + fn size(&self) -> usize { + self.size } } @@ -74,6 +65,8 @@ impl ZnxLayout for VecZnx { } } +impl ZnxBasics for VecZnx {} + /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. /// Panics if the cols do not match. @@ -89,28 +82,28 @@ impl ZnxBase for VecZnx { type Scalar = i64; /// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\]. - fn new(module: &Module, cols: usize, limbs: usize) -> Self { + fn new(module: &Module, cols: usize, size: usize) -> Self { let n: usize = module.n(); #[cfg(debug_assertions)] { assert!(n > 0); assert!(n & (n - 1) == 0); assert!(cols > 0); - assert!(limbs > 0); + assert!(size > 0); } - let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, limbs)); + let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, size)); let ptr: *mut i64 = data.as_mut_ptr(); Self { n: n, cols: cols, - limbs: limbs, + size: size, data: data, ptr: ptr, } } - fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { - module.n() * cols * limbs * size_of::() + fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { + module.n() * cols * size * size_of::() } /// Returns a new struct implementing [VecZnx] with the provided data as backing array. @@ -118,14 +111,14 @@ impl ZnxBase for VecZnx { /// The struct will take ownership of buf[..[Self::bytes_of]] /// /// User must ensure that data is properly alligned and that - /// the limbs of data is equal to [Self::bytes_of]. - fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + /// the size of data is equal to [Self::bytes_of]. + fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self { let n: usize = module.n(); #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); + assert!(size > 0); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); assert_alignement(bytes.as_ptr()); } unsafe { @@ -134,25 +127,25 @@ impl ZnxBase for VecZnx { Self { n: n, cols: cols, - limbs: limbs, + size: size, data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), ptr: ptr, } } } - fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); - assert!(bytes.len() >= Self::bytes_of(module, cols, limbs)); + assert!(size > 0); + assert!(bytes.len() >= Self::bytes_of(module, cols, size)); assert_alignement(bytes.as_ptr()); } Self { n: module.n(), cols: cols, - limbs: limbs, + size: size, data: Vec::new(), ptr: bytes.as_mut_ptr() as *mut i64, } @@ -173,16 +166,16 @@ impl VecZnx { if !self.borrowing() { self.data - .truncate(self.n() * self.cols() * (self.limbs() - k / log_base2k)); + .truncate(self.n() * self.cols() * (self.size() - k / log_base2k)); } - self.limbs -= k / log_base2k; + self.size -= k / log_base2k; let k_rem: usize = k % log_base2k; if k_rem != 0 { let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem; - self.at_limb_mut(self.limbs() - 1) + self.at_limb_mut(self.size() - 1) .iter_mut() .for_each(|x: &mut i64| *x &= mask) } @@ -196,52 +189,22 @@ impl VecZnx { self.data.len() == 0 } - pub fn zero(&mut self) { - unsafe { znx::znx_zero_i64_ref((self.n * self.poly_count()) as u64, self.ptr) } - } - pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { normalize(log_base2k, self, carry) } - pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { - rsh(log_base2k, self, k, carry) - } - pub fn switch_degree(&self, a: &mut Self) { switch_degree(a, self) } // Prints the first `n` coefficients of each limb pub fn print(&self, n: usize) { - (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])) + (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])) } } -pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) { - let (n_in, n_out) = (a.n(), b.n()); - let (gap_in, gap_out): (usize, usize); - - if n_in > n_out { - (gap_in, gap_out) = (n_in / n_out, 1) - } else { - (gap_in, gap_out) = (1, n_out / n_in); - b.zero(); - } - - let limbs: usize = min(a.limbs(), b.limbs()); - - (0..limbs).for_each(|i| { - izip!( - a.at_limb(i).iter().step_by(gap_in), - b.at_limb_mut(i).iter_mut().step_by(gap_out) - ) - .for_each(|(x_in, x_out)| *x_out = *x_in); - }); -} - -fn normalize_tmp_bytes(n: usize, limbs: usize) -> usize { - n * limbs * std::mem::size_of::() +fn normalize_tmp_bytes(n: usize, size: usize) -> usize { + n * size * std::mem::size_of::() } fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { @@ -264,7 +227,7 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { unsafe { znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); - (0..a.limbs()).rev().for_each(|i| { + (0..a.size()).rev().for_each(|i| { znx::znx_normalize( (n * cols) as u64, log_base2k as u64, @@ -276,462 +239,3 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { }); } } - -pub fn rsh_tmp_bytes(n: usize, limbs: usize) -> usize { - n * limbs * std::mem::size_of::() -} - -pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) { - let n: usize = a.n(); - let limbs: usize = a.limbs(); - - #[cfg(debug_assertions)] - { - assert!( - tmp_bytes.len() >= rsh_tmp_bytes(n, limbs), - "invalid carry: carry.len()/8={} < rsh_tmp_bytes({}, {})", - tmp_bytes.len() >> 3, - n, - limbs, - ); - assert_alignement(tmp_bytes.as_ptr()); - } - - let limbs: usize = a.limbs(); - let size_steps: usize = k / log_base2k; - - a.raw_mut().rotate_right(n * limbs * size_steps); - unsafe { - znx::znx_zero_i64_ref((n * limbs * size_steps) as u64, a.as_mut_ptr()); - } - - let k_rem = k % log_base2k; - - if k_rem != 0 { - let carry_i64: &mut [i64] = cast_mut(tmp_bytes); - - unsafe { - znx::znx_zero_i64_ref((n * limbs) as u64, carry_i64.as_mut_ptr()); - } - - let log_base2k: usize = log_base2k; - - (size_steps..limbs).for_each(|i| { - izip!(carry_i64.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| { - *xi += *ci << log_base2k; - *ci = get_base_k_carry(*xi, k_rem); - *xi = (*xi - *ci) >> k_rem; - }); - }) - } -} - -#[inline(always)] -fn get_base_k_carry(x: i64, k: usize) -> i64 { - (x << 64 - k) >> (64 - k) -} - -pub trait VecZnxOps { - /// Allocates a new [VecZnx]. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials. - /// * `limbs`: the number of limbs per polynomial (a.k.a small polynomials). - fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx; - - fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx; - fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx; - - /// Returns the minimum number of bytes necessary to allocate - /// a new [VecZnx] through [VecZnx::from_bytes]. - fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; - - fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize; - - /// c <- a + b. - fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); - - /// b <- b + a. - fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx); - - /// c <- a - b. - fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); - - /// b <- a - b. - fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx); - - /// b <- b - a. - fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx); - - /// b <- -a. - fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx); - - /// b <- -b. - fn vec_znx_negate_inplace(&self, a: &mut VecZnx); - - /// b <- a * X^k (mod X^{n} + 1) - fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx); - - /// a <- a * X^k (mod X^{n} + 1) - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx); - - /// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) - fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx); - - /// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx); - - /// Splits b into subrings and copies them them into a. - /// - /// # Panics - /// - /// This method requires that all [VecZnx] of b have the same ring degree - /// and that b.n() * b.len() <= a.n() - fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx); - - /// Merges the subrings a into b. - /// - /// # Panics - /// - /// This method requires that all [VecZnx] of a have the same ring degree - /// and that a.n() * a.len() <= b.n() - fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec); -} - -impl VecZnxOps for Module { - fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx { - VecZnx::new(self, cols, limbs) - } - - fn bytes_of_vec_znx(&self, cols: usize, limbs: usize) -> usize { - VecZnx::bytes_of(self, cols, limbs) - } - - fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx { - VecZnx::from_bytes(self, cols, limbs, bytes) - } - - fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx { - VecZnx::from_bytes_borrow(self, cols, limbs, tmp_bytes) - } - - fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols } - } - - // c <- a + b - fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(c.n(), n); - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_add( - self.ptr, - c.as_mut_ptr(), - c.limbs() as u64, - (n * c.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - b.as_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - ) - } - } - - // b <- a + b - fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_add( - self.ptr, - b.as_mut_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - b.as_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - ) - } - } - - // c <- a + b - fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(c.n(), n); - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - c.as_mut_ptr(), - c.limbs() as u64, - (n * c.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - b.as_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - ) - } - } - - // b <- a - b - fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - b.as_mut_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - b.as_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - ) - } - } - - // b <- b - a - fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - b.as_mut_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - b.as_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - ) - } - } - - fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_negate( - self.ptr, - b.as_mut_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - ) - } - } - - fn vec_znx_negate_inplace(&self, a: &mut VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - } - unsafe { - vec_znx::vec_znx_negate( - self.ptr, - a.as_mut_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - ) - } - } - - fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_rotate( - self.ptr, - k, - b.as_mut_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - ) - } - } - - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - } - unsafe { - vec_znx::vec_znx_rotate( - self.ptr, - k, - a.as_mut_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - ) - } - } - - /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size. - /// - /// # Arguments - /// - /// * `a`: input. - /// * `b`: output. - /// * `k`: the power to which to map each coefficients. - /// * `a_size`: the number of a_size on which to apply the mapping. - /// - /// # Panics - /// - /// The method will panic if the argument `a` is greater than `a.limbs()`. - fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_automorphism( - self.ptr, - k, - b.as_mut_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - ); - } - } - - /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size. - /// - /// # Arguments - /// - /// * `a`: input and output. - /// * `k`: the power to which to map each coefficients. - /// * `a_size`: the number of size on which to apply the mapping. - /// - /// # Panics - /// - /// The method will panic if the argument `size` is greater than `self.limbs()`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - } - unsafe { - vec_znx::vec_znx_automorphism( - self.ptr, - k, - a.as_mut_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - ); - } - } - - fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx) { - let (n_in, n_out) = (a.n(), b[0].n()); - - debug_assert!( - n_out < n_in, - "invalid a: output ring degree should be smaller" - ); - b[1..].iter().for_each(|bi| { - debug_assert_eq!( - bi.n(), - n_out, - "invalid input a: all VecZnx must have the same degree" - ) - }); - - b.iter_mut().enumerate().for_each(|(i, bi)| { - if i == 0 { - switch_degree(bi, a); - self.vec_znx_rotate(-1, buf, a); - } else { - switch_degree(bi, buf); - self.vec_znx_rotate_inplace(-1, buf); - } - }) - } - - fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec) { - let (n_in, n_out) = (b.n(), a[0].n()); - - debug_assert!( - n_out < n_in, - "invalid a: output ring degree should be smaller" - ); - a[1..].iter().for_each(|ai| { - debug_assert_eq!( - ai.n(), - n_out, - "invalid input a: all VecZnx must have the same degree" - ) - }); - - a.iter().enumerate().for_each(|(_, ai)| { - switch_degree(b, ai); - self.vec_znx_rotate_inplace(-1, b); - }); - - self.vec_znx_rotate_inplace(a.len() as i64, b); - } -} diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 8c67a8d..7f647da 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -7,43 +7,43 @@ pub struct VecZnxBig { pub ptr: *mut u8, pub n: usize, pub cols: usize, - pub limbs: usize, + pub size: usize, pub _marker: PhantomData, } impl ZnxBase for VecZnxBig { type Scalar = u8; - fn new(module: &Module, cols: usize, limbs: usize) -> Self { + fn new(module: &Module, cols: usize, size: usize) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); + assert!(size > 0); } - let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, limbs)); + let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, size)); let ptr: *mut Self::Scalar = data.as_mut_ptr(); Self { data: data, ptr: ptr, n: module.n(), cols: cols, - limbs: limbs, + size: size, _marker: PhantomData, } } - fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { - unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, limbs as u64) as usize * cols } + fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { + unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } } /// Returns a new [VecZnxBig] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. - fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { + fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); + assert!(size > 0); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); assert_alignement(bytes.as_ptr()) }; unsafe { @@ -52,18 +52,18 @@ impl ZnxBase for VecZnxBig { ptr: bytes.as_mut_ptr(), n: module.n(), cols: cols, - limbs: limbs, + size: size, _marker: PhantomData, } } } - fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { + fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); + assert!(size > 0); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); assert_alignement(bytes.as_ptr()); } Self { @@ -71,17 +71,13 @@ impl ZnxBase for VecZnxBig { ptr: bytes.as_mut_ptr(), n: module.n(), cols: cols, - limbs: limbs, + size: size, _marker: PhantomData, } } } impl ZnxInfos for VecZnxBig { - fn log_n(&self) -> usize { - (usize::BITS - (self.n - 1).leading_zeros()) as _ - } - fn n(&self) -> usize { self.n } @@ -94,12 +90,8 @@ impl ZnxInfos for VecZnxBig { 1 } - fn limbs(&self) -> usize { - self.limbs - } - - fn poly_count(&self) -> usize { - self.cols * self.limbs + fn size(&self) -> usize { + self.size } } @@ -117,13 +109,13 @@ impl ZnxLayout for VecZnxBig { impl VecZnxBig { pub fn print(&self, n: usize) { - (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); + (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); } } pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. - fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig; + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -132,12 +124,12 @@ pub trait VecZnxBigOps { /// # Arguments /// /// * `cols`: the number of polynomials.. - /// * `limbs`: the number of limbs (a.k.a small polynomials) per polynomial. + /// * `size`: the number of size (a.k.a small polynomials) per polynomial. /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxBig; + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -146,25 +138,25 @@ pub trait VecZnxBigOps { /// # Arguments /// /// * `cols`: the number of polynomials.. - /// * `limbs`: the number of limbs (a.k.a small polynomials) per polynomial. + /// * `size`: the number of size (a.k.a small polynomials) per polynomial. /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; + fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. - fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize; + fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; /// b[VecZnxBig] <- b[VecZnxBig] - a[VecZnx] /// /// # Behavior /// - /// [VecZnxBig] (3 cols and 4 limbs) + /// [VecZnxBig] (3 cols and 4 size) /// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3] /// - - /// [VecZnx] (2 cols and 3 limbs) + /// [VecZnx] (2 cols and 3 size) /// [d0, e0] [d1, e1] [d2, e2] /// = /// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3] @@ -203,26 +195,26 @@ pub trait VecZnxBigOps { } impl VecZnxBigOps for Module { - fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig { - VecZnxBig::new(self, cols, limbs) + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig { + VecZnxBig::new(self, cols, size) } - fn new_vec_znx_big_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes(self, cols, limbs, bytes) + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig { + VecZnxBig::from_bytes(self, cols, size, bytes) } - fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes_borrow(self, cols, limbs, tmp_bytes) + fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { + VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes) } - fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize { - VecZnxBig::bytes_of(self, cols, limbs) + fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { + VecZnxBig::bytes_of(self, cols, size) } - /// [VecZnxBig] (3 cols and 4 limbs) + /// [VecZnxBig] (3 cols and 4 size) /// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3] /// - - /// [VecZnx] (2 cols and 3 limbs) + /// [VecZnx] (2 cols and 3 size) /// [d0, e0] [d1, e1] [d2, e2] /// = /// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3] @@ -306,10 +298,10 @@ impl VecZnxBigOps for Module { self.ptr, log_base2k as u64, b.as_mut_ptr(), - b.limbs() as u64, + b.size() as u64, b.n() as u64, a.ptr as *mut vec_znx_big_t, - a.limbs() as u64, + a.size() as u64, tmp_bytes.as_mut_ptr(), ) } @@ -344,7 +336,7 @@ impl VecZnxBigOps for Module { self.ptr, log_base2k as u64, res.as_mut_ptr(), - res.limbs() as u64, + res.size() as u64, res.n() as u64, a.ptr as *mut vec_znx_big_t, a_range_begin as u64, diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 7724710..d9c9e60 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -10,44 +10,44 @@ pub struct VecZnxDft { pub ptr: *mut u8, pub n: usize, pub cols: usize, - pub limbs: usize, + pub size: usize, pub _marker: PhantomData, } impl ZnxBase for VecZnxDft { type Scalar = u8; - fn new(module: &Module, cols: usize, limbs: usize) -> Self { + fn new(module: &Module, cols: usize, size: usize) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); + assert!(size > 0); } - let mut data: Vec = alloc_aligned(Self::bytes_of(module, cols, limbs)); + let mut data: Vec = alloc_aligned(Self::bytes_of(module, cols, size)); let ptr: *mut Self::Scalar = data.as_mut_ptr(); Self { data: data, ptr: ptr, n: module.n(), - limbs: limbs, + size: size, cols: cols, _marker: PhantomData, } } - fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { - unsafe { bytes_of_vec_znx_dft(module.ptr, limbs as u64) as usize * cols } + fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { + unsafe { bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } } /// Returns a new [VecZnxDft] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. - fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { + fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); + assert!(size > 0); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); assert_alignement(bytes.as_ptr()) } unsafe { @@ -56,18 +56,18 @@ impl ZnxBase for VecZnxDft { ptr: bytes.as_mut_ptr(), n: module.n(), cols: cols, - limbs: limbs, + size: size, _marker: PhantomData, } } } - fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { + fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); + assert!(size > 0); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); assert_alignement(bytes.as_ptr()); } Self { @@ -75,7 +75,7 @@ impl ZnxBase for VecZnxDft { ptr: bytes.as_mut_ptr(), n: module.n(), cols: cols, - limbs: limbs, + size: size, _marker: PhantomData, } } @@ -91,7 +91,7 @@ impl VecZnxDft { ptr: self.ptr, n: self.n, cols: self.cols, - limbs: self.limbs, + size: self.size, _marker: PhantomData, } } @@ -102,10 +102,6 @@ impl ZnxInfos for VecZnxDft { self.n } - fn log_n(&self) -> usize { - (usize::BITS - (self.n() - 1).leading_zeros()) as _ - } - fn rows(&self) -> usize { 1 } @@ -114,12 +110,8 @@ impl ZnxInfos for VecZnxDft { self.cols } - fn limbs(&self) -> usize { - self.limbs - } - - fn poly_count(&self) -> usize { - self.cols * self.limbs + fn size(&self) -> usize { + self.size } } @@ -137,13 +129,13 @@ impl ZnxLayout for VecZnxDft { impl VecZnxDft { pub fn print(&self, n: usize) { - (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); + (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); } } pub trait VecZnxDftOps { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. - fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft; + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -156,7 +148,7 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxDft; + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -169,7 +161,7 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxDft; + fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -180,7 +172,7 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn bytes_of_vec_znx_dft(&self, cols: usize, limbs: usize) -> usize; + fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. @@ -201,20 +193,20 @@ pub trait VecZnxDftOps { } impl VecZnxDftOps for Module { - fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft { - VecZnxDft::::new(&self, cols, limbs) + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft { + VecZnxDft::::new(&self, cols, size) } - fn new_vec_znx_dft_from_bytes(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes(self, cols, limbs, tmp_bytes) + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { + VecZnxDft::from_bytes(self, cols, size, tmp_bytes) } - fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes_borrow(self, cols, limbs, tmp_bytes) + fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { + VecZnxDft::from_bytes_borrow(self, cols, size, tmp_bytes) } - fn bytes_of_vec_znx_dft(&self, cols: usize, limbs: usize) -> usize { - VecZnxDft::bytes_of(&self, cols, limbs) + fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { + VecZnxDft::bytes_of(&self, cols, size) } fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) { @@ -242,9 +234,9 @@ impl VecZnxDftOps for Module { vec_znx_dft::vec_znx_dft( self.ptr, b.ptr as *mut vec_znx_dft_t, - b.limbs() as u64, + b.size() as u64, a.as_ptr(), - a.limbs() as u64, + a.size() as u64, (a.n() * a.cols()) as u64, ) } @@ -329,14 +321,14 @@ mod tests { let n: usize = 8; let module: Module = Module::::new(n); - let limbs: usize = 2; + let size: usize = 2; let log_base2k: usize = 17; - let mut a: VecZnx = module.new_vec_znx(1, limbs); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs); + let mut a: VecZnx = module.new_vec_znx(1, size); + let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, size); + let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, size); let mut source: Source = Source::new([0u8; 32]); - module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); + module.fill_uniform(log_base2k, &mut a, 0, size, &mut source); let mut tmp_bytes: Vec = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes()); diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs new file mode 100644 index 0000000..7afcc9a --- /dev/null +++ b/base2k/src/vec_znx_ops.rs @@ -0,0 +1,795 @@ +use crate::ffi::module::MODULE; +use crate::ffi::vec_znx; +use crate::{Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, switch_degree, znx_post_process_ternary_op}; +use std::cmp::min; +pub trait VecZnxOps { + /// Allocates a new [VecZnx]. + /// + /// # Arguments + /// + /// * `cols`: the number of polynomials. + /// * `size`: the number of size per polynomial (a.k.a small polynomials). + fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx; + + fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx; + fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx; + + /// Returns the minimum number of bytes necessary to allocate + /// a new [VecZnx] through [VecZnx::from_bytes]. + fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; + + fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize; + + /// c <- a + b. + fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); + + /// b <- b + a. + fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx); + + /// c <- a - b. + fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); + + /// b <- a - b. + fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx); + + /// b <- b - a. + fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx); + + /// b <- -a. + fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx); + + /// b <- -b. + fn vec_znx_negate_inplace(&self, a: &mut VecZnx); + + /// b <- a * X^k (mod X^{n} + 1) + fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx); + + /// a <- a * X^k (mod X^{n} + 1) + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx); + + /// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) + fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx); + + /// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx); + + /// Splits b into subrings and copies them them into a. + /// + /// # Panics + /// + /// This method requires that all [VecZnx] of b have the same ring degree + /// and that b.n() * b.len() <= a.n() + fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx); + + /// Merges the subrings a into b. + /// + /// # Panics + /// + /// This method requires that all [VecZnx] of a have the same ring degree + /// and that a.n() * a.len() <= b.n() + fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec); +} + +impl VecZnxOps for Module { + fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx { + VecZnx::new(self, cols, size) + } + + fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize { + VecZnx::bytes_of(self, cols, size) + } + + fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx { + VecZnx::from_bytes(self, cols, size, bytes) + } + + fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx { + VecZnx::from_bytes_borrow(self, cols, size, tmp_bytes) + } + + fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols } + } + + fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { + let op = ffi_ternary_op_factory( + self.ptr, + c.size(), + c.sl(), + a.size(), + a.sl(), + b.size(), + b.sl(), + vec_znx::vec_znx_add, + ); + vec_znx_apply_binary_op::(self, c, a, b, op); + } + + fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + unsafe { + let b_ptr: *mut VecZnx = b as *mut VecZnx; + Self::vec_znx_add(self, &mut *b_ptr, a, &*b_ptr); + } + } + + fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { + let op = ffi_ternary_op_factory( + self.ptr, + c.size(), + c.sl(), + a.size(), + a.sl(), + b.size(), + b.sl(), + vec_znx::vec_znx_sub, + ); + vec_znx_apply_binary_op::(self, c, a, b, op); + } + + fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + unsafe { + let b_ptr: *mut VecZnx = b as *mut VecZnx; + Self::vec_znx_sub(self, &mut *b_ptr, a, &*b_ptr); + } + } + + fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + unsafe { + let b_ptr: *mut VecZnx = b as *mut VecZnx; + Self::vec_znx_sub(self, &mut *b_ptr, &*b_ptr, a); + } + } + + fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) { + let op = ffi_binary_op_factory_type_0( + self.ptr, + b.size(), + b.sl(), + a.size(), + a.sl(), + vec_znx::vec_znx_negate, + ); + vec_znx_apply_unary_op::(self, b, a, op); + } + + fn vec_znx_negate_inplace(&self, a: &mut VecZnx) { + unsafe { + let a_ptr: *mut VecZnx = a as *mut VecZnx; + Self::vec_znx_negate(self, &mut *a_ptr, &*a_ptr); + } + } + + fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { + let op = ffi_binary_op_factory_type_1( + self.ptr, + k, + b.size(), + b.sl(), + a.size(), + a.sl(), + vec_znx::vec_znx_rotate, + ); + vec_znx_apply_unary_op::(self, b, a, op); + } + + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) { + unsafe { + let a_ptr: *mut VecZnx = a as *mut VecZnx; + Self::vec_znx_rotate(self, k, &mut *a_ptr, &*a_ptr); + } + } + + /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size. + /// + /// # Arguments + /// + /// * `a`: input. + /// * `b`: output. + /// * `k`: the power to which to map each coefficients. + /// * `a_size`: the number of a_size on which to apply the mapping. + /// + /// # Panics + /// + /// The method will panic if the argument `a` is greater than `a.size()`. + fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { + let op = ffi_binary_op_factory_type_1( + self.ptr, + k, + b.size(), + b.sl(), + a.size(), + a.sl(), + vec_znx::vec_znx_automorphism, + ); + vec_znx_apply_unary_op::(self, b, a, op); + } + + /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size. + /// + /// # Arguments + /// + /// * `a`: input and output. + /// * `k`: the power to which to map each coefficients. + /// * `a_size`: the number of size on which to apply the mapping. + /// + /// # Panics + /// + /// The method will panic if the argument `size` is greater than `self.size()`. + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { + unsafe { + let a_ptr: *mut VecZnx = a as *mut VecZnx; + Self::vec_znx_automorphism(self, k, &mut *a_ptr, &*a_ptr); + } + } + + fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx) { + let (n_in, n_out) = (a.n(), b[0].n()); + + debug_assert!( + n_out < n_in, + "invalid a: output ring degree should be smaller" + ); + b[1..].iter().for_each(|bi| { + debug_assert_eq!( + bi.n(), + n_out, + "invalid input a: all VecZnx must have the same degree" + ) + }); + + b.iter_mut().enumerate().for_each(|(i, bi)| { + if i == 0 { + switch_degree(bi, a); + self.vec_znx_rotate(-1, buf, a); + } else { + switch_degree(bi, buf); + self.vec_znx_rotate_inplace(-1, buf); + } + }) + } + + fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec) { + let (n_in, n_out) = (b.n(), a[0].n()); + + debug_assert!( + n_out < n_in, + "invalid a: output ring degree should be smaller" + ); + a[1..].iter().for_each(|ai| { + debug_assert_eq!( + ai.n(), + n_out, + "invalid input a: all VecZnx must have the same degree" + ) + }); + + a.iter().enumerate().for_each(|(_, ai)| { + switch_degree(b, ai); + self.vec_znx_rotate_inplace(-1, b); + }); + + self.vec_znx_rotate_inplace(a.len() as i64, b); + } +} + +fn ffi_ternary_op_factory( + module_ptr: *const MODULE, + c_size: usize, + c_sl: usize, + a_size: usize, + a_sl: usize, + b_size: usize, + b_sl: usize, + op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64, *const i64, u64, u64), +) -> impl Fn(&mut [i64], &[i64], &[i64]) { + move |cv: &mut [i64], av: &[i64], bv: &[i64]| unsafe { + op_fn( + module_ptr, + cv.as_mut_ptr(), + c_size as u64, + c_sl as u64, + av.as_ptr(), + a_size as u64, + a_sl as u64, + bv.as_ptr(), + b_size as u64, + b_sl as u64, + ) + } +} + +fn ffi_binary_op_factory_type_0( + module_ptr: *const MODULE, + b_size: usize, + b_sl: usize, + a_size: usize, + a_sl: usize, + op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64), +) -> impl Fn(&mut [i64], &[i64]) { + move |bv: &mut [i64], av: &[i64]| unsafe { + op_fn( + module_ptr, + bv.as_mut_ptr(), + b_size as u64, + b_sl as u64, + av.as_ptr(), + a_size as u64, + a_sl as u64, + ) + } +} + +fn ffi_binary_op_factory_type_1( + module_ptr: *const MODULE, + k: i64, + b_size: usize, + b_sl: usize, + a_size: usize, + a_sl: usize, + op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut i64, u64, u64, *const i64, u64, u64), +) -> impl Fn(&mut [i64], &[i64]) { + move |bv: &mut [i64], av: &[i64]| unsafe { + op_fn( + module_ptr, + k, + bv.as_mut_ptr(), + b_size as u64, + b_sl as u64, + av.as_ptr(), + a_size as u64, + a_sl as u64, + ) + } +} + +#[inline(always)] +pub fn vec_znx_apply_binary_op( + module: &Module, + c: &mut VecZnx, + a: &VecZnx, + b: &VecZnx, + op: impl Fn(&mut [i64], &[i64], &[i64]), +) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(c.n(), module.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let c_cols: usize = c.cols(); + + let min_ab_cols: usize = min(a_cols, b_cols); + let min_cols: usize = min(c_cols, min_ab_cols); + + // Applies over shared cols between (a, b, c) + (0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0))); + // Copies/Negates/Zeroes the remaining cols if op is not inplace. + if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() { + znx_post_process_ternary_op::(c, a, b); + } +} + +#[inline(always)] +pub fn vec_znx_apply_unary_op(module: &Module, b: &mut VecZnx, a: &VecZnx, op: impl Fn(&mut [i64], &[i64])) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + } + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let min_cols: usize = min(a_cols, b_cols); + // Applies over the shared cols between (a, b) + (0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0))); + // Zeroes the remaining cols of b. + (min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j))); +} + +#[cfg(test)] +mod tests { + use crate::{ + Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx, + znx_post_process_ternary_op, + }; + use itertools::izip; + use sampling::source::Source; + use std::cmp::min; + + #[test] + fn vec_znx_add() { + let n: usize = 8; + let module: Module = Module::::new(n); + let op = |cv: &mut [i64], av: &[i64], bv: &[i64]| { + izip!(cv.iter_mut(), bv.iter(), av.iter()).for_each(|(ci, bi, ai)| *ci = *bi + *ai); + }; + test_binary_op::( + &module, + &|c: &mut VecZnx, a: &VecZnx, b: &VecZnx| module.vec_znx_add(c, a, b), + op, + ); + } + + #[test] + fn vec_znx_add_inplace() { + let n: usize = 8; + let module: Module = Module::::new(n); + let op = |bv: &mut [i64], av: &[i64]| { + izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *bi + *ai); + }; + test_binary_op_inplace::( + &module, + &|b: &mut VecZnx, a: &VecZnx| module.vec_znx_add_inplace(b, a), + op, + ); + } + + #[test] + fn vec_znx_sub() { + let n: usize = 8; + let module: Module = Module::::new(n); + let op = |cv: &mut [i64], av: &[i64], bv: &[i64]| { + izip!(cv.iter_mut(), bv.iter(), av.iter()).for_each(|(ci, bi, ai)| *ci = *bi - *ai); + }; + test_binary_op::( + &module, + &|c: &mut VecZnx, a: &VecZnx, b: &VecZnx| module.vec_znx_sub(c, a, b), + op, + ); + } + + #[test] + fn vec_znx_sub_ab_inplace() { + let n: usize = 8; + let module: Module = Module::::new(n); + let op = |bv: &mut [i64], av: &[i64]| { + izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *ai - *bi); + }; + test_binary_op_inplace::( + &module, + &|b: &mut VecZnx, a: &VecZnx| module.vec_znx_sub_ab_inplace(b, a), + op, + ); + } + + #[test] + fn vec_znx_sub_ba_inplace() { + let n: usize = 8; + let module: Module = Module::::new(n); + let op = |bv: &mut [i64], av: &[i64]| { + izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *bi - *ai); + }; + test_binary_op_inplace::( + &module, + &|b: &mut VecZnx, a: &VecZnx| module.vec_znx_sub_ba_inplace(b, a), + op, + ); + } + + #[test] + fn vec_znx_negate() { + let n: usize = 8; + let module: Module = Module::::new(n); + let op = |b: &mut [i64], a: &[i64]| { + izip!(b.iter_mut(), a.iter()).for_each(|(bi, ai)| *bi = -*ai); + }; + test_unary_op( + &module, + |b: &mut VecZnx, a: &VecZnx| module.vec_znx_negate(b, a), + op, + ) + } + + #[test] + fn vec_znx_negate_inplace() { + let n: usize = 8; + let module: Module = Module::::new(n); + let op = |a: &mut [i64]| a.iter_mut().for_each(|xi| *xi = -*xi); + test_unary_op_inplace( + &module, + |a: &mut VecZnx| module.vec_znx_negate_inplace(a), + op, + ) + } + + #[test] + fn vec_znx_rotate() { + let n: usize = 8; + let module: Module = Module::::new(n); + let k: i64 = 53; + let op = |b: &mut [i64], a: &[i64]| { + assert_eq!(b.len(), a.len()); + b.copy_from_slice(a); + + let mut k_mod2n: i64 = k % (2 * n as i64); + if k_mod2n < 0 { + k_mod2n += 2 * n as i64; + } + let sign: i64 = (k_mod2n.abs() / (n as i64)) & 1; + let k_modn: i64 = k_mod2n % (n as i64); + + b.rotate_right(k_modn as usize); + b[0..k_modn as usize].iter_mut().for_each(|x| *x = -*x); + + if sign == 1 { + b.iter_mut().for_each(|x| *x = -*x); + } + }; + test_unary_op( + &module, + |b: &mut VecZnx, a: &VecZnx| module.vec_znx_rotate(k, b, a), + op, + ) + } + + #[test] + fn vec_znx_rotate_inplace() { + let n: usize = 8; + let module: Module = Module::::new(n); + let k: i64 = 53; + let rot = |a: &mut [i64]| { + let mut k_mod2n: i64 = k % (2 * n as i64); + if k_mod2n < 0 { + k_mod2n += 2 * n as i64; + } + let sign: i64 = (k_mod2n.abs() / (n as i64)) & 1; + let k_modn: i64 = k_mod2n % (n as i64); + + a.rotate_right(k_modn as usize); + a[0..k_modn as usize].iter_mut().for_each(|x| *x = -*x); + + if sign == 1 { + a.iter_mut().for_each(|x| *x = -*x); + } + }; + test_unary_op_inplace( + &module, + |a: &mut VecZnx| module.vec_znx_rotate_inplace(k, a), + rot, + ) + } + + #[test] + fn vec_znx_automorphism() { + let n: usize = 8; + let module: Module = Module::::new(n); + let k: i64 = -5; + let op = |b: &mut [i64], a: &[i64]| { + assert_eq!(b.len(), a.len()); + unsafe { + vec_znx::vec_znx_automorphism( + module.ptr, + k, + b.as_mut_ptr(), + 1u64, + n as u64, + a.as_ptr(), + 1u64, + n as u64, + ); + } + }; + test_unary_op( + &module, + |b: &mut VecZnx, a: &VecZnx| module.vec_znx_automorphism(k, b, a), + op, + ) + } + + #[test] + fn vec_znx_automorphism_inplace() { + let n: usize = 8; + let module: Module = Module::::new(n); + let k: i64 = -5; + let op = |a: &mut [i64]| unsafe { + vec_znx::vec_znx_automorphism( + module.ptr, + k, + a.as_mut_ptr(), + 1u64, + n as u64, + a.as_ptr(), + 1u64, + n as u64, + ); + }; + test_unary_op_inplace( + &module, + |a: &mut VecZnx| module.vec_znx_automorphism_inplace(k, a), + op, + ) + } + + fn test_binary_op( + module: &Module, + func_have: impl Fn(&mut VecZnx, &VecZnx, &VecZnx), + func_want: impl Fn(&mut [i64], &[i64], &[i64]), + ) { + let a_size: usize = 3; + let b_size: usize = 4; + let c_size: usize = 5; + let mut source: Source = Source::new([0u8; 32]); + + [1usize, 2, 3].iter().for_each(|a_cols| { + [1usize, 2, 3].iter().for_each(|b_cols| { + [1usize, 2, 3].iter().for_each(|c_cols| { + let min_ab_cols: usize = min(*a_cols, *b_cols); + let min_cols: usize = min(*c_cols, min_ab_cols); + let min_size: usize = min(c_size, min(a_size, b_size)); + + let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); + (0..*a_cols).for_each(|i| { + module.fill_uniform(3, &mut a, i, a_size, &mut source); + }); + + let mut b: VecZnx = module.new_vec_znx(*b_cols, b_size); + (0..*b_cols).for_each(|i| { + module.fill_uniform(3, &mut b, i, b_size, &mut source); + }); + + let mut c_have: VecZnx = module.new_vec_znx(*c_cols, c_size); + (0..c_have.cols()).for_each(|i| { + module.fill_uniform(3, &mut c_have, i, c_size, &mut source); + }); + + func_have(&mut c_have, &a, &b); + + let mut c_want: VecZnx = module.new_vec_znx(*c_cols, c_size); + + // Adds with the minimum matching columns + (0..min_cols).for_each(|i| { + // Adds with th eminimum matching size + (0..min_size).for_each(|j| { + func_want(c_want.at_poly_mut(i, j), b.at_poly(i, j), a.at_poly(i, j)); + }); + + if a_size > b_size { + // Copies remaining size of lh if lh.size() > rh.size() + (min_size..a_size).for_each(|j| { + izip!(c_want.at_poly_mut(i, j).iter_mut(), a.at_poly(i, j).iter()).for_each(|(ci, ai)| *ci = *ai); + if NEGATE { + c_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); + } + }); + } else { + // Copies the remaining size of rh if the are greater + (min_size..b_size).for_each(|j| { + izip!(c_want.at_poly_mut(i, j).iter_mut(), b.at_poly(i, j).iter()).for_each(|(ci, bi)| *ci = *bi); + if NEGATE { + c_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); + } + }); + } + }); + + znx_post_process_ternary_op::<_, NEGATE>(&mut c_want, &a, &b); + + assert_eq!(c_have.raw(), c_want.raw()); + }); + }); + }); + } + + fn test_binary_op_inplace( + module: &Module, + func_have: impl Fn(&mut VecZnx, &VecZnx), + func_want: impl Fn(&mut [i64], &[i64]), + ) { + let a_size: usize = 3; + let b_size: usize = 5; + let mut source = Source::new([0u8; 32]); + + [1usize, 2, 3].iter().for_each(|a_cols| { + [1usize, 2, 3].iter().for_each(|b_cols| { + let min_cols: usize = min(*b_cols, *a_cols); + let min_size: usize = min(b_size, a_size); + + let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); + (0..*a_cols).for_each(|i| { + module.fill_uniform(3, &mut a, i, a_size, &mut source); + }); + + let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size); + (0..*b_cols).for_each(|i| { + module.fill_uniform(3, &mut b_have, i, b_size, &mut source); + }); + + let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size); + b_want.raw_mut().copy_from_slice(b_have.raw()); + + func_have(&mut b_have, &a); + + // Applies with the minimum matching columns + (0..min_cols).for_each(|i| { + // Adds with th eminimum matching size + (0..min_size).for_each(|j| func_want(b_want.at_poly_mut(i, j), a.at_poly(i, j))); + if NEGATE { + (min_size..b_size).for_each(|j| { + b_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); + }); + } + }); + + assert_eq!(b_have.raw(), b_want.raw()); + }); + }); + } + + fn test_unary_op( + module: &Module, + func_have: impl Fn(&mut VecZnx, &VecZnx), + func_want: impl Fn(&mut [i64], &[i64]), + ) { + let a_size: usize = 3; + let b_size: usize = 5; + let mut source = Source::new([0u8; 32]); + + [1usize, 2, 3].iter().for_each(|a_cols| { + [1usize, 2, 3].iter().for_each(|b_cols| { + let min_cols: usize = min(*b_cols, *a_cols); + let min_size: usize = min(b_size, a_size); + + let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); + (0..a.cols()).for_each(|i| { + module.fill_uniform(3, &mut a, i, a_size, &mut source); + }); + + let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size); + (0..b_have.cols()).for_each(|i| { + module.fill_uniform(3, &mut b_have, i, b_size, &mut source); + }); + + let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size); + + func_have(&mut b_have, &a); + + // Applies on the minimum matching columns + (0..min_cols).for_each(|i| { + // Applies on the minimum matching size + (0..min_size).for_each(|j| func_want(b_want.at_poly_mut(i, j), a.at_poly(i, j))); + + // Zeroes the unmatching size + (min_size..b_size).for_each(|j| { + b_want.zero_at(i, j); + }) + }); + + // Zeroes the unmatching columns + (min_cols..*b_cols).for_each(|i| { + (0..b_size).for_each(|j| { + b_want.zero_at(i, j); + }) + }); + + assert_eq!(b_have.raw(), b_want.raw()); + }); + }); + } + + fn test_unary_op_inplace(module: &Module, func_have: impl Fn(&mut VecZnx), func_want: impl Fn(&mut [i64])) { + let a_size: usize = 3; + let mut source = Source::new([0u8; 32]); + [1usize, 2, 3].iter().for_each(|a_cols| { + let mut a_have: VecZnx = module.new_vec_znx(*a_cols, a_size); + (0..*a_cols).for_each(|i| { + module.fill_uniform(3, &mut a_have, i, a_size, &mut source); + }); + + let mut a_want: VecZnx = module.new_vec_znx(*a_cols, a_size); + a_have.raw_mut().copy_from_slice(a_want.raw()); + + func_have(&mut a_have); + + // Applies on the minimum matching columns + (0..*a_cols).for_each(|i| { + // Applies on the minimum matching size + (0..a_size).for_each(|j| func_want(a_want.at_poly_mut(i, j))); + }); + + assert_eq!(a_have.raw(), a_want.raw()); + }); + } +} From 48cfc0027b9abb0d2b74b03e1a817355cf546d8a Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 29 Apr 2025 12:46:25 +0200 Subject: [PATCH 11/87] Updated vec_znx_ops doc --- base2k/src/vec_znx_ops.rs | 93 ++++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 40 deletions(-) diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 7afcc9a..4c8409d 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -8,49 +8,72 @@ pub trait VecZnxOps { /// # Arguments /// /// * `cols`: the number of polynomials. - /// * `size`: the number of size per polynomial (a.k.a small polynomials). + /// * `size`: the number small polynomials per column. fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx; + /// Instantiates a new [VecZnx] from a slice of bytes. + /// The returned [VecZnx] takes ownership of the slice of bytes. + /// + /// # Arguments + /// + /// * `cols`: the number of polynomials. + /// * `size`: the number small polynomials per column. + /// + /// # Panic + /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx; + + /// Instantiates a new [VecZnx] from a slice of bytes. + /// The returned [VecZnx] does take ownership of the slice of bytes. + /// + /// # Arguments + /// + /// * `cols`: the number of polynomials. + /// * `size`: the number small polynomials per column. + /// + /// # Panic + /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx; - /// Returns the minimum number of bytes necessary to allocate - /// a new [VecZnx] through [VecZnx::from_bytes]. + /// Returns the number of bytes necessary to allocate + /// a new [VecZnx] through [VecZnxOps::new_vec_znx_from_bytes] + /// or [VecZnxOps::new_vec_znx_from_bytes_borrow]. fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; + /// Returns the minimum number of bytes necessary for normalization. fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize; - /// c <- a + b. + /// Adds `a` to `b` and write the result on `c`. fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); - /// b <- b + a. + /// Adds `a` to `b` and write the result on `b`. fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx); - /// c <- a - b. + /// Subtracts `b` to `a` and write the result on `c`. fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); - /// b <- a - b. + /// Subtracts `a` to `b` and write the result on `b`. fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx); - /// b <- b - a. + /// Subtracts `b` to `a` and write the result on `b`. fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx); - /// b <- -a. + // Negates `a` and stores the result on `b`. fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx); - /// b <- -b. + /// Negages `a` and stores the result on `a`. fn vec_znx_negate_inplace(&self, a: &mut VecZnx); - /// b <- a * X^k (mod X^{n} + 1) + /// Multiplies `a` by X^k and stores the result on `b`. fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx); - /// a <- a * X^k (mod X^{n} + 1) + /// Multiplies `a` by X^k and stores the result on `a`. fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx); - /// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx); - /// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx); /// Splits b into subrings and copies them them into a. @@ -179,18 +202,6 @@ impl VecZnxOps for Module { } } - /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size. - /// - /// # Arguments - /// - /// * `a`: input. - /// * `b`: output. - /// * `k`: the power to which to map each coefficients. - /// * `a_size`: the number of a_size on which to apply the mapping. - /// - /// # Panics - /// - /// The method will panic if the argument `a` is greater than `a.size()`. fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { let op = ffi_binary_op_factory_type_1( self.ptr, @@ -204,17 +215,6 @@ impl VecZnxOps for Module { vec_znx_apply_unary_op::(self, b, a, op); } - /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size. - /// - /// # Arguments - /// - /// * `a`: input and output. - /// * `k`: the power to which to map each coefficients. - /// * `a_size`: the number of size on which to apply the mapping. - /// - /// # Panics - /// - /// The method will panic if the argument `size` is greater than `self.size()`. fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; @@ -357,14 +357,11 @@ pub fn vec_znx_apply_binary_op( assert_eq!(c.n(), module.n()); assert_ne!(a.as_ptr(), b.as_ptr()); } - let a_cols: usize = a.cols(); let b_cols: usize = b.cols(); let c_cols: usize = c.cols(); - let min_ab_cols: usize = min(a_cols, b_cols); let min_cols: usize = min(c_cols, min_ab_cols); - // Applies over shared cols between (a, b, c) (0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0))); // Copies/Negates/Zeroes the remaining cols if op is not inplace. @@ -620,25 +617,30 @@ mod tests { let min_cols: usize = min(*c_cols, min_ab_cols); let min_size: usize = min(c_size, min(a_size, b_size)); + // Allocats a and populates with random values. let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); (0..*a_cols).for_each(|i| { module.fill_uniform(3, &mut a, i, a_size, &mut source); }); + // Allocats b and populates with random values. let mut b: VecZnx = module.new_vec_znx(*b_cols, b_size); (0..*b_cols).for_each(|i| { module.fill_uniform(3, &mut b, i, b_size, &mut source); }); + // Allocats c and populates with random values. let mut c_have: VecZnx = module.new_vec_znx(*c_cols, c_size); (0..c_have.cols()).for_each(|i| { module.fill_uniform(3, &mut c_have, i, c_size, &mut source); }); + // Applies the function to test func_have(&mut c_have, &a, &b); let mut c_want: VecZnx = module.new_vec_znx(*c_cols, c_size); + // Applies the reference function and expected behavior. // Adds with the minimum matching columns (0..min_cols).for_each(|i| { // Adds with th eminimum matching size @@ -687,11 +689,13 @@ mod tests { let min_cols: usize = min(*b_cols, *a_cols); let min_size: usize = min(b_size, a_size); + // Allocats a and populates with random values. let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); (0..*a_cols).for_each(|i| { module.fill_uniform(3, &mut a, i, a_size, &mut source); }); + // Allocats b and populates with random values. let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size); (0..*b_cols).for_each(|i| { module.fill_uniform(3, &mut b_have, i, b_size, &mut source); @@ -700,8 +704,10 @@ mod tests { let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size); b_want.raw_mut().copy_from_slice(b_have.raw()); + // Applies the function to test. func_have(&mut b_have, &a); + // Applies the reference function and expected behavior. // Applies with the minimum matching columns (0..min_cols).for_each(|i| { // Adds with th eminimum matching size @@ -732,11 +738,13 @@ mod tests { let min_cols: usize = min(*b_cols, *a_cols); let min_size: usize = min(b_size, a_size); + // Allocats a and populates with random values. let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); (0..a.cols()).for_each(|i| { module.fill_uniform(3, &mut a, i, a_size, &mut source); }); + // Allocats b and populates with random values. let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size); (0..b_have.cols()).for_each(|i| { module.fill_uniform(3, &mut b_have, i, b_size, &mut source); @@ -744,8 +752,10 @@ mod tests { let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size); + // Applies the function to test. func_have(&mut b_have, &a); + // Applies the reference function and expected behavior. // Applies on the minimum matching columns (0..min_cols).for_each(|i| { // Applies on the minimum matching size @@ -778,11 +788,14 @@ mod tests { module.fill_uniform(3, &mut a_have, i, a_size, &mut source); }); + // Allocats a and populates with random values. let mut a_want: VecZnx = module.new_vec_znx(*a_cols, a_size); a_have.raw_mut().copy_from_slice(a_want.raw()); + // Applies the function to test. func_have(&mut a_have); + // Applies the reference function and expected behavior. // Applies on the minimum matching columns (0..*a_cols).for_each(|i| { // Applies on the minimum matching size From d86d6b6ee86c813a1fc0107b317ec79371b4614e Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 29 Apr 2025 12:52:35 +0200 Subject: [PATCH 12/87] Updated vec_znx_big doc --- base2k/src/vec_znx_big.rs | 57 ++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 7f647da..d54d72d 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -19,7 +19,7 @@ impl ZnxBase for VecZnxBig { assert!(cols > 0); assert!(size > 0); } - let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, size)); + let mut data: Vec = alloc_aligned(Self::bytes_of(module, cols, size)); let ptr: *mut Self::Scalar = data.as_mut_ptr(); Self { data: data, @@ -124,7 +124,7 @@ pub trait VecZnxBigOps { /// # Arguments /// /// * `cols`: the number of polynomials.. - /// * `size`: the number of size (a.k.a small polynomials) per polynomial. + /// * `size`: the number of polynomials per column. /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. /// /// # Panics @@ -138,7 +138,7 @@ pub trait VecZnxBigOps { /// # Arguments /// /// * `cols`: the number of polynomials.. - /// * `size`: the number of size (a.k.a small polynomials) per polynomial. + /// * `size`: the number of polynomials per column. /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. /// /// # Panics @@ -149,39 +149,45 @@ pub trait VecZnxBigOps { /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; - /// b[VecZnxBig] <- b[VecZnxBig] - a[VecZnx] - /// - /// # Behavior - /// - /// [VecZnxBig] (3 cols and 4 size) - /// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3] - /// - - /// [VecZnx] (2 cols and 3 size) - /// [d0, e0] [d1, e1] [d2, e2] - /// = - /// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3] + /// Subtracts `a` to `b` and stores the result on `b`. fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); - /// c <- b - a + /// Subtracts `b` to `a` and stores the result on `c`. fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); - /// c <- b + a + /// Adds `a` to `b` and stores the result on `c`. fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); - /// b <- b + a + /// Adds `a` to `b` and stores the result on `b`. fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; - /// b <- normalize(a) + /// Normalizes `a` and stores the result on `b`. + /// + /// # Arguments + /// + /// * `log_base2k`: normalization basis. + /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]); + /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_range_normalize_base2k]. fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize; + /// Normalize `a`, taking into account column interleaving and stores the result on `b`. + /// + /// # Arguments + /// + /// * `log_base2k`: normalization basis. + /// * `a_range_begin`: column to start. + /// * `a_range_end`: column to end. + /// * `a_range_step`: column step size. + /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_range_normalize_base2k_tmp_bytes]. fn vec_znx_big_range_normalize_base2k( &self, log_base2k: usize, - res: &mut VecZnx, + b: &mut VecZnx, a: &VecZnxBig, a_range_begin: usize, a_range_xend: usize, @@ -189,9 +195,11 @@ pub trait VecZnxBigOps { tmp_bytes: &mut [u8], ); - fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig); + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. + fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig, a: &VecZnxBig); - fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig); + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig); } impl VecZnxBigOps for Module { @@ -211,13 +219,6 @@ impl VecZnxBigOps for Module { VecZnxBig::bytes_of(self, cols, size) } - /// [VecZnxBig] (3 cols and 4 size) - /// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3] - /// - - /// [VecZnx] (2 cols and 3 size) - /// [d0, e0] [d1, e1] [d2, e2] - /// = - /// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3] fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { unsafe { vec_znx_big::vec_znx_big_sub_small_a( From 3ee69866bd5936bb6c5ef390e5bbcac4bf45d562 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 29 Apr 2025 14:33:07 +0200 Subject: [PATCH 13/87] Generalized apply_binary_op & apply_unary_op --- base2k/src/commons.rs | 55 +++++++++++++++++++++++++++++++++++++- base2k/src/vec_znx_ops.rs | 56 +++++---------------------------------- 2 files changed, 60 insertions(+), 51 deletions(-) diff --git a/base2k/src/commons.rs b/base2k/src/commons.rs index 1d7a0c9..cfae556 100644 --- a/base2k/src/commons.rs +++ b/base2k/src/commons.rs @@ -244,7 +244,7 @@ where }); } -pub fn znx_post_process_ternary_op(c: &mut T, a: &T, b: &T) +pub fn znx_post_process_ternary_op(c: &mut T, a: &T, b: &T) where ::Scalar: IntegerType, { @@ -292,3 +292,56 @@ where }); } } + +#[inline(always)] +pub fn apply_binary_op( + module: &Module, + c: &mut T, + a: &T, + b: &T, + op: impl Fn(&mut [T::Scalar], &[T::Scalar], &[T::Scalar]), +) where + ::Scalar: IntegerType, +{ + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(c.n(), module.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let c_cols: usize = c.cols(); + let min_ab_cols: usize = min(a_cols, b_cols); + let min_cols: usize = min(c_cols, min_ab_cols); + // Applies over shared cols between (a, b, c) + (0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0))); + // Copies/Negates/Zeroes the remaining cols if op is not inplace. + if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() { + znx_post_process_ternary_op::(c, a, b); + } +} + +#[inline(always)] +pub fn apply_unary_op( + module: &Module, + b: &mut T, + a: &T, + op: impl Fn(&mut [T::Scalar], &[T::Scalar]), +) where + ::Scalar: IntegerType, +{ + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + } + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let min_cols: usize = min(a_cols, b_cols); + // Applies over the shared cols between (a, b) + (0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0))); + // Zeroes the remaining cols of b. + (min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j))); +} diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 4c8409d..573e5b1 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,6 +1,6 @@ use crate::ffi::module::MODULE; use crate::ffi::vec_znx; -use crate::{Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, switch_degree, znx_post_process_ternary_op}; +use crate::{apply_binary_op, apply_unary_op, switch_degree, znx_post_process_ternary_op, Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout}; use std::cmp::min; pub trait VecZnxOps { /// Allocates a new [VecZnx]. @@ -125,7 +125,7 @@ impl VecZnxOps for Module { b.sl(), vec_znx::vec_znx_add, ); - vec_znx_apply_binary_op::(self, c, a, b, op); + apply_binary_op::(self, c, a, b, op); } fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { @@ -146,7 +146,7 @@ impl VecZnxOps for Module { b.sl(), vec_znx::vec_znx_sub, ); - vec_znx_apply_binary_op::(self, c, a, b, op); + apply_binary_op::(self, c, a, b, op); } fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) { @@ -172,7 +172,7 @@ impl VecZnxOps for Module { a.sl(), vec_znx::vec_znx_negate, ); - vec_znx_apply_unary_op::(self, b, a, op); + apply_unary_op::(self, b, a, op); } fn vec_znx_negate_inplace(&self, a: &mut VecZnx) { @@ -192,7 +192,7 @@ impl VecZnxOps for Module { a.sl(), vec_znx::vec_znx_rotate, ); - vec_znx_apply_unary_op::(self, b, a, op); + apply_unary_op::(self, b, a, op); } fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) { @@ -212,7 +212,7 @@ impl VecZnxOps for Module { a.sl(), vec_znx::vec_znx_automorphism, ); - vec_znx_apply_unary_op::(self, b, a, op); + apply_unary_op::(self, b, a, op); } fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { @@ -342,50 +342,6 @@ fn ffi_binary_op_factory_type_1( } } -#[inline(always)] -pub fn vec_znx_apply_binary_op( - module: &Module, - c: &mut VecZnx, - a: &VecZnx, - b: &VecZnx, - op: impl Fn(&mut [i64], &[i64], &[i64]), -) { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(c.n(), module.n()); - assert_ne!(a.as_ptr(), b.as_ptr()); - } - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let c_cols: usize = c.cols(); - let min_ab_cols: usize = min(a_cols, b_cols); - let min_cols: usize = min(c_cols, min_ab_cols); - // Applies over shared cols between (a, b, c) - (0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0))); - // Copies/Negates/Zeroes the remaining cols if op is not inplace. - if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() { - znx_post_process_ternary_op::(c, a, b); - } -} - -#[inline(always)] -pub fn vec_znx_apply_unary_op(module: &Module, b: &mut VecZnx, a: &VecZnx, op: impl Fn(&mut [i64], &[i64])) { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - } - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let min_cols: usize = min(a_cols, b_cols); - // Applies over the shared cols between (a, b) - (0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0))); - // Zeroes the remaining cols of b. - (min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j))); -} - #[cfg(test)] mod tests { use crate::{ From bd933c0e94ef83875703e23987560a14a7d73d15 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 29 Apr 2025 15:53:26 +0200 Subject: [PATCH 14/87] Added VecZnxBig ops --- base2k/examples/rlwe_encrypt.rs | 2 +- base2k/src/commons.rs | 104 +--------- base2k/src/internals.rs | 192 ++++++++++++++++++ base2k/src/lib.rs | 3 + base2k/src/vec_znx_big.rs | 269 +------------------------ base2k/src/vec_znx_big_ops.rs | 339 ++++++++++++++++++++++++++++++++ base2k/src/vec_znx_ops.rs | 61 +----- 7 files changed, 549 insertions(+), 421 deletions(-) create mode 100644 base2k/src/internals.rs create mode 100644 base2k/src/vec_znx_big_ops.rs diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 3661f0d..8a5d09f 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -59,7 +59,7 @@ fn main() { m.normalize(log_base2k, &mut carry); // buf_big <- m - buf_big - module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m); + module.vec_znx_big_sub_small_ab_inplace(&mut buf_big, &m); println!("{:?}", buf_big.raw()); diff --git a/base2k/src/commons.rs b/base2k/src/commons.rs index cfae556..969897d 100644 --- a/base2k/src/commons.rs +++ b/base2k/src/commons.rs @@ -1,6 +1,6 @@ use crate::{Backend, Module, assert_alignement, cast_mut}; use itertools::izip; -use std::cmp::{max, min}; +use std::cmp::min; pub trait ZnxInfos { /// Returns the ring degree of the polynomials. @@ -243,105 +243,3 @@ where .for_each(|(x_in, x_out)| *x_out = *x_in); }); } - -pub fn znx_post_process_ternary_op(c: &mut T, a: &T, b: &T) -where - ::Scalar: IntegerType, -{ - #[cfg(debug_assertions)] - { - assert_ne!(a.as_ptr(), b.as_ptr()); - assert_ne!(b.as_ptr(), c.as_ptr()); - assert_ne!(a.as_ptr(), c.as_ptr()); - } - - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let c_cols: usize = c.cols(); - - let min_ab_cols: usize = min(a_cols, b_cols); - let max_ab_cols: usize = max(a_cols, b_cols); - - // Copies shared shared cols between (c, max(a, b)) - if a_cols != b_cols { - let mut x: &T = a; - if a_cols < b_cols { - x = b; - } - - let min_size = min(c.size(), x.size()); - (min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| { - (0..min_size).for_each(|j| { - c.at_poly_mut(i, j).copy_from_slice(x.at_poly(i, j)); - if NEGATE { - c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); - } - }); - (min_size..c.size()).for_each(|j| { - c.zero_at(i, j); - }); - }); - } - - // Zeroes the cols of c > max(a, b). - if c_cols > max_ab_cols { - (max_ab_cols..c_cols).for_each(|i| { - (0..c.size()).for_each(|j| { - c.zero_at(i, j); - }) - }); - } -} - -#[inline(always)] -pub fn apply_binary_op( - module: &Module, - c: &mut T, - a: &T, - b: &T, - op: impl Fn(&mut [T::Scalar], &[T::Scalar], &[T::Scalar]), -) where - ::Scalar: IntegerType, -{ - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(c.n(), module.n()); - assert_ne!(a.as_ptr(), b.as_ptr()); - } - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let c_cols: usize = c.cols(); - let min_ab_cols: usize = min(a_cols, b_cols); - let min_cols: usize = min(c_cols, min_ab_cols); - // Applies over shared cols between (a, b, c) - (0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0))); - // Copies/Negates/Zeroes the remaining cols if op is not inplace. - if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() { - znx_post_process_ternary_op::(c, a, b); - } -} - -#[inline(always)] -pub fn apply_unary_op( - module: &Module, - b: &mut T, - a: &T, - op: impl Fn(&mut [T::Scalar], &[T::Scalar]), -) where - ::Scalar: IntegerType, -{ - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - } - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let min_cols: usize = min(a_cols, b_cols); - // Applies over the shared cols between (a, b) - (0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0))); - // Zeroes the remaining cols of b. - (min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j))); -} diff --git a/base2k/src/internals.rs b/base2k/src/internals.rs new file mode 100644 index 0000000..d7b08dc --- /dev/null +++ b/base2k/src/internals.rs @@ -0,0 +1,192 @@ +use std::cmp::{max, min}; + +use crate::{Backend, IntegerType, Module, ZnxBasics, ZnxLayout, ffi::module::MODULE}; + +pub(crate) fn znx_post_process_ternary_op(c: &mut C, a: &A, b: &B) +where + C: ZnxBasics + ZnxLayout, + A: ZnxBasics + ZnxLayout, + B: ZnxBasics + ZnxLayout, + C::Scalar: IntegerType, +{ + #[cfg(debug_assertions)] + { + assert_ne!(a.as_ptr(), b.as_ptr()); + assert_ne!(b.as_ptr(), c.as_ptr()); + assert_ne!(a.as_ptr(), c.as_ptr()); + } + + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let c_cols: usize = c.cols(); + + let min_ab_cols: usize = min(a_cols, b_cols); + let max_ab_cols: usize = max(a_cols, b_cols); + + // Copies shared shared cols between (c, max(a, b)) + if a_cols != b_cols { + if a_cols > b_cols { + let min_size = min(c.size(), a.size()); + (min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| { + (0..min_size).for_each(|j| { + c.at_poly_mut(i, j).copy_from_slice(a.at_poly(i, j)); + if NEGATE { + c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); + } + }); + (min_size..c.size()).for_each(|j| { + c.zero_at(i, j); + }); + }); + } else { + let min_size = min(c.size(), b.size()); + (min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| { + (0..min_size).for_each(|j| { + c.at_poly_mut(i, j).copy_from_slice(b.at_poly(i, j)); + if NEGATE { + c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); + } + }); + (min_size..c.size()).for_each(|j| { + c.zero_at(i, j); + }); + }); + } + } + + // Zeroes the cols of c > max(a, b). + if c_cols > max_ab_cols { + (max_ab_cols..c_cols).for_each(|i| { + (0..c.size()).for_each(|j| { + c.zero_at(i, j); + }) + }); + } +} + +#[inline(always)] +pub fn apply_binary_op( + module: &Module, + c: &mut C, + a: &A, + b: &B, + op: impl Fn(&mut [C::Scalar], &[A::Scalar], &[B::Scalar]), +) where + BE: Backend, + C: ZnxBasics + ZnxLayout, + A: ZnxBasics + ZnxLayout, + B: ZnxBasics + ZnxLayout, + C::Scalar: IntegerType, +{ + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(c.n(), module.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let c_cols: usize = c.cols(); + let min_ab_cols: usize = min(a_cols, b_cols); + let min_cols: usize = min(c_cols, min_ab_cols); + // Applies over shared cols between (a, b, c) + (0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0))); + // Copies/Negates/Zeroes the remaining cols if op is not inplace. + if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() { + znx_post_process_ternary_op::(c, a, b); + } +} + +#[inline(always)] +pub fn apply_unary_op( + module: &Module, + b: &mut T, + a: &T, + op: impl Fn(&mut [T::Scalar], &[T::Scalar]), +) where + ::Scalar: IntegerType, +{ + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + } + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let min_cols: usize = min(a_cols, b_cols); + // Applies over the shared cols between (a, b) + (0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0))); + // Zeroes the remaining cols of b. + (min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j))); +} + +pub fn ffi_ternary_op_factory( + module_ptr: *const MODULE, + c_size: usize, + c_sl: usize, + a_size: usize, + a_sl: usize, + b_size: usize, + b_sl: usize, + op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64, *const T, u64, u64), +) -> impl Fn(&mut [T], &[T], &[T]) { + move |cv: &mut [T], av: &[T], bv: &[T]| unsafe { + op_fn( + module_ptr, + cv.as_mut_ptr(), + c_size as u64, + c_sl as u64, + av.as_ptr(), + a_size as u64, + a_sl as u64, + bv.as_ptr(), + b_size as u64, + b_sl as u64, + ) + } +} + +pub fn ffi_binary_op_factory_type_0( + module_ptr: *const MODULE, + b_size: usize, + b_sl: usize, + a_size: usize, + a_sl: usize, + op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64), +) -> impl Fn(&mut [T], &[T]) { + move |bv: &mut [T], av: &[T]| unsafe { + op_fn( + module_ptr, + bv.as_mut_ptr(), + b_size as u64, + b_sl as u64, + av.as_ptr(), + a_size as u64, + a_sl as u64, + ) + } +} + +pub fn ffi_binary_op_factory_type_1( + module_ptr: *const MODULE, + k: i64, + b_size: usize, + b_sl: usize, + a_size: usize, + a_sl: usize, + op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut T, u64, u64, *const T, u64, u64), +) -> impl Fn(&mut [T], &[T]) { + move |bv: &mut [T], av: &[T]| unsafe { + op_fn( + module_ptr, + k, + bv.as_mut_ptr(), + b_size as u64, + b_sl as u64, + av.as_ptr(), + a_size as u64, + a_sl as u64, + ) + } +} diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 3c48319..2a9a899 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -3,6 +3,7 @@ pub mod encoding; #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] // Other modules and exports pub mod ffi; +mod internals; pub mod mat_znx_dft; pub mod module; pub mod sampling; @@ -10,6 +11,7 @@ pub mod scalar_znx_dft; pub mod stats; pub mod vec_znx; pub mod vec_znx_big; +pub mod vec_znx_big_ops; pub mod vec_znx_dft; pub mod vec_znx_ops; @@ -23,6 +25,7 @@ pub use scalar_znx_dft::*; pub use stats::*; pub use vec_znx::*; pub use vec_znx_big::*; +pub use vec_znx_big_ops::*; pub use vec_znx_dft::*; pub use vec_znx_ops::*; diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index d54d72d..67b75a2 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,5 +1,5 @@ -use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; -use crate::{Backend, FFT64, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; +use crate::ffi::vec_znx_big; +use crate::{Backend, FFT64, Module, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; use std::marker::PhantomData; pub struct VecZnxBig { @@ -10,6 +10,9 @@ pub struct VecZnxBig { pub size: usize, pub _marker: PhantomData, } + +impl ZnxBasics for VecZnxBig {} + impl ZnxBase for VecZnxBig { type Scalar = u8; @@ -112,265 +115,3 @@ impl VecZnxBig { (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); } } - -pub trait VecZnxBigOps { - /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. - fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig; - - /// Returns a new [VecZnxBig] with the provided bytes array as backing array. - /// - /// Behavior: takes ownership of the backing array. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials.. - /// * `size`: the number of polynomials per column. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig; - - /// Returns a new [VecZnxBig] with the provided bytes array as backing array. - /// - /// Behavior: the backing array is only borrowed. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials.. - /// * `size`: the number of polynomials per column. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; - - /// Returns the minimum number of bytes necessary to allocate - /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. - fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; - - /// Subtracts `a` to `b` and stores the result on `b`. - fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); - - /// Subtracts `b` to `a` and stores the result on `c`. - fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); - - /// Adds `a` to `b` and stores the result on `c`. - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); - - /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); - - /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; - - /// Normalizes `a` and stores the result on `b`. - /// - /// # Arguments - /// - /// * `log_base2k`: normalization basis. - /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]); - - /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_range_normalize_base2k]. - fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize; - - /// Normalize `a`, taking into account column interleaving and stores the result on `b`. - /// - /// # Arguments - /// - /// * `log_base2k`: normalization basis. - /// * `a_range_begin`: column to start. - /// * `a_range_end`: column to end. - /// * `a_range_step`: column step size. - /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_range_normalize_base2k_tmp_bytes]. - fn vec_znx_big_range_normalize_base2k( - &self, - log_base2k: usize, - b: &mut VecZnx, - a: &VecZnxBig, - a_range_begin: usize, - a_range_xend: usize, - a_range_step: usize, - tmp_bytes: &mut [u8], - ); - - /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig, a: &VecZnxBig); - - /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig); -} - -impl VecZnxBigOps for Module { - fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig { - VecZnxBig::new(self, cols, size) - } - - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes(self, cols, size, bytes) - } - - fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes) - } - - fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { - VecZnxBig::bytes_of(self, cols, size) - } - - fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { - unsafe { - vec_znx_big::vec_znx_big_sub_small_a( - self.ptr, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - a.as_ptr(), - a.poly_count() as u64, - a.n() as u64, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - ) - } - } - - fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { - unsafe { - vec_znx_big::vec_znx_big_sub_small_a( - self.ptr, - c.ptr as *mut vec_znx_big_t, - c.poly_count() as u64, - a.as_ptr(), - a.poly_count() as u64, - a.n() as u64, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - ) - } - } - - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { - unsafe { - vec_znx_big::vec_znx_big_add_small( - self.ptr, - c.ptr as *mut vec_znx_big_t, - c.poly_count() as u64, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - a.as_ptr(), - a.poly_count() as u64, - a.n() as u64, - ) - } - } - - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { - unsafe { - vec_znx_big::vec_znx_big_add_small( - self.ptr, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - a.as_ptr(), - a.poly_count() as u64, - a.n() as u64, - ) - } - } - - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { - unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize } - } - - fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) { - debug_assert!( - tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", - tmp_bytes.len(), - Self::vec_znx_big_normalize_tmp_bytes(self) - ); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()) - } - unsafe { - vec_znx_big::vec_znx_big_normalize_base2k( - self.ptr, - log_base2k as u64, - b.as_mut_ptr(), - b.size() as u64, - b.n() as u64, - a.ptr as *mut vec_znx_big_t, - a.size() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize { - unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize } - } - - fn vec_znx_big_range_normalize_base2k( - &self, - log_base2k: usize, - res: &mut VecZnx, - a: &VecZnxBig, - a_range_begin: usize, - a_range_xend: usize, - a_range_step: usize, - tmp_bytes: &mut [u8], - ) { - debug_assert!( - tmp_bytes.len() >= Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}", - tmp_bytes.len(), - Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self) - ); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()) - } - unsafe { - vec_znx_big::vec_znx_big_range_normalize_base2k( - self.ptr, - log_base2k as u64, - res.as_mut_ptr(), - res.size() as u64, - res.n() as u64, - a.ptr as *mut vec_znx_big_t, - a_range_begin as u64, - a_range_xend as u64, - a_range_step as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { - unsafe { - vec_znx_big::vec_znx_big_automorphism( - self.ptr, - gal_el, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - a.ptr as *mut vec_znx_big_t, - a.poly_count() as u64, - ); - } - } - - fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { - unsafe { - vec_znx_big::vec_znx_big_automorphism( - self.ptr, - gal_el, - a.ptr as *mut vec_znx_big_t, - a.poly_count() as u64, - a.ptr as *mut vec_znx_big_t, - a.poly_count() as u64, - ); - } - } -} diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs new file mode 100644 index 0000000..530cb54 --- /dev/null +++ b/base2k/src/vec_znx_big_ops.rs @@ -0,0 +1,339 @@ +use crate::ffi::vec_znx_big::vec_znx_big_t; +use crate::ffi::{vec_znx, vec_znx_big}; +use crate::internals::{apply_binary_op, ffi_ternary_op_factory}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement}; + +pub trait VecZnxBigOps { + /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig; + + /// Returns a new [VecZnxBig] with the provided bytes array as backing array. + /// + /// Behavior: takes ownership of the backing array. + /// + /// # Arguments + /// + /// * `cols`: the number of polynomials.. + /// * `size`: the number of polynomials per column. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig; + + /// Returns a new [VecZnxBig] with the provided bytes array as backing array. + /// + /// Behavior: the backing array is only borrowed. + /// + /// # Arguments + /// + /// * `cols`: the number of polynomials.. + /// * `size`: the number of polynomials per column. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. + fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; + + /// Returns the minimum number of bytes necessary to allocate + /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. + fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; + + /// Adds `a` to `b` and stores the result on `c`. + fn vec_znx_big_add(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig); + + /// Adds `a` to `b` and stores the result on `b`. + fn vec_znx_big_add_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig); + + /// Adds `a` to `b` and stores the result on `c`. + fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); + + /// Adds `a` to `b` and stores the result on `b`. + fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + + /// Subtracts `a` to `b` and stores the result on `c`. + fn vec_znx_big_sub(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig); + + /// Subtracts `a` to `b` and stores the result on `b`. + fn vec_znx_big_sub_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig); + + /// Subtracts `b` to `a` and stores the result on `b`. + fn vec_znx_big_sub_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig); + + /// Subtracts `b` to `a` and stores the result on `c`. + fn vec_znx_big_sub_small_ab(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); + + /// Subtracts `a` to `b` and stores the result on `b`. + fn vec_znx_big_sub_small_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + + /// Subtracts `b` to `a` and stores the result on `c`. + fn vec_znx_big_sub_small_ba(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnx); + + /// Subtracts `b` to `a` and stores the result on `b`. + fn vec_znx_big_sub_small_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + + /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; + + /// Normalizes `a` and stores the result on `b`. + /// + /// # Arguments + /// + /// * `log_base2k`: normalization basis. + /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. + fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]); + + /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_range_normalize_base2k]. + fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize; + + /// Normalize `a`, taking into account column interleaving and stores the result on `b`. + /// + /// # Arguments + /// + /// * `log_base2k`: normalization basis. + /// * `a_range_begin`: column to start. + /// * `a_range_end`: column to end. + /// * `a_range_step`: column step size. + /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_range_normalize_base2k_tmp_bytes]. + fn vec_znx_big_range_normalize_base2k( + &self, + log_base2k: usize, + b: &mut VecZnx, + a: &VecZnxBig, + a_range_begin: usize, + a_range_xend: usize, + a_range_step: usize, + tmp_bytes: &mut [u8], + ); + + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. + fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig, a: &VecZnxBig); + + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig); +} + +impl VecZnxBigOps for Module { + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig { + VecZnxBig::new(self, cols, size) + } + + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig { + VecZnxBig::from_bytes(self, cols, size, bytes) + } + + fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { + VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes) + } + + fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { + VecZnxBig::bytes_of(self, cols, size) + } + + fn vec_znx_big_add(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig) { + let op = ffi_ternary_op_factory( + self.ptr, + c.size(), + c.sl(), + a.size(), + a.sl(), + b.size(), + b.sl(), + vec_znx::vec_znx_add, + ); + apply_binary_op::, VecZnxBig, VecZnxBig, false>(self, c, a, b, op); + } + + fn vec_znx_big_add_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig) { + unsafe { + let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; + Self::vec_znx_big_add(self, &mut *b_ptr, a, &*b_ptr); + } + } + + fn vec_znx_big_sub(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig) { + let op = ffi_ternary_op_factory( + self.ptr, + c.size(), + c.sl(), + a.size(), + a.sl(), + b.size(), + b.sl(), + vec_znx::vec_znx_sub, + ); + apply_binary_op::, VecZnxBig, VecZnxBig, true>(self, c, a, b, op); + } + + fn vec_znx_big_sub_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig) { + unsafe { + let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; + Self::vec_znx_big_sub(self, &mut *b_ptr, a, &*b_ptr); + } + } + + fn vec_znx_big_sub_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig) { + unsafe { + let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; + Self::vec_znx_big_sub(self, &mut *b_ptr, &*b_ptr, a); + } + } + + fn vec_znx_big_sub_small_ba(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnx) { + let op = ffi_ternary_op_factory( + self.ptr, + c.size(), + c.sl(), + a.size(), + a.sl(), + b.size(), + b.sl(), + vec_znx::vec_znx_sub, + ); + apply_binary_op::, VecZnxBig, VecZnx, true>(self, c, a, b, op); + } + + fn vec_znx_big_sub_small_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { + unsafe { + let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; + Self::vec_znx_big_sub_small_ba(self, &mut *b_ptr, &*b_ptr, a); + } + } + + fn vec_znx_big_sub_small_ab(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { + let op = ffi_ternary_op_factory( + self.ptr, + c.size(), + c.sl(), + a.size(), + a.sl(), + b.size(), + b.sl(), + vec_znx::vec_znx_sub, + ); + apply_binary_op::, VecZnx, VecZnxBig, true>(self, c, a, b, op); + } + + fn vec_znx_big_sub_small_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { + unsafe { + let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; + Self::vec_znx_big_sub_small_ab(self, &mut *b_ptr, a, &*b_ptr); + } + } + + fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { + let op = ffi_ternary_op_factory( + self.ptr, + c.size(), + c.sl(), + a.size(), + a.sl(), + b.size(), + b.sl(), + vec_znx::vec_znx_add, + ); + apply_binary_op::, VecZnx, VecZnxBig, false>(self, c, a, b, op); + } + + fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { + unsafe { + let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; + Self::vec_znx_big_add_small(self, &mut *b_ptr, a, &*b_ptr); + } + } + + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { + unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize } + } + + fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) { + debug_assert!( + tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(self), + "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", + tmp_bytes.len(), + Self::vec_znx_big_normalize_tmp_bytes(self) + ); + #[cfg(debug_assertions)] + { + assert_alignement(tmp_bytes.as_ptr()) + } + unsafe { + vec_znx_big::vec_znx_big_normalize_base2k( + self.ptr, + log_base2k as u64, + b.as_mut_ptr(), + b.size() as u64, + b.n() as u64, + a.ptr as *mut vec_znx_big_t, + a.size() as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } + + fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize { + unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize } + } + + fn vec_znx_big_range_normalize_base2k( + &self, + log_base2k: usize, + res: &mut VecZnx, + a: &VecZnxBig, + a_range_begin: usize, + a_range_xend: usize, + a_range_step: usize, + tmp_bytes: &mut [u8], + ) { + debug_assert!( + tmp_bytes.len() >= Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self), + "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}", + tmp_bytes.len(), + Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self) + ); + #[cfg(debug_assertions)] + { + assert_alignement(tmp_bytes.as_ptr()) + } + unsafe { + vec_znx_big::vec_znx_big_range_normalize_base2k( + self.ptr, + log_base2k as u64, + res.as_mut_ptr(), + res.size() as u64, + res.n() as u64, + a.ptr as *mut vec_znx_big_t, + a_range_begin as u64, + a_range_xend as u64, + a_range_step as u64, + tmp_bytes.as_mut_ptr(), + ); + } + } + + fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { + unsafe { + vec_znx_big::vec_znx_big_automorphism( + self.ptr, + gal_el, + b.ptr as *mut vec_znx_big_t, + b.poly_count() as u64, + a.ptr as *mut vec_znx_big_t, + a.poly_count() as u64, + ); + } + } + + fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { + unsafe { + vec_znx_big::vec_znx_big_automorphism( + self.ptr, + gal_el, + a.ptr as *mut vec_znx_big_t, + a.poly_count() as u64, + a.ptr as *mut vec_znx_big_t, + a.poly_count() as u64, + ); + } + } +} diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 573e5b1..f67b6b0 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,7 +1,7 @@ use crate::ffi::module::MODULE; use crate::ffi::vec_znx; -use crate::{apply_binary_op, apply_unary_op, switch_degree, znx_post_process_ternary_op, Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout}; -use std::cmp::min; +use crate::internals::{apply_binary_op, apply_unary_op, ffi_binary_op_factory_type_0, ffi_binary_op_factory_type_1}; +use crate::{Backend, Module, VecZnx, ZnxBase, ZnxInfos, switch_degree}; pub trait VecZnxOps { /// Allocates a new [VecZnx]. /// @@ -125,7 +125,7 @@ impl VecZnxOps for Module { b.sl(), vec_znx::vec_znx_add, ); - apply_binary_op::(self, c, a, b, op); + apply_binary_op::(self, c, a, b, op); } fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { @@ -146,7 +146,7 @@ impl VecZnxOps for Module { b.sl(), vec_znx::vec_znx_sub, ); - apply_binary_op::(self, c, a, b, op); + apply_binary_op::(self, c, a, b, op); } fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) { @@ -298,56 +298,11 @@ fn ffi_ternary_op_factory( } } -fn ffi_binary_op_factory_type_0( - module_ptr: *const MODULE, - b_size: usize, - b_sl: usize, - a_size: usize, - a_sl: usize, - op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64), -) -> impl Fn(&mut [i64], &[i64]) { - move |bv: &mut [i64], av: &[i64]| unsafe { - op_fn( - module_ptr, - bv.as_mut_ptr(), - b_size as u64, - b_sl as u64, - av.as_ptr(), - a_size as u64, - a_sl as u64, - ) - } -} - -fn ffi_binary_op_factory_type_1( - module_ptr: *const MODULE, - k: i64, - b_size: usize, - b_sl: usize, - a_size: usize, - a_sl: usize, - op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut i64, u64, u64, *const i64, u64, u64), -) -> impl Fn(&mut [i64], &[i64]) { - move |bv: &mut [i64], av: &[i64]| unsafe { - op_fn( - module_ptr, - k, - bv.as_mut_ptr(), - b_size as u64, - b_sl as u64, - av.as_ptr(), - a_size as u64, - a_sl as u64, - ) - } -} - #[cfg(test)] mod tests { - use crate::{ - Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx, - znx_post_process_ternary_op, - }; + use crate::internals::znx_post_process_ternary_op; + use crate::{Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx}; + use itertools::izip; use sampling::source::Source; use std::cmp::min; @@ -623,7 +578,7 @@ mod tests { } }); - znx_post_process_ternary_op::<_, NEGATE>(&mut c_want, &a, &b); + znx_post_process_ternary_op::(&mut c_want, &a, &b); assert_eq!(c_have.raw(), c_want.raw()); }); From 4f54234bc48f5fd37acac439e02c67bb5d5f21b0 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 29 Apr 2025 16:15:31 +0200 Subject: [PATCH 15/87] Finished adding VecZnxBig ops --- base2k/examples/rlwe_encrypt.rs | 20 ++--- base2k/src/vec_znx_big_ops.rs | 148 ++++++++++---------------------- base2k/src/vec_znx_ops.rs | 48 ++++++++++- 3 files changed, 100 insertions(+), 116 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 8a5d09f..803f371 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -8,17 +8,17 @@ use sampling::source::Source; fn main() { let n: usize = 16; let log_base2k: usize = 18; - let limbs: usize = 3; - let msg_cols: usize = 2; - let log_scale: usize = msg_cols * log_base2k - 5; + let ct_size: usize = 3; + let msg_size: usize = 2; + let log_scale: usize = msg_size * log_base2k - 5; let module: Module = Module::::new(n); - let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); + let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes(1)); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); - let mut res: VecZnx = module.new_vec_znx(1, limbs); + let mut res: VecZnx = module.new_vec_znx(1, ct_size); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) let mut s: Scalar = Scalar::new(n); @@ -31,8 +31,8 @@ fn main() { module.svp_prepare(&mut s_ppol, &s); // a <- Z_{2^prec}[X]/(X^{N}+1) - let mut a: VecZnx = module.new_vec_znx(1, limbs); - module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); + let mut a: VecZnx = module.new_vec_znx(1, ct_size); + module.fill_uniform(log_base2k, &mut a, 0, ct_size, &mut source); // Scratch space for DFT values let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.size()); @@ -48,7 +48,7 @@ fn main() { println!("{:?}", buf_big.raw()); - let mut m: VecZnx = module.new_vec_znx(1, msg_cols); + let mut m: VecZnx = module.new_vec_znx(1, msg_size); let mut want: Vec = vec![0; n]; want.iter_mut() @@ -64,14 +64,14 @@ fn main() { println!("{:?}", buf_big.raw()); // b <- normalize(buf_big) + e - let mut b: VecZnx = module.new_vec_znx(1, limbs); + let mut b: VecZnx = module.new_vec_znx(1, ct_size); module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry); b.print(n); module.add_normal( log_base2k, &mut b, 0, - log_base2k * limbs, + log_base2k * ct_size, &mut source, 3.2, 19.0, diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 530cb54..4b4e54e 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,7 +1,8 @@ -use crate::ffi::vec_znx_big::vec_znx_big_t; -use crate::ffi::{vec_znx, vec_znx_big}; -use crate::internals::{apply_binary_op, ffi_ternary_op_factory}; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement}; +use std::cmp::min; + +use crate::ffi::vec_znx; +use crate::internals::{apply_binary_op, apply_unary_op, ffi_binary_op_factory_type_1, ffi_ternary_op_factory}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, assert_alignement}; pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. @@ -73,7 +74,7 @@ pub trait VecZnxBigOps { fn vec_znx_big_sub_small_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; + fn vec_znx_big_normalize_tmp_bytes(&self, cols: usize) -> usize; /// Normalizes `a` and stores the result on `b`. /// @@ -83,29 +84,6 @@ pub trait VecZnxBigOps { /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]); - /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_range_normalize_base2k]. - fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize; - - /// Normalize `a`, taking into account column interleaving and stores the result on `b`. - /// - /// # Arguments - /// - /// * `log_base2k`: normalization basis. - /// * `a_range_begin`: column to start. - /// * `a_range_end`: column to end. - /// * `a_range_step`: column step size. - /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_range_normalize_base2k_tmp_bytes]. - fn vec_znx_big_range_normalize_base2k( - &self, - log_base2k: usize, - b: &mut VecZnx, - a: &VecZnxBig, - a_range_begin: usize, - a_range_xend: usize, - a_range_step: usize, - tmp_bytes: &mut [u8], - ); - /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig, a: &VecZnxBig); @@ -242,98 +220,58 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { - unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize } + fn vec_znx_big_normalize_tmp_bytes(&self, cols: usize) -> usize { + Self::vec_znx_normalize_tmp_bytes(self, cols) } fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) { - debug_assert!( - tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", - tmp_bytes.len(), - Self::vec_znx_big_normalize_tmp_bytes(self) - ); #[cfg(debug_assertions)] { - assert_alignement(tmp_bytes.as_ptr()) + assert!(tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(&self, a.cols())); + assert_alignement(tmp_bytes.as_ptr()); } - unsafe { - vec_znx_big::vec_znx_big_normalize_base2k( + + let a_size: usize = a.size(); + let b_size: usize = b.sl(); + let a_sl: usize = a.size(); + let b_sl: usize = a.sl(); + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let min_cols: usize = min(a_cols, b_cols); + (0..min_cols).for_each(|i| unsafe { + vec_znx::vec_znx_normalize_base2k( self.ptr, log_base2k as u64, - b.as_mut_ptr(), - b.size() as u64, - b.n() as u64, - a.ptr as *mut vec_znx_big_t, - a.size() as u64, + b.at_mut_ptr(i, 0), + b_size as u64, + b_sl as u64, + a.at_ptr(i, 0), + a_size as u64, + a_sl as u64, tmp_bytes.as_mut_ptr(), - ) - } + ); + }); + + (min_cols..b_cols).for_each(|i| (0..b_size).for_each(|j| b.zero_at(i, j))); } - fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize { - unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize } - } - - fn vec_znx_big_range_normalize_base2k( - &self, - log_base2k: usize, - res: &mut VecZnx, - a: &VecZnxBig, - a_range_begin: usize, - a_range_xend: usize, - a_range_step: usize, - tmp_bytes: &mut [u8], - ) { - debug_assert!( - tmp_bytes.len() >= Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}", - tmp_bytes.len(), - Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self) + fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig, a: &VecZnxBig) { + let op = ffi_binary_op_factory_type_1( + self.ptr, + k, + b.size(), + b.sl(), + a.size(), + a.sl(), + vec_znx::vec_znx_automorphism, ); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()) - } - unsafe { - vec_znx_big::vec_znx_big_range_normalize_base2k( - self.ptr, - log_base2k as u64, - res.as_mut_ptr(), - res.size() as u64, - res.n() as u64, - a.ptr as *mut vec_znx_big_t, - a_range_begin as u64, - a_range_xend as u64, - a_range_step as u64, - tmp_bytes.as_mut_ptr(), - ); - } + apply_unary_op::>(self, b, a, op); } - fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig) { unsafe { - vec_znx_big::vec_znx_big_automorphism( - self.ptr, - gal_el, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - a.ptr as *mut vec_znx_big_t, - a.poly_count() as u64, - ); - } - } - - fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { - unsafe { - vec_znx_big::vec_znx_big_automorphism( - self.ptr, - gal_el, - a.ptr as *mut vec_znx_big_t, - a.poly_count() as u64, - a.ptr as *mut vec_znx_big_t, - a.poly_count() as u64, - ); + let a_ptr: *mut VecZnxBig = a as *mut VecZnxBig; + Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, &*a_ptr); } } } diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index f67b6b0..c7c8d85 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,7 +1,9 @@ +use std::cmp::min; + use crate::ffi::module::MODULE; use crate::ffi::vec_znx; use crate::internals::{apply_binary_op, apply_unary_op, ffi_binary_op_factory_type_0, ffi_binary_op_factory_type_1}; -use crate::{Backend, Module, VecZnx, ZnxBase, ZnxInfos, switch_degree}; +use crate::{Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, assert_alignement, switch_degree}; pub trait VecZnxOps { /// Allocates a new [VecZnx]. /// @@ -43,6 +45,12 @@ pub trait VecZnxOps { /// Returns the minimum number of bytes necessary for normalization. fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize; + /// Normalizes `a` and stores the result into `b`. + fn vec_znx_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnx, tmp_bytes: &mut [u8]); + + /// Normalizes `a` and stores the result into `a`. + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]); + /// Adds `a` to `b` and write the result on `c`. fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); @@ -114,6 +122,44 @@ impl VecZnxOps for Module { unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols } } + fn vec_znx_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnx, tmp_bytes: &mut [u8]) { + #[cfg(debug_assertions)] + { + assert!(tmp_bytes.len() >= Self::vec_znx_normalize_tmp_bytes(&self, a.cols())); + assert_alignement(tmp_bytes.as_ptr()); + } + + let a_size: usize = a.size(); + let b_size: usize = b.sl(); + let a_sl: usize = a.size(); + let b_sl: usize = a.sl(); + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let min_cols: usize = min(a_cols, b_cols); + (0..min_cols).for_each(|i| unsafe { + vec_znx::vec_znx_normalize_base2k( + self.ptr, + log_base2k as u64, + b.at_mut_ptr(i, 0), + b_size as u64, + b_sl as u64, + a.at_ptr(i, 0), + a_size as u64, + a_sl as u64, + tmp_bytes.as_mut_ptr(), + ); + }); + + (min_cols..b_cols).for_each(|i| (0..b_size).for_each(|j| b.zero_at(i, j))); + } + + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { + unsafe { + let a_ptr: *mut VecZnx = a as *mut VecZnx; + Self::vec_znx_normalize(self, log_base2k, &mut *a_ptr, &*a_ptr, tmp_bytes); + } + } + fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { let op = ffi_ternary_op_factory( self.ptr, From 917a4724375166e3f504096fd40178cc33e2c849 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 29 Apr 2025 18:14:16 +0200 Subject: [PATCH 16/87] wip: change of approach, enables to select columns on which to operate --- base2k/examples/rlwe_encrypt.rs | 65 +-- base2k/src/commons.rs | 10 +- base2k/src/internals.rs | 96 ---- base2k/src/scalar_znx_dft.rs | 12 +- base2k/src/vec_znx.rs | 4 +- base2k/src/vec_znx_big_ops.rs | 6 +- base2k/src/vec_znx_ops.rs | 795 ++++++++------------------------ 7 files changed, 250 insertions(+), 738 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 803f371..395fdf6 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,6 @@ use base2k::{ Encoding, FFT64, Module, Sampling, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, - VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxLayout, alloc_aligned, + VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned, }; use itertools::izip; use sampling::source::Source; @@ -13,13 +13,11 @@ fn main() { let log_scale: usize = msg_size * log_base2k - 5; let module: Module = Module::::new(n); - let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes(1)); + let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes(2)); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); - let mut res: VecZnx = module.new_vec_znx(1, ct_size); - // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) let mut s: Scalar = Scalar::new(n); s.fill_ternary_prob(0.5, &mut source); @@ -30,47 +28,50 @@ fn main() { // s_ppol <- DFT(s) module.svp_prepare(&mut s_ppol, &s); - // a <- Z_{2^prec}[X]/(X^{N}+1) - let mut a: VecZnx = module.new_vec_znx(1, ct_size); - module.fill_uniform(log_base2k, &mut a, 0, ct_size, &mut source); + // ct = (c0, c1) + let mut ct: VecZnx = module.new_vec_znx(2, ct_size); + + // Fill c1 with random values + module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source); // Scratch space for DFT values - let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.size()); + let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, ct.size()); - // Applies buf_dft <- s * a - module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); + // Applies buf_dft <- s * c1 + module.svp_apply_dft( + &mut buf_dft, // DFT(c1 * s) + &s_ppol, + &ct, + 1, // c1 + ); - // Alias scratch space + // Alias scratch space (VecZnxDftis always at least as big as VecZnxBig) let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); - // buf_big <- IDFT(buf_dft) (not normalized) + // BIG(c1 * s) <- IDFT(DFT(c1 * s)) (not normalized) module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); - println!("{:?}", buf_big.raw()); - + // m <- (0) let mut m: VecZnx = module.new_vec_znx(1, msg_size); - let mut want: Vec = vec![0; n]; want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); - - // m m.encode_vec_i64(0, log_base2k, log_scale, &want, 4); m.normalize(log_base2k, &mut carry); - // buf_big <- m - buf_big + // m - BIG(c1 * s) module.vec_znx_big_sub_small_ab_inplace(&mut buf_big, &m); - println!("{:?}", buf_big.raw()); + // c0 <- m - BIG(c1 * s) + module.vec_znx_big_normalize(log_base2k, &mut ct, &buf_big, &mut carry); - // b <- normalize(buf_big) + e - let mut b: VecZnx = module.new_vec_znx(1, ct_size); - module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry); - b.print(n); + ct.print(ct.sl()); + + // (c0 + e, c1) module.add_normal( log_base2k, - &mut b, - 0, + &mut ct, + 0, // c0 log_base2k * ct_size, &mut source, 3.2, @@ -79,16 +80,16 @@ fn main() { // Decrypt - // buf_big <- a * s - module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); + // DFT(c1 * s) + module.svp_apply_dft(&mut buf_dft, &s_ppol, &ct, 1); + // BIG(c1 * s) = IDFT(DFT(c1 * s)) module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); - // buf_big <- a * s + b - module.vec_znx_big_add_small_inplace(&mut buf_big, &b); + // BIG(c1 * s) + c0 + module.vec_znx_big_add_small_inplace(&mut buf_big, &ct); - println!("raw: {:?}", &buf_big.raw()); - - // res <- normalize(buf_big) + // m + e <- BIG(c1 * s + c0) + let mut res: VecZnx = module.new_vec_znx(1, ct_size); module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry); // have = m * 2^{log_scale} + e diff --git a/base2k/src/commons.rs b/base2k/src/commons.rs index 969897d..d5f60ee 100644 --- a/base2k/src/commons.rs +++ b/base2k/src/commons.rs @@ -81,12 +81,12 @@ pub trait ZnxLayout: ZnxInfos { } /// Returns non-mutable reference to the (i, j)-th small polynomial. - fn at_poly(&self, i: usize, j: usize) -> &[Self::Scalar] { + fn at(&self, i: usize, j: usize) -> &[Self::Scalar] { unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } } /// Returns mutable reference to the (i, j)-th small polynomial. - fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { + fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } } @@ -219,7 +219,7 @@ pub fn rsh_tmp_bytes(n: usize, cols: usize) -> usize { n * cols * std::mem::size_of::() } -pub fn switch_degree(b: &mut T, a: &T) +pub fn switch_degree(b: &mut T, col_b: usize, a: &T, col_a: usize) where ::Scalar: IntegerType, { @@ -237,8 +237,8 @@ where (0..size).for_each(|i| { izip!( - a.at_limb(i).iter().step_by(gap_in), - b.at_limb_mut(i).iter_mut().step_by(gap_out) + a.at(col_a, i).iter().step_by(gap_in), + b.at_mut(col_b, i).iter_mut().step_by(gap_out) ) .for_each(|(x_in, x_out)| *x_out = *x_in); }); diff --git a/base2k/src/internals.rs b/base2k/src/internals.rs index d7b08dc..f2fbe3b 100644 --- a/base2k/src/internals.rs +++ b/base2k/src/internals.rs @@ -2,102 +2,6 @@ use std::cmp::{max, min}; use crate::{Backend, IntegerType, Module, ZnxBasics, ZnxLayout, ffi::module::MODULE}; -pub(crate) fn znx_post_process_ternary_op(c: &mut C, a: &A, b: &B) -where - C: ZnxBasics + ZnxLayout, - A: ZnxBasics + ZnxLayout, - B: ZnxBasics + ZnxLayout, - C::Scalar: IntegerType, -{ - #[cfg(debug_assertions)] - { - assert_ne!(a.as_ptr(), b.as_ptr()); - assert_ne!(b.as_ptr(), c.as_ptr()); - assert_ne!(a.as_ptr(), c.as_ptr()); - } - - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let c_cols: usize = c.cols(); - - let min_ab_cols: usize = min(a_cols, b_cols); - let max_ab_cols: usize = max(a_cols, b_cols); - - // Copies shared shared cols between (c, max(a, b)) - if a_cols != b_cols { - if a_cols > b_cols { - let min_size = min(c.size(), a.size()); - (min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| { - (0..min_size).for_each(|j| { - c.at_poly_mut(i, j).copy_from_slice(a.at_poly(i, j)); - if NEGATE { - c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); - } - }); - (min_size..c.size()).for_each(|j| { - c.zero_at(i, j); - }); - }); - } else { - let min_size = min(c.size(), b.size()); - (min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| { - (0..min_size).for_each(|j| { - c.at_poly_mut(i, j).copy_from_slice(b.at_poly(i, j)); - if NEGATE { - c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); - } - }); - (min_size..c.size()).for_each(|j| { - c.zero_at(i, j); - }); - }); - } - } - - // Zeroes the cols of c > max(a, b). - if c_cols > max_ab_cols { - (max_ab_cols..c_cols).for_each(|i| { - (0..c.size()).for_each(|j| { - c.zero_at(i, j); - }) - }); - } -} - -#[inline(always)] -pub fn apply_binary_op( - module: &Module, - c: &mut C, - a: &A, - b: &B, - op: impl Fn(&mut [C::Scalar], &[A::Scalar], &[B::Scalar]), -) where - BE: Backend, - C: ZnxBasics + ZnxLayout, - A: ZnxBasics + ZnxLayout, - B: ZnxBasics + ZnxLayout, - C::Scalar: IntegerType, -{ - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(c.n(), module.n()); - assert_ne!(a.as_ptr(), b.as_ptr()); - } - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let c_cols: usize = c.cols(); - let min_ab_cols: usize = min(a_cols, b_cols); - let min_cols: usize = min(c_cols, min_ab_cols); - // Applies over shared cols between (a, b, c) - (0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0))); - // Copies/Negates/Zeroes the remaining cols if op is not inplace. - if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() { - znx_post_process_ternary_op::(c, a, b); - } -} - #[inline(always)] pub fn apply_unary_op( module: &Module, diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index cfe2f45..474135b 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -230,7 +230,7 @@ pub trait ScalarZnxDftOps { /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// the [VecZnxDft] is multiplied with [SvpPPol]. - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx); + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx, b_col: usize); } impl ScalarZnxDftOps for Module { @@ -261,16 +261,16 @@ impl ScalarZnxDftOps for Module { unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) } } - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx) { + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx, b_col: usize) { unsafe { svp::svp_apply_dft( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, + c.size() as u64, a.ptr as *const svp_ppol_t, - b.as_ptr(), - b.cols() as u64, - b.n() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, ) } } diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 1bb8ab3..53aeb39 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -193,8 +193,8 @@ impl VecZnx { normalize(log_base2k, self, carry) } - pub fn switch_degree(&self, a: &mut Self) { - switch_degree(a, self) + pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) { + switch_degree(a, col_a, self, col) } // Prints the first `n` coefficients of each limb diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 4b4e54e..c87c95d 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -232,9 +232,9 @@ impl VecZnxBigOps for Module { } let a_size: usize = a.size(); - let b_size: usize = b.sl(); - let a_sl: usize = a.size(); - let b_sl: usize = a.sl(); + let b_size: usize = b.size(); + let a_sl: usize = a.sl(); + let b_sl: usize = b.sl(); let a_cols: usize = a.cols(); let b_cols: usize = b.cols(); let min_cols: usize = min(a_cols, b_cols); diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index c7c8d85..9f2d43a 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,9 +1,5 @@ -use std::cmp::min; - -use crate::ffi::module::MODULE; use crate::ffi::vec_znx; -use crate::internals::{apply_binary_op, apply_unary_op, ffi_binary_op_factory_type_0, ffi_binary_op_factory_type_1}; -use crate::{Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, assert_alignement, switch_degree}; +use crate::{Backend, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement, switch_degree}; pub trait VecZnxOps { /// Allocates a new [VecZnx]. /// @@ -43,62 +39,70 @@ pub trait VecZnxOps { fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; /// Returns the minimum number of bytes necessary for normalization. - fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize; + fn vec_znx_normalize_tmp_bytes(&self) -> usize; - /// Normalizes `a` and stores the result into `b`. - fn vec_znx_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnx, tmp_bytes: &mut [u8]); + /// Normalizes the selected column of `a` and stores the result into the selected column of `res`. + fn vec_znx_normalize( + &self, + log_base2k: usize, + res: &mut VecZnx, + col_res: usize, + a: &VecZnx, + col_a: usize, + tmp_bytes: &mut [u8], + ); - /// Normalizes `a` and stores the result into `a`. - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]); + /// Normalizes the selected column of `a`. + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, tmp_bytes: &mut [u8]); - /// Adds `a` to `b` and write the result on `c`. - fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); + /// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `c`. + fn vec_znx_add(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize); - /// Adds `a` to `b` and write the result on `b`. - fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx); + /// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `res`. + fn vec_znx_add_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); - /// Subtracts `b` to `a` and write the result on `c`. - fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); + /// Subtracts the selected column of `b` to the selected column of `a` and write the result on the selected column of `res`. + fn vec_znx_sub(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize); - /// Subtracts `a` to `b` and write the result on `b`. - fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx); + /// Subtracts the selected column of `a` to the selected column of `res`. + fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); - /// Subtracts `b` to `a` and write the result on `b`. - fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx); + /// Subtracts the selected column of `a` to the selected column of `res` and negates the selected column of `res`. + fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); - // Negates `a` and stores the result on `b`. - fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx); + // Negates the selected column of `a` and stores the result on the selected column of `res`. + fn vec_znx_negate(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); - /// Negages `a` and stores the result on `a`. - fn vec_znx_negate_inplace(&self, a: &mut VecZnx); + /// Negates the selected column of `a`. + fn vec_znx_negate_inplace(&self, a: &mut VecZnx, col_a: usize); - /// Multiplies `a` by X^k and stores the result on `b`. - fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx); + /// Multiplies the selected column of `a` by X^k and stores the result on the selected column of `res`. + fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); - /// Multiplies `a` by X^k and stores the result on `a`. - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx); + /// Multiplies the selected column of `a` by X^k. + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize); - /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx); + /// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result on the selected column of `res`. + fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); - /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx); + /// Applies the automorphism X^i -> X^ik on the selected column of `a`. + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize); - /// Splits b into subrings and copies them them into a. + /// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`. /// /// # Panics /// /// This method requires that all [VecZnx] of b have the same ring degree /// and that b.n() * b.len() <= a.n() - fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx); + fn vec_znx_split(&self, res: &mut Vec, col_res: usize, a: &VecZnx, col_a: usize, buf: &mut VecZnx); - /// Merges the subrings a into b. + /// Merges the subrings of the selected column of `a` into the selected column of `res`. /// /// # Panics /// /// This method requires that all [VecZnx] of a have the same ring degree /// and that a.n() * a.len() <= b.n() - fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec); + fn vec_znx_merge(&self, res: &mut VecZnx, col_res: usize, a: &Vec, col_a: usize); } impl VecZnxOps for Module { @@ -118,164 +122,213 @@ impl VecZnxOps for Module { VecZnx::from_bytes_borrow(self, cols, size, tmp_bytes) } - fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols } + fn vec_znx_normalize_tmp_bytes(&self) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize } } - fn vec_znx_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnx, tmp_bytes: &mut [u8]) { + fn vec_znx_normalize( + &self, + log_base2k: usize, + res: &mut VecZnx, + col_res: usize, + a: &VecZnx, + col_a: usize, + tmp_bytes: &mut [u8], + ) { #[cfg(debug_assertions)] { - assert!(tmp_bytes.len() >= Self::vec_znx_normalize_tmp_bytes(&self, a.cols())); + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert!(tmp_bytes.len() >= Self::vec_znx_normalize_tmp_bytes(&self)); assert_alignement(tmp_bytes.as_ptr()); } - - let a_size: usize = a.size(); - let b_size: usize = b.sl(); - let a_sl: usize = a.size(); - let b_sl: usize = a.sl(); - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let min_cols: usize = min(a_cols, b_cols); - (0..min_cols).for_each(|i| unsafe { + unsafe { vec_znx::vec_znx_normalize_base2k( self.ptr, log_base2k as u64, - b.at_mut_ptr(i, 0), - b_size as u64, - b_sl as u64, - a.at_ptr(i, 0), - a_size as u64, - a_sl as u64, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, tmp_bytes.as_mut_ptr(), ); - }); - - (min_cols..b_cols).for_each(|i| (0..b_size).for_each(|j| b.zero_at(i, j))); + } } - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, tmp_bytes: &mut [u8]) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_normalize(self, log_base2k, &mut *a_ptr, &*a_ptr, tmp_bytes); + Self::vec_znx_normalize( + self, + log_base2k, + &mut *a_ptr, + col_a, + &*a_ptr, + col_a, + tmp_bytes, + ); } } - fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { - let op = ffi_ternary_op_factory( - self.ptr, - c.size(), - c.sl(), - a.size(), - a.sl(), - b.size(), - b.sl(), - vec_znx::vec_znx_add, - ); - apply_binary_op::(self, c, a, b, op); - } - - fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + fn vec_znx_add(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } unsafe { - let b_ptr: *mut VecZnx = b as *mut VecZnx; - Self::vec_znx_add(self, &mut *b_ptr, a, &*b_ptr); + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(col_b, 0), + b.size() as u64, + b.sl() as u64, + ) } } - fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { - let op = ffi_ternary_op_factory( - self.ptr, - c.size(), - c.sl(), - a.size(), - a.sl(), - b.size(), - b.sl(), - vec_znx::vec_znx_sub, - ); - apply_binary_op::(self, c, a, b, op); - } - - fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + fn vec_znx_add_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { unsafe { - let b_ptr: *mut VecZnx = b as *mut VecZnx; - Self::vec_znx_sub(self, &mut *b_ptr, a, &*b_ptr); + let res_ptr: *mut VecZnx = res as *mut VecZnx; + Self::vec_znx_add(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); } } - fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + fn vec_znx_sub(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } unsafe { - let b_ptr: *mut VecZnx = b as *mut VecZnx; - Self::vec_znx_sub(self, &mut *b_ptr, &*b_ptr, a); + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(col_b, 0), + b.size() as u64, + b.sl() as u64, + ) } } - fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) { - let op = ffi_binary_op_factory_type_0( - self.ptr, - b.size(), - b.sl(), - a.size(), - a.sl(), - vec_znx::vec_znx_negate, - ); - apply_unary_op::(self, b, a, op); + fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + unsafe { + let res_ptr: *mut VecZnx = res as *mut VecZnx; + Self::vec_znx_sub(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + } } - fn vec_znx_negate_inplace(&self, a: &mut VecZnx) { + fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + unsafe { + let res_ptr: *mut VecZnx = res as *mut VecZnx; + Self::vec_znx_sub(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a); + } + } + + fn vec_znx_negate(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_negate( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_negate_inplace(&self, a: &mut VecZnx, col_a: usize) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_negate(self, &mut *a_ptr, &*a_ptr); + Self::vec_znx_negate(self, &mut *a_ptr, col_a, &*a_ptr, col_a); } } - fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { - let op = ffi_binary_op_factory_type_1( - self.ptr, - k, - b.size(), - b.sl(), - a.size(), - a.sl(), - vec_znx::vec_znx_rotate, - ); - apply_unary_op::(self, b, a, op); + fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_rotate( + self.ptr, + k, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + ) + } } - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) { + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_rotate(self, k, &mut *a_ptr, &*a_ptr); + Self::vec_znx_rotate(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a); } } - fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { - let op = ffi_binary_op_factory_type_1( - self.ptr, - k, - b.size(), - b.sl(), - a.size(), - a.sl(), - vec_znx::vec_znx_automorphism, - ); - apply_unary_op::(self, b, a, op); + fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + self.ptr, + k, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + ) + } } - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_automorphism(self, k, &mut *a_ptr, &*a_ptr); + Self::vec_znx_automorphism(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a); } } - fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx) { - let (n_in, n_out) = (a.n(), b[0].n()); + fn vec_znx_split(&self, res: &mut Vec, col_res: usize, a: &VecZnx, col_a: usize, buf: &mut VecZnx) { + let (n_in, n_out) = (a.n(), res[0].n()); debug_assert!( n_out < n_in, "invalid a: output ring degree should be smaller" ); - b[1..].iter().for_each(|bi| { + res[1..].iter().for_each(|bi| { debug_assert_eq!( bi.n(), n_out, @@ -283,19 +336,19 @@ impl VecZnxOps for Module { ) }); - b.iter_mut().enumerate().for_each(|(i, bi)| { + res.iter_mut().enumerate().for_each(|(i, bi)| { if i == 0 { - switch_degree(bi, a); - self.vec_znx_rotate(-1, buf, a); + switch_degree(bi, col_res, a, col_a); + self.vec_znx_rotate(-1, buf, 0, a, col_a); } else { - switch_degree(bi, buf); - self.vec_znx_rotate_inplace(-1, buf); + switch_degree(bi, col_res, buf, col_a); + self.vec_znx_rotate_inplace(-1, buf, col_a); } }) } - fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec) { - let (n_in, n_out) = (b.n(), a[0].n()); + fn vec_znx_merge(&self, res: &mut VecZnx, col_res: usize, a: &Vec, col_a: usize) { + let (n_in, n_out) = (res.n(), a[0].n()); debug_assert!( n_out < n_in, @@ -310,456 +363,10 @@ impl VecZnxOps for Module { }); a.iter().enumerate().for_each(|(_, ai)| { - switch_degree(b, ai); - self.vec_znx_rotate_inplace(-1, b); + switch_degree(res, col_res, ai, col_a); + self.vec_znx_rotate_inplace(-1, res, col_res); }); - self.vec_znx_rotate_inplace(a.len() as i64, b); - } -} - -fn ffi_ternary_op_factory( - module_ptr: *const MODULE, - c_size: usize, - c_sl: usize, - a_size: usize, - a_sl: usize, - b_size: usize, - b_sl: usize, - op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64, *const i64, u64, u64), -) -> impl Fn(&mut [i64], &[i64], &[i64]) { - move |cv: &mut [i64], av: &[i64], bv: &[i64]| unsafe { - op_fn( - module_ptr, - cv.as_mut_ptr(), - c_size as u64, - c_sl as u64, - av.as_ptr(), - a_size as u64, - a_sl as u64, - bv.as_ptr(), - b_size as u64, - b_sl as u64, - ) - } -} - -#[cfg(test)] -mod tests { - use crate::internals::znx_post_process_ternary_op; - use crate::{Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx}; - - use itertools::izip; - use sampling::source::Source; - use std::cmp::min; - - #[test] - fn vec_znx_add() { - let n: usize = 8; - let module: Module = Module::::new(n); - let op = |cv: &mut [i64], av: &[i64], bv: &[i64]| { - izip!(cv.iter_mut(), bv.iter(), av.iter()).for_each(|(ci, bi, ai)| *ci = *bi + *ai); - }; - test_binary_op::( - &module, - &|c: &mut VecZnx, a: &VecZnx, b: &VecZnx| module.vec_znx_add(c, a, b), - op, - ); - } - - #[test] - fn vec_znx_add_inplace() { - let n: usize = 8; - let module: Module = Module::::new(n); - let op = |bv: &mut [i64], av: &[i64]| { - izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *bi + *ai); - }; - test_binary_op_inplace::( - &module, - &|b: &mut VecZnx, a: &VecZnx| module.vec_znx_add_inplace(b, a), - op, - ); - } - - #[test] - fn vec_znx_sub() { - let n: usize = 8; - let module: Module = Module::::new(n); - let op = |cv: &mut [i64], av: &[i64], bv: &[i64]| { - izip!(cv.iter_mut(), bv.iter(), av.iter()).for_each(|(ci, bi, ai)| *ci = *bi - *ai); - }; - test_binary_op::( - &module, - &|c: &mut VecZnx, a: &VecZnx, b: &VecZnx| module.vec_znx_sub(c, a, b), - op, - ); - } - - #[test] - fn vec_znx_sub_ab_inplace() { - let n: usize = 8; - let module: Module = Module::::new(n); - let op = |bv: &mut [i64], av: &[i64]| { - izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *ai - *bi); - }; - test_binary_op_inplace::( - &module, - &|b: &mut VecZnx, a: &VecZnx| module.vec_znx_sub_ab_inplace(b, a), - op, - ); - } - - #[test] - fn vec_znx_sub_ba_inplace() { - let n: usize = 8; - let module: Module = Module::::new(n); - let op = |bv: &mut [i64], av: &[i64]| { - izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *bi - *ai); - }; - test_binary_op_inplace::( - &module, - &|b: &mut VecZnx, a: &VecZnx| module.vec_znx_sub_ba_inplace(b, a), - op, - ); - } - - #[test] - fn vec_znx_negate() { - let n: usize = 8; - let module: Module = Module::::new(n); - let op = |b: &mut [i64], a: &[i64]| { - izip!(b.iter_mut(), a.iter()).for_each(|(bi, ai)| *bi = -*ai); - }; - test_unary_op( - &module, - |b: &mut VecZnx, a: &VecZnx| module.vec_znx_negate(b, a), - op, - ) - } - - #[test] - fn vec_znx_negate_inplace() { - let n: usize = 8; - let module: Module = Module::::new(n); - let op = |a: &mut [i64]| a.iter_mut().for_each(|xi| *xi = -*xi); - test_unary_op_inplace( - &module, - |a: &mut VecZnx| module.vec_znx_negate_inplace(a), - op, - ) - } - - #[test] - fn vec_znx_rotate() { - let n: usize = 8; - let module: Module = Module::::new(n); - let k: i64 = 53; - let op = |b: &mut [i64], a: &[i64]| { - assert_eq!(b.len(), a.len()); - b.copy_from_slice(a); - - let mut k_mod2n: i64 = k % (2 * n as i64); - if k_mod2n < 0 { - k_mod2n += 2 * n as i64; - } - let sign: i64 = (k_mod2n.abs() / (n as i64)) & 1; - let k_modn: i64 = k_mod2n % (n as i64); - - b.rotate_right(k_modn as usize); - b[0..k_modn as usize].iter_mut().for_each(|x| *x = -*x); - - if sign == 1 { - b.iter_mut().for_each(|x| *x = -*x); - } - }; - test_unary_op( - &module, - |b: &mut VecZnx, a: &VecZnx| module.vec_znx_rotate(k, b, a), - op, - ) - } - - #[test] - fn vec_znx_rotate_inplace() { - let n: usize = 8; - let module: Module = Module::::new(n); - let k: i64 = 53; - let rot = |a: &mut [i64]| { - let mut k_mod2n: i64 = k % (2 * n as i64); - if k_mod2n < 0 { - k_mod2n += 2 * n as i64; - } - let sign: i64 = (k_mod2n.abs() / (n as i64)) & 1; - let k_modn: i64 = k_mod2n % (n as i64); - - a.rotate_right(k_modn as usize); - a[0..k_modn as usize].iter_mut().for_each(|x| *x = -*x); - - if sign == 1 { - a.iter_mut().for_each(|x| *x = -*x); - } - }; - test_unary_op_inplace( - &module, - |a: &mut VecZnx| module.vec_znx_rotate_inplace(k, a), - rot, - ) - } - - #[test] - fn vec_znx_automorphism() { - let n: usize = 8; - let module: Module = Module::::new(n); - let k: i64 = -5; - let op = |b: &mut [i64], a: &[i64]| { - assert_eq!(b.len(), a.len()); - unsafe { - vec_znx::vec_znx_automorphism( - module.ptr, - k, - b.as_mut_ptr(), - 1u64, - n as u64, - a.as_ptr(), - 1u64, - n as u64, - ); - } - }; - test_unary_op( - &module, - |b: &mut VecZnx, a: &VecZnx| module.vec_znx_automorphism(k, b, a), - op, - ) - } - - #[test] - fn vec_znx_automorphism_inplace() { - let n: usize = 8; - let module: Module = Module::::new(n); - let k: i64 = -5; - let op = |a: &mut [i64]| unsafe { - vec_znx::vec_znx_automorphism( - module.ptr, - k, - a.as_mut_ptr(), - 1u64, - n as u64, - a.as_ptr(), - 1u64, - n as u64, - ); - }; - test_unary_op_inplace( - &module, - |a: &mut VecZnx| module.vec_znx_automorphism_inplace(k, a), - op, - ) - } - - fn test_binary_op( - module: &Module, - func_have: impl Fn(&mut VecZnx, &VecZnx, &VecZnx), - func_want: impl Fn(&mut [i64], &[i64], &[i64]), - ) { - let a_size: usize = 3; - let b_size: usize = 4; - let c_size: usize = 5; - let mut source: Source = Source::new([0u8; 32]); - - [1usize, 2, 3].iter().for_each(|a_cols| { - [1usize, 2, 3].iter().for_each(|b_cols| { - [1usize, 2, 3].iter().for_each(|c_cols| { - let min_ab_cols: usize = min(*a_cols, *b_cols); - let min_cols: usize = min(*c_cols, min_ab_cols); - let min_size: usize = min(c_size, min(a_size, b_size)); - - // Allocats a and populates with random values. - let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); - (0..*a_cols).for_each(|i| { - module.fill_uniform(3, &mut a, i, a_size, &mut source); - }); - - // Allocats b and populates with random values. - let mut b: VecZnx = module.new_vec_znx(*b_cols, b_size); - (0..*b_cols).for_each(|i| { - module.fill_uniform(3, &mut b, i, b_size, &mut source); - }); - - // Allocats c and populates with random values. - let mut c_have: VecZnx = module.new_vec_znx(*c_cols, c_size); - (0..c_have.cols()).for_each(|i| { - module.fill_uniform(3, &mut c_have, i, c_size, &mut source); - }); - - // Applies the function to test - func_have(&mut c_have, &a, &b); - - let mut c_want: VecZnx = module.new_vec_znx(*c_cols, c_size); - - // Applies the reference function and expected behavior. - // Adds with the minimum matching columns - (0..min_cols).for_each(|i| { - // Adds with th eminimum matching size - (0..min_size).for_each(|j| { - func_want(c_want.at_poly_mut(i, j), b.at_poly(i, j), a.at_poly(i, j)); - }); - - if a_size > b_size { - // Copies remaining size of lh if lh.size() > rh.size() - (min_size..a_size).for_each(|j| { - izip!(c_want.at_poly_mut(i, j).iter_mut(), a.at_poly(i, j).iter()).for_each(|(ci, ai)| *ci = *ai); - if NEGATE { - c_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); - } - }); - } else { - // Copies the remaining size of rh if the are greater - (min_size..b_size).for_each(|j| { - izip!(c_want.at_poly_mut(i, j).iter_mut(), b.at_poly(i, j).iter()).for_each(|(ci, bi)| *ci = *bi); - if NEGATE { - c_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); - } - }); - } - }); - - znx_post_process_ternary_op::(&mut c_want, &a, &b); - - assert_eq!(c_have.raw(), c_want.raw()); - }); - }); - }); - } - - fn test_binary_op_inplace( - module: &Module, - func_have: impl Fn(&mut VecZnx, &VecZnx), - func_want: impl Fn(&mut [i64], &[i64]), - ) { - let a_size: usize = 3; - let b_size: usize = 5; - let mut source = Source::new([0u8; 32]); - - [1usize, 2, 3].iter().for_each(|a_cols| { - [1usize, 2, 3].iter().for_each(|b_cols| { - let min_cols: usize = min(*b_cols, *a_cols); - let min_size: usize = min(b_size, a_size); - - // Allocats a and populates with random values. - let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); - (0..*a_cols).for_each(|i| { - module.fill_uniform(3, &mut a, i, a_size, &mut source); - }); - - // Allocats b and populates with random values. - let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size); - (0..*b_cols).for_each(|i| { - module.fill_uniform(3, &mut b_have, i, b_size, &mut source); - }); - - let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size); - b_want.raw_mut().copy_from_slice(b_have.raw()); - - // Applies the function to test. - func_have(&mut b_have, &a); - - // Applies the reference function and expected behavior. - // Applies with the minimum matching columns - (0..min_cols).for_each(|i| { - // Adds with th eminimum matching size - (0..min_size).for_each(|j| func_want(b_want.at_poly_mut(i, j), a.at_poly(i, j))); - if NEGATE { - (min_size..b_size).for_each(|j| { - b_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); - }); - } - }); - - assert_eq!(b_have.raw(), b_want.raw()); - }); - }); - } - - fn test_unary_op( - module: &Module, - func_have: impl Fn(&mut VecZnx, &VecZnx), - func_want: impl Fn(&mut [i64], &[i64]), - ) { - let a_size: usize = 3; - let b_size: usize = 5; - let mut source = Source::new([0u8; 32]); - - [1usize, 2, 3].iter().for_each(|a_cols| { - [1usize, 2, 3].iter().for_each(|b_cols| { - let min_cols: usize = min(*b_cols, *a_cols); - let min_size: usize = min(b_size, a_size); - - // Allocats a and populates with random values. - let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); - (0..a.cols()).for_each(|i| { - module.fill_uniform(3, &mut a, i, a_size, &mut source); - }); - - // Allocats b and populates with random values. - let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size); - (0..b_have.cols()).for_each(|i| { - module.fill_uniform(3, &mut b_have, i, b_size, &mut source); - }); - - let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size); - - // Applies the function to test. - func_have(&mut b_have, &a); - - // Applies the reference function and expected behavior. - // Applies on the minimum matching columns - (0..min_cols).for_each(|i| { - // Applies on the minimum matching size - (0..min_size).for_each(|j| func_want(b_want.at_poly_mut(i, j), a.at_poly(i, j))); - - // Zeroes the unmatching size - (min_size..b_size).for_each(|j| { - b_want.zero_at(i, j); - }) - }); - - // Zeroes the unmatching columns - (min_cols..*b_cols).for_each(|i| { - (0..b_size).for_each(|j| { - b_want.zero_at(i, j); - }) - }); - - assert_eq!(b_have.raw(), b_want.raw()); - }); - }); - } - - fn test_unary_op_inplace(module: &Module, func_have: impl Fn(&mut VecZnx), func_want: impl Fn(&mut [i64])) { - let a_size: usize = 3; - let mut source = Source::new([0u8; 32]); - [1usize, 2, 3].iter().for_each(|a_cols| { - let mut a_have: VecZnx = module.new_vec_znx(*a_cols, a_size); - (0..*a_cols).for_each(|i| { - module.fill_uniform(3, &mut a_have, i, a_size, &mut source); - }); - - // Allocats a and populates with random values. - let mut a_want: VecZnx = module.new_vec_znx(*a_cols, a_size); - a_have.raw_mut().copy_from_slice(a_want.raw()); - - // Applies the function to test. - func_have(&mut a_have); - - // Applies the reference function and expected behavior. - // Applies on the minimum matching columns - (0..*a_cols).for_each(|i| { - // Applies on the minimum matching size - (0..a_size).for_each(|j| func_want(a_want.at_poly_mut(i, j))); - }); - - assert_eq!(a_have.raw(), a_want.raw()); - }); + self.vec_znx_rotate_inplace(a.len() as i64, res, col_res); } } From 06d0c5e8328cd7888fb66bc389e1903ca9bf99a6 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 29 Apr 2025 18:16:09 +0200 Subject: [PATCH 17/87] more fixes --- base2k/src/encoding.rs | 24 +++++----- base2k/src/internals.rs | 96 --------------------------------------- base2k/src/lib.rs | 1 - base2k/src/sampling.rs | 18 ++++---- base2k/src/vec_znx_dft.rs | 2 +- 5 files changed, 22 insertions(+), 119 deletions(-) delete mode 100644 base2k/src/internals.rs diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 8c41381..7f8a0cc 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -107,7 +107,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_poly_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]); + a.at_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]); } else { let mask: i64 = (1 << log_base2k) - 1; let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); @@ -116,7 +116,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, .enumerate() .for_each(|(i, i_rev)| { let shift: usize = i * log_base2k; - izip!(a.at_poly_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); + izip!(a.at_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); }) } @@ -124,7 +124,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, if log_k_rem != log_base2k { let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); (size - steps..size).rev().for_each(|i| { - a.at_poly_mut(col_i, i)[..data_len] + a.at_mut(col_i, i)[..data_len] .iter_mut() .for_each(|x| *x <<= log_k_rem); }) @@ -143,16 +143,16 @@ fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, dat ); assert!(col_i < a.cols()); } - data.copy_from_slice(a.at_poly(col_i, 0)); + data.copy_from_slice(a.at(col_i, 0)); let rem: usize = log_base2k - (log_k % log_base2k); (1..size).for_each(|i| { if i == size - 1 && rem != log_base2k { let k_rem: usize = log_base2k - rem; - izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << k_rem) + (x >> rem); }); } else { - izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << log_base2k) + x; }); } @@ -180,12 +180,12 @@ fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Flo // y[i] = sum x[j][i] * 2^{-log_base2k*j} (0..size).for_each(|i| { if i == 0 { - izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { y.assign(*x); *y /= &base; }); } else { - izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { *y += Float::with_val(prec, *x); *y /= &base; }); @@ -209,13 +209,13 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz } let log_k_rem: usize = log_base2k - (log_k % log_base2k); - (0..a.size()).for_each(|j| a.at_poly_mut(col_i, j)[i] = 0); + (0..a.size()).for_each(|j| a.at_mut(col_i, j)[i] = 0); // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_poly_mut(col_i, size - 1)[i] = value; + a.at_mut(col_i, size - 1)[i] = value; } else { let mask: i64 = (1 << log_base2k) - 1; let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); @@ -223,7 +223,7 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz .rev() .enumerate() .for_each(|(j, j_rev)| { - a.at_poly_mut(col_i, j_rev)[i] = (value >> (j * log_base2k)) & mask; + a.at_mut(col_i, j_rev)[i] = (value >> (j * log_base2k)) & mask; }) } @@ -231,7 +231,7 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz if log_k_rem != log_base2k { let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); (size - steps..size).rev().for_each(|j| { - a.at_poly_mut(col_i, j)[i] <<= log_k_rem; + a.at_mut(col_i, j)[i] <<= log_k_rem; }) } } diff --git a/base2k/src/internals.rs b/base2k/src/internals.rs deleted file mode 100644 index f2fbe3b..0000000 --- a/base2k/src/internals.rs +++ /dev/null @@ -1,96 +0,0 @@ -use std::cmp::{max, min}; - -use crate::{Backend, IntegerType, Module, ZnxBasics, ZnxLayout, ffi::module::MODULE}; - -#[inline(always)] -pub fn apply_unary_op( - module: &Module, - b: &mut T, - a: &T, - op: impl Fn(&mut [T::Scalar], &[T::Scalar]), -) where - ::Scalar: IntegerType, -{ - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - } - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let min_cols: usize = min(a_cols, b_cols); - // Applies over the shared cols between (a, b) - (0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0))); - // Zeroes the remaining cols of b. - (min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j))); -} - -pub fn ffi_ternary_op_factory( - module_ptr: *const MODULE, - c_size: usize, - c_sl: usize, - a_size: usize, - a_sl: usize, - b_size: usize, - b_sl: usize, - op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64, *const T, u64, u64), -) -> impl Fn(&mut [T], &[T], &[T]) { - move |cv: &mut [T], av: &[T], bv: &[T]| unsafe { - op_fn( - module_ptr, - cv.as_mut_ptr(), - c_size as u64, - c_sl as u64, - av.as_ptr(), - a_size as u64, - a_sl as u64, - bv.as_ptr(), - b_size as u64, - b_sl as u64, - ) - } -} - -pub fn ffi_binary_op_factory_type_0( - module_ptr: *const MODULE, - b_size: usize, - b_sl: usize, - a_size: usize, - a_sl: usize, - op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64), -) -> impl Fn(&mut [T], &[T]) { - move |bv: &mut [T], av: &[T]| unsafe { - op_fn( - module_ptr, - bv.as_mut_ptr(), - b_size as u64, - b_sl as u64, - av.as_ptr(), - a_size as u64, - a_sl as u64, - ) - } -} - -pub fn ffi_binary_op_factory_type_1( - module_ptr: *const MODULE, - k: i64, - b_size: usize, - b_sl: usize, - a_size: usize, - a_sl: usize, - op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut T, u64, u64, *const T, u64, u64), -) -> impl Fn(&mut [T], &[T]) { - move |bv: &mut [T], av: &[T]| unsafe { - op_fn( - module_ptr, - k, - bv.as_mut_ptr(), - b_size as u64, - b_sl as u64, - av.as_ptr(), - a_size as u64, - a_sl as u64, - ) - } -} diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 2a9a899..7a8a3f8 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -3,7 +3,6 @@ pub mod encoding; #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] // Other modules and exports pub mod ffi; -mod internals; pub mod mat_znx_dft; pub mod module; pub mod sampling; diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index a96937e..5261207 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -32,12 +32,12 @@ pub trait Sampling { } impl Sampling for Module { - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source) { + fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, size: usize, source: &mut Source) { let base2k: u64 = 1 << log_base2k; let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; (0..size).for_each(|j| { - a.at_poly_mut(col_i, j) + a.at_mut(col_a, j) .iter_mut() .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); }) @@ -47,7 +47,7 @@ impl Sampling for Module { &self, log_base2k: usize, a: &mut VecZnx, - col_i: usize, + col_a: usize, log_k: usize, source: &mut Source, dist: D, @@ -63,7 +63,7 @@ impl Sampling for Module { let log_base2k_rem: usize = log_k % log_base2k; if log_base2k_rem != 0 { - a.at_poly_mut(col_i, limb).iter_mut().for_each(|a| { + a.at_mut(col_a, limb).iter_mut().for_each(|a| { let mut dist_f64: f64 = dist.sample(source); while dist_f64.abs() > bound { dist_f64 = dist.sample(source) @@ -71,7 +71,7 @@ impl Sampling for Module { *a += (dist_f64.round() as i64) << log_base2k_rem; }); } else { - a.at_poly_mut(col_i, limb).iter_mut().for_each(|a| { + a.at_mut(col_a, limb).iter_mut().for_each(|a| { let mut dist_f64: f64 = dist.sample(source); while dist_f64.abs() > bound { dist_f64 = dist.sample(source) @@ -85,7 +85,7 @@ impl Sampling for Module { &self, log_base2k: usize, a: &mut VecZnx, - col_i: usize, + col_a: usize, log_k: usize, source: &mut Source, sigma: f64, @@ -94,7 +94,7 @@ impl Sampling for Module { self.add_dist_f64( log_base2k, a, - col_i, + col_a, log_k, source, Normal::new(0.0, sigma).unwrap(), @@ -125,7 +125,7 @@ mod tests { (0..cols).for_each(|col_j| { if col_j != col_i { (0..size).for_each(|limb_i| { - assert_eq!(a.at_poly(col_j, limb_i), zero); + assert_eq!(a.at(col_j, limb_i), zero); }) } else { let std: f64 = a.std(col_i, log_base2k); @@ -159,7 +159,7 @@ mod tests { (0..cols).for_each(|col_j| { if col_j != col_i { (0..size).for_each(|limb_i| { - assert_eq!(a.at_poly(col_j, limb_i), zero); + assert_eq!(a.at(col_j, limb_i), zero); }) } else { let std: f64 = a.std(col_i, log_base2k) * k_f64; diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index d9c9e60..1b88af5 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -341,7 +341,7 @@ mod tests { module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes); // a <- AUTO(a) - module.vec_znx_automorphism_inplace(p, &mut a); + module.vec_znx_automorphism_inplace(p, &mut a, 0); // b_dft <- DFT(AUTO(a)) module.vec_znx_dft(&mut b_dft, &a); From 2cc51eee18bcf2905b000a236fdd27906b389ae0 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 29 Apr 2025 21:53:27 +0200 Subject: [PATCH 18/87] working rlwe encryption example with interleaved polynomial --- base2k/examples/rlwe_encrypt.rs | 94 +++-- base2k/examples/vector_matrix_product.rs | 2 +- base2k/src/vec_znx_big_ops.rs | 478 +++++++++++++++-------- 3 files changed, 371 insertions(+), 203 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 395fdf6..ee2bd02 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -13,7 +13,7 @@ fn main() { let log_scale: usize = msg_size * log_base2k - 5; let module: Module = Module::::new(n); - let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes(2)); + let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -28,69 +28,95 @@ fn main() { // s_ppol <- DFT(s) module.svp_prepare(&mut s_ppol, &s); - // ct = (c0, c1) - let mut ct: VecZnx = module.new_vec_znx(2, ct_size); + // Allocates a VecZnx with two columns: ct=(0, 0) + let mut ct: VecZnx = module.new_vec_znx( + 2, // Number of columns + ct_size, // Number of small poly per column + ); - // Fill c1 with random values + // Fill the second column with random values: ct = (0, a) module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source); // Scratch space for DFT values - let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, ct.size()); - - // Applies buf_dft <- s * c1 - module.svp_apply_dft( - &mut buf_dft, // DFT(c1 * s) - &s_ppol, - &ct, - 1, // c1 + let mut buf_dft: VecZnxDft = module.new_vec_znx_dft( + 1, // Number of columns + ct.size(), // Number of polynomials per column ); - // Alias scratch space (VecZnxDftis always at least as big as VecZnxBig) + // Applies DFT(ct[1]) * DFT(s) + module.svp_apply_dft( + &mut buf_dft, // DFT(ct[1] * s) + &s_ppol, // DFT(s) + &ct, + 1, // Selects the second column of ct + ); + + // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); - // BIG(c1 * s) <- IDFT(DFT(c1 * s)) (not normalized) + // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); - // m <- (0) - let mut m: VecZnx = module.new_vec_znx(1, msg_size); + // Creates a plaintext: VecZnx with 1 column + let mut m: VecZnx = module.new_vec_znx( + 1, // Number of columns + msg_size, // Number of small polynomials + ); let mut want: Vec = vec![0; n]; want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); m.encode_vec_i64(0, log_base2k, log_scale, &want, 4); m.normalize(log_base2k, &mut carry); - // m - BIG(c1 * s) - module.vec_znx_big_sub_small_ab_inplace(&mut buf_big, &m); + // m - BIG(ct[1] * s) + module.vec_znx_big_sub_small_a_inplace( + &mut buf_big, + 0, // Selects the first column of the receiver + &m, + 0, // Selects the first column of the message + ); - // c0 <- m - BIG(c1 * s) - module.vec_znx_big_normalize(log_base2k, &mut ct, &buf_big, &mut carry); + // Normalizes back to VecZnx + // ct[0] <- m - BIG(c1 * s) + module.vec_znx_big_normalize( + log_base2k, &mut ct, 0, // Selects the first column of ct (ct[0]) + &buf_big, 0, // Selects the first column of buf_big + &mut carry, + ); - ct.print(ct.sl()); - - // (c0 + e, c1) + // Add noise to ct[0] + // ct[0] <- ct[0] + e module.add_normal( log_base2k, &mut ct, - 0, // c0 - log_base2k * ct_size, + 0, // Selects the first column of ct (ct[0]) + log_base2k * ct_size, // Scaling of the noise: 2^{-log_base2k * limbs} &mut source, - 3.2, - 19.0, + 3.2, // Standard deviation + 19.0, // Truncatation bound ); - // Decrypt + // Final ciphertext: ct = (-a * s + m + e, a) + + // Decryption + + // DFT(ct[1] * s) + module.svp_apply_dft( + &mut buf_dft, + &s_ppol, + &ct, + 1, // Selects the second column of ct (ct[1]) + ); - // DFT(c1 * s) - module.svp_apply_dft(&mut buf_dft, &s_ppol, &ct, 1); // BIG(c1 * s) = IDFT(DFT(c1 * s)) module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); - // BIG(c1 * s) + c0 - module.vec_znx_big_add_small_inplace(&mut buf_big, &ct); + // BIG(c1 * s) + ct[0] + module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0); - // m + e <- BIG(c1 * s + c0) + // m + e <- BIG(ct[1] * s + ct[0]) let mut res: VecZnx = module.new_vec_znx(1, ct_size); - module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry); + module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, &mut carry); // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 96a0df7..2f4b1fb 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -46,7 +46,7 @@ fn main() { module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft); let mut res: VecZnx = module.new_vec_znx(1, limbs_vec); - module.vec_znx_big_normalize(log_base2k, &mut res, &c_big, &mut buf); + module.vec_znx_big_normalize(log_base2k, &mut res, 0, &c_big, 0, &mut buf); let mut values_res: Vec = vec![i64::default(); n]; res.decode_vec_i64(0, log_base2k, log_k, &mut values_res); diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index c87c95d..e59fda1 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,8 +1,5 @@ -use std::cmp::min; - use crate::ffi::vec_znx; -use crate::internals::{apply_binary_op, apply_unary_op, ffi_binary_op_factory_type_1, ffi_ternary_op_factory}; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement}; pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. @@ -41,40 +38,80 @@ pub trait VecZnxBigOps { fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; /// Adds `a` to `b` and stores the result on `c`. - fn vec_znx_big_add(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig); + fn vec_znx_big_add( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ); /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig); + fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); /// Adds `a` to `b` and stores the result on `c`. - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); + fn vec_znx_big_add_small( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnx, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ); /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize); /// Subtracts `a` to `b` and stores the result on `c`. - fn vec_znx_big_sub(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig); + fn vec_znx_big_sub( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ); /// Subtracts `a` to `b` and stores the result on `b`. - fn vec_znx_big_sub_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig); + fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); /// Subtracts `b` to `a` and stores the result on `b`. - fn vec_znx_big_sub_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig); + fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); /// Subtracts `b` to `a` and stores the result on `c`. - fn vec_znx_big_sub_small_ab(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); + fn vec_znx_big_sub_small_a( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnx, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ); /// Subtracts `a` to `b` and stores the result on `b`. - fn vec_znx_big_sub_small_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize); /// Subtracts `b` to `a` and stores the result on `c`. - fn vec_znx_big_sub_small_ba(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnx); + fn vec_znx_big_sub_small_b( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + b: &VecZnx, + col_b: usize, + ); /// Subtracts `b` to `a` and stores the result on `b`. - fn vec_znx_big_sub_small_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize); /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize_tmp_bytes(&self, cols: usize) -> usize; + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; /// Normalizes `a` and stores the result on `b`. /// @@ -82,13 +119,21 @@ pub trait VecZnxBigOps { /// /// * `log_base2k`: normalization basis. /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]); + fn vec_znx_big_normalize( + &self, + log_base2k: usize, + res: &mut VecZnx, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + tmp_bytes: &mut [u8], + ); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig, a: &VecZnxBig); + fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig); + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, col_a: usize); } impl VecZnxBigOps for Module { @@ -108,170 +153,267 @@ impl VecZnxBigOps for Module { VecZnxBig::bytes_of(self, cols, size) } - fn vec_znx_big_add(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig) { - let op = ffi_ternary_op_factory( - self.ptr, - c.size(), - c.sl(), - a.size(), - a.sl(), - b.size(), - b.sl(), - vec_znx::vec_znx_add, - ); - apply_binary_op::, VecZnxBig, VecZnxBig, false>(self, c, a, b, op); - } - - fn vec_znx_big_add_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig) { - unsafe { - let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; - Self::vec_znx_big_add(self, &mut *b_ptr, a, &*b_ptr); - } - } - - fn vec_znx_big_sub(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig) { - let op = ffi_ternary_op_factory( - self.ptr, - c.size(), - c.sl(), - a.size(), - a.sl(), - b.size(), - b.sl(), - vec_znx::vec_znx_sub, - ); - apply_binary_op::, VecZnxBig, VecZnxBig, true>(self, c, a, b, op); - } - - fn vec_znx_big_sub_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig) { - unsafe { - let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; - Self::vec_znx_big_sub(self, &mut *b_ptr, a, &*b_ptr); - } - } - - fn vec_znx_big_sub_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig) { - unsafe { - let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; - Self::vec_znx_big_sub(self, &mut *b_ptr, &*b_ptr, a); - } - } - - fn vec_znx_big_sub_small_ba(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnx) { - let op = ffi_ternary_op_factory( - self.ptr, - c.size(), - c.sl(), - a.size(), - a.sl(), - b.size(), - b.sl(), - vec_znx::vec_znx_sub, - ); - apply_binary_op::, VecZnxBig, VecZnx, true>(self, c, a, b, op); - } - - fn vec_znx_big_sub_small_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { - unsafe { - let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; - Self::vec_znx_big_sub_small_ba(self, &mut *b_ptr, &*b_ptr, a); - } - } - - fn vec_znx_big_sub_small_ab(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { - let op = ffi_ternary_op_factory( - self.ptr, - c.size(), - c.sl(), - a.size(), - a.sl(), - b.size(), - b.sl(), - vec_znx::vec_znx_sub, - ); - apply_binary_op::, VecZnx, VecZnxBig, true>(self, c, a, b, op); - } - - fn vec_znx_big_sub_small_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { - unsafe { - let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; - Self::vec_znx_big_sub_small_ab(self, &mut *b_ptr, a, &*b_ptr); - } - } - - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { - let op = ffi_ternary_op_factory( - self.ptr, - c.size(), - c.sl(), - a.size(), - a.sl(), - b.size(), - b.sl(), - vec_znx::vec_znx_add, - ); - apply_binary_op::, VecZnx, VecZnxBig, false>(self, c, a, b, op); - } - - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { - unsafe { - let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; - Self::vec_znx_big_add_small(self, &mut *b_ptr, a, &*b_ptr); - } - } - - fn vec_znx_big_normalize_tmp_bytes(&self, cols: usize) -> usize { - Self::vec_znx_normalize_tmp_bytes(self, cols) - } - - fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) { + fn vec_znx_big_add( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ) { #[cfg(debug_assertions)] { - assert!(tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(&self, a.cols())); + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(col_b, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + unsafe { + let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; + Self::vec_znx_big_add(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + } + } + + fn vec_znx_big_sub( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(col_b, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + unsafe { + let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; + Self::vec_znx_big_sub(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + } + } + + fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + unsafe { + let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; + Self::vec_znx_big_sub(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a); + } + } + + fn vec_znx_big_sub_small_b( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + b: &VecZnx, + col_b: usize, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(col_b, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize) { + unsafe { + let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; + Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a); + } + } + + fn vec_znx_big_sub_small_a( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnx, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(col_b, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize) { + unsafe { + let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; + Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + } + } + + fn vec_znx_big_add_small( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnx, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(col_b, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, a_col: usize) { + unsafe { + let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; + Self::vec_znx_big_add_small(self, &mut *res_ptr, col_res, a, a_col, &*res_ptr, col_res); + } + } + + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { + Self::vec_znx_normalize_tmp_bytes(self) + } + + fn vec_znx_big_normalize( + &self, + log_base2k: usize, + res: &mut VecZnx, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + tmp_bytes: &mut [u8], + ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert!(tmp_bytes.len() >= Self::vec_znx_normalize_tmp_bytes(&self)); assert_alignement(tmp_bytes.as_ptr()); } - - let a_size: usize = a.size(); - let b_size: usize = b.size(); - let a_sl: usize = a.sl(); - let b_sl: usize = b.sl(); - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let min_cols: usize = min(a_cols, b_cols); - (0..min_cols).for_each(|i| unsafe { + unsafe { vec_znx::vec_znx_normalize_base2k( self.ptr, log_base2k as u64, - b.at_mut_ptr(i, 0), - b_size as u64, - b_sl as u64, - a.at_ptr(i, 0), - a_size as u64, - a_sl as u64, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, tmp_bytes.as_mut_ptr(), ); - }); - - (min_cols..b_cols).for_each(|i| (0..b_size).for_each(|j| b.zero_at(i, j))); + } } - fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig, a: &VecZnxBig) { - let op = ffi_binary_op_factory_type_1( - self.ptr, - k, - b.size(), - b.sl(), - a.size(), - a.sl(), - vec_znx::vec_znx_automorphism, - ); - apply_unary_op::>(self, b, a, op); + fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + self.ptr, + k, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + ) + } } - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig) { + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, col_a: usize) { unsafe { let a_ptr: *mut VecZnxBig = a as *mut VecZnxBig; - Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, &*a_ptr); + Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a); } } } From 6f7b93c7ca7d1234ed3cc04007c608e28c6dd5d7 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 30 Apr 2025 13:43:18 +0200 Subject: [PATCH 19/87] wip major refactoring (compiles & all test + example passing) --- base2k/examples/rlwe_encrypt.rs | 18 +- base2k/examples/vector_matrix_product.rs | 4 +- base2k/src/encoding.rs | 12 +- base2k/src/lib.rs | 18 +- base2k/src/mat_znx_dft.rs | 257 +++++++--------- base2k/src/sampling.rs | 8 +- base2k/src/scalar_znx_dft.rs | 56 ++-- base2k/src/stats.rs | 3 +- base2k/src/vec_znx.rs | 156 +++------- base2k/src/vec_znx_big.rs | 132 +++----- base2k/src/vec_znx_big_ops.rs | 206 ++++++------- base2k/src/vec_znx_dft.rs | 373 +++-------------------- base2k/src/vec_znx_dft_ops.rs | 140 +++++++++ base2k/src/vec_znx_ops.rs | 15 +- base2k/src/{commons.rs => znx_base.rs} | 122 +++++++- rlwe/Cargo.toml | 2 - rlwe/src/automorphism.rs | 8 +- rlwe/src/keys.rs | 2 +- 18 files changed, 662 insertions(+), 870 deletions(-) create mode 100644 base2k/src/vec_znx_dft_ops.rs rename base2k/src/{commons.rs => znx_base.rs} (67%) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index ee2bd02..0f75ef3 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -23,10 +23,10 @@ fn main() { s.fill_ternary_prob(0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_ppol: ScalarZnxDft = module.new_svp_ppol(); + let mut s_dft: ScalarZnxDft = module.new_scalar_znx_dft(); - // s_ppol <- DFT(s) - module.svp_prepare(&mut s_ppol, &s); + // s_dft <- DFT(s) + module.svp_prepare(&mut s_dft, &s); // Allocates a VecZnx with two columns: ct=(0, 0) let mut ct: VecZnx = module.new_vec_znx( @@ -46,16 +46,17 @@ fn main() { // Applies DFT(ct[1]) * DFT(s) module.svp_apply_dft( &mut buf_dft, // DFT(ct[1] * s) - &s_ppol, // DFT(s) + 0, // Selects the first column of res + &s_dft, // DFT(s) &ct, 1, // Selects the second column of ct ); // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) - let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); + let mut buf_big: VecZnxBig = buf_dft.alias_as_vec_znx_big(); // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) - module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); + module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); // Creates a plaintext: VecZnx with 1 column let mut m: VecZnx = module.new_vec_znx( @@ -103,13 +104,14 @@ fn main() { // DFT(ct[1] * s) module.svp_apply_dft( &mut buf_dft, - &s_ppol, + 0, // Selects the first column of res. + &s_dft, &ct, 1, // Selects the second column of ct (ct[1]) ); // BIG(c1 * s) = IDFT(DFT(c1 * s)) - module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); + module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); // BIG(c1 * s) + ct[0] module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0); diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 2f4b1fb..e565be1 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -42,8 +42,8 @@ fn main() { let mut c_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs_mat); module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut buf); - let mut c_big: VecZnxBig = c_dft.as_vec_znx_big(); - module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft); + let mut c_big: VecZnxBig = c_dft.alias_as_vec_znx_big(); + module.vec_znx_idft_tmp_a(&mut c_big, 0, &mut c_dft, 0); let mut res: VecZnx = module.new_vec_znx(1, limbs_vec); module.vec_znx_big_normalize(log_base2k, &mut res, 0, &c_big, 0, &mut buf); diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 7f8a0cc..b7d014d 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -1,5 +1,6 @@ use crate::ffi::znx::znx_zero_i64_ref; -use crate::{VecZnx, ZnxInfos, ZnxLayout}; +use crate::znx_base::ZnxLayout; +use crate::{VecZnx, znx_base::ZnxInfos}; use itertools::izip; use rug::{Assign, Float}; use std::cmp::min; @@ -262,7 +263,10 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i #[cfg(test)] mod tests { - use crate::{Encoding, FFT64, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout}; + use crate::{ + Encoding, FFT64, Module, VecZnx, VecZnxOps, + znx_base::{ZnxInfos, ZnxLayout}, + }; use itertools::izip; use sampling::source::Source; @@ -273,7 +277,7 @@ mod tests { let log_base2k: usize = 17; let size: usize = 5; let log_k: usize = size * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(&module, 2, size); + let mut a: VecZnx = module.new_vec_znx(2, size); let mut source: Source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); @@ -295,7 +299,7 @@ mod tests { let log_base2k: usize = 17; let size: usize = 5; let log_k: usize = size * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(&module, 2, size); + let mut a: VecZnx = module.new_vec_znx(2, size); let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 7a8a3f8..f57e482 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -1,4 +1,3 @@ -pub mod commons; pub mod encoding; #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] // Other modules and exports @@ -12,9 +11,10 @@ pub mod vec_znx; pub mod vec_znx_big; pub mod vec_znx_big_ops; pub mod vec_znx_dft; +pub mod vec_znx_dft_ops; pub mod vec_znx_ops; +pub mod znx_base; -pub use commons::*; pub use encoding::*; pub use mat_znx_dft::*; pub use module::*; @@ -26,7 +26,9 @@ pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_big_ops::*; pub use vec_znx_dft::*; +pub use vec_znx_dft_ops::*; pub use vec_znx_ops::*; +pub use znx_base::*; pub const GALOISGENERATOR: u64 = 5; pub const DEFAULTALIGN: usize = 64; @@ -110,14 +112,8 @@ pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { unsafe { Vec::from_raw_parts(ptr, len, cap) } } -// Allocates an aligned of size equal to the smallest power of two equal or greater to `size` that is -// at least as bit as DEFAULTALIGN / std::mem::size_of::(). +/// Allocates an aligned of size equal to the smallest multiple +/// of [DEFAULTALIGN] that is equal or greater to `size`. pub fn alloc_aligned(size: usize) -> Vec { - alloc_aligned_custom::( - std::cmp::max( - size.next_power_of_two(), - DEFAULTALIGN / std::mem::size_of::(), - ), - DEFAULTALIGN, - ) + alloc_aligned_custom::(size + (size % DEFAULTALIGN), DEFAULTALIGN) } diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index b40ed71..9b5e2ca 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,103 +1,75 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp::{self, vmp_pmat_t}; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement}; use std::marker::PhantomData; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], /// stored as a 3D matrix in the DFT domain in a single contiguous array. -/// Each col of the [VmpPMat] can be seen as a collection of [VecZnxDft]. +/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft]. /// -/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [VmpPMat]. -/// See the trait [VmpPMatOps] for additional information. +/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. +/// See the trait [MatZnxDftOps] for additional information. pub struct MatZnxDft { - /// Raw data, is empty if borrowing scratch space. - data: Vec, - /// Pointer to data. Can point to scratch space. - ptr: *mut u8, - /// The ring degree of each polynomial. - n: usize, - /// Number of rows - rows: usize, - /// Number of cols - cols: usize, - /// The number of small polynomials - size: usize, + pub inner: ZnxBase, _marker: PhantomData, } -impl ZnxInfos for MatZnxDft { - fn n(&self) -> usize { - self.n +impl GetZnxBase for MatZnxDft { + fn znx(&self) -> &ZnxBase { + &self.inner } - fn rows(&self) -> usize { - self.rows - } - - fn cols(&self) -> usize { - self.cols - } - - fn size(&self) -> usize { - self.size + fn znx_mut(&mut self) -> &mut ZnxBase { + &mut self.inner } } -impl MatZnxDft { - fn new(module: &Module, rows: usize, cols: usize, size: usize) -> MatZnxDft { - let mut data: Vec = alloc_aligned::(module.bytes_of_mat_znx_dft(rows, cols, size)); - let ptr: *mut u8 = data.as_mut_ptr(); - MatZnxDft:: { - data: data, - ptr: ptr, - n: module.n(), - rows: rows, - cols: cols, - size: size, +impl ZnxInfos for MatZnxDft {} + +impl ZnxSliceSize for MatZnxDft { + fn sl(&self) -> usize { + self.n() + } +} + +impl ZnxLayout for MatZnxDft { + type Scalar = f64; +} + +impl ZnxAlloc for MatZnxDft { + type Scalar = u8; + + fn from_bytes_borrow(module: &Module, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + Self { + inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols, size, bytes), _marker: PhantomData, } } - pub fn as_ptr(&self) -> *const u8 { - self.ptr + fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize { + unsafe { vmp::bytes_of_vmp_pmat(module.ptr, rows as u64, size as u64) as usize * cols } } +} - pub fn as_mut_ptr(&self) -> *mut u8 { - self.ptr - } - - pub fn borrowed(&self) -> bool { - self.data.len() == 0 - } - - /// Returns a non-mutable reference to the entire contiguous array of the [VmpPMat]. - pub fn raw(&self) -> &[f64] { - let ptr: *const f64 = self.ptr as *const f64; - let size: usize = self.n() * self.poly_count(); - unsafe { &std::slice::from_raw_parts(ptr, size) } - } - - /// Returns a mutable reference of to the entire contiguous array of the [VmpPMat]. - pub fn raw_mut(&self) -> &mut [f64] { - let ptr: *mut f64 = self.ptr as *mut f64; - let size: usize = self.n() * self.poly_count(); - unsafe { std::slice::from_raw_parts_mut(ptr, size) } - } - - /// Returns a copy of the backend array at index (i, j) of the [VmpPMat]. +impl MatZnxDft { + /// Returns a copy of the backend array at index (i, j) of the [MatZnxDft]. /// /// # Arguments /// /// * `row`: row index (i). /// * `col`: col index (j). - pub fn at(&self, row: usize, col: usize) -> Vec { - let mut res: Vec = alloc_aligned(self.n); + #[allow(dead_code)] + fn at(&self, row: usize, col: usize) -> Vec { + let n: usize = self.n(); - if self.n < 8 { - res.copy_from_slice(&self.raw()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)]); + let mut res: Vec = alloc_aligned(n); + + if n < 8 { + res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]); } else { - (0..self.n >> 3).for_each(|blk| { + (0..n >> 3).for_each(|blk| { res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); }); } @@ -105,6 +77,7 @@ impl MatZnxDft { res } + #[allow(dead_code)] fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { let nrows: usize = self.rows(); let nsize: usize = self.size(); @@ -117,11 +90,11 @@ impl MatZnxDft { } /// This trait implements methods for vector matrix product, -/// that is, multiplying a [VecZnx] with a [VmpPMat]. +/// that is, multiplying a [VecZnx] with a [MatZnxDft]. pub trait MatZnxDftOps { fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize; - /// Allocates a new [VmpPMat] with the given number of rows and columns. + /// Allocates a new [MatZnxDft] with the given number of rows and columns. /// /// # Arguments /// @@ -129,83 +102,83 @@ pub trait MatZnxDftOps { /// * `size`: number of size (number of size of each [VecZnxDft]). fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft; - /// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous]. + /// Returns the number of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_contiguous]. /// /// # Arguments /// - /// * `rows`: number of rows of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. - /// * `size`: number of size of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. + /// * `rows`: number of rows of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous]. + /// * `size`: number of size of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous]. fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize; - /// Prepares a [VmpPMat] from a contiguous array of [i64]. + /// Prepares a [MatZnxDft] from a contiguous array of [i64]. /// The helper struct [Matrix3D] can be used to contruct and populate /// the appropriate contiguous array. /// /// # Arguments /// - /// * `b`: [VmpPMat] on which the values are encoded. - /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat]. - /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. + /// * `b`: [MatZnxDft] on which the values are encoded. + /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [MatZnxDft]. + /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft, a: &[i64], buf: &mut [u8]); - /// Prepares the ith-row of [VmpPMat] from a [VecZnx]. + /// Prepares the ith-row of [MatZnxDft] from a [VecZnx]. /// /// # Arguments /// - /// * `b`: [VmpPMat] on which the values are encoded. - /// * `a`: the vector of [VecZnx] to encode on the [VmpPMat]. + /// * `b`: [MatZnxDft] on which the values are encoded. + /// * `a`: the vector of [VecZnx] to encode on the [MatZnxDft]. /// * `row_i`: the index of the row to prepare. - /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. + /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. /// - /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. + /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); - /// Extracts the ith-row of [VmpPMat] into a [VecZnxBig]. + /// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig]. /// /// # Arguments /// - /// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat]. - /// * `a`: [VmpPMat] on which the values are encoded. + /// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft]. + /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &MatZnxDft, row_i: usize); - /// Prepares the ith-row of [VmpPMat] from a [VecZnxDft]. + /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. /// /// # Arguments /// - /// * `b`: [VmpPMat] on which the values are encoded. - /// * `a`: the [VecZnxDft] to encode on the [VmpPMat]. + /// * `b`: [MatZnxDft] on which the values are encoded. + /// * `a`: the [VecZnxDft] to encode on the [MatZnxDft]. /// * `row_i`: the index of the row to prepare. /// - /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. + /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, a: &VecZnxDft, row_i: usize); - /// Extracts the ith-row of [VmpPMat] into a [VecZnxDft]. + /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft]. /// /// # Arguments /// - /// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat]. - /// * `a`: [VmpPMat] on which the values are encoded. + /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. + /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, row_i: usize); - /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. + /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. /// /// # Arguments /// /// * `c_size`: number of size of the output [VecZnxDft]. /// * `a_size`: number of size of the input [VecZnx]. - /// * `rows`: number of rows of the input [VmpPMat]. - /// * `size`: number of size of the input [VmpPMat]. + /// * `rows`: number of rows of the input [MatZnxDft]. + /// * `size`: number of size of the input [MatZnxDft]. fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. /// - /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. @@ -221,17 +194,17 @@ pub trait MatZnxDftOps { /// /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. /// * `a`: the left operand [VecZnx] of the vector matrix product. - /// * `b`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. + /// * `b`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver. + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on the receiver. /// /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. /// - /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. @@ -247,28 +220,28 @@ pub trait MatZnxDftOps { /// /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. /// * `a`: the left operand [VecZnx] of the vector matrix product. - /// * `b`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. + /// * `b`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); - /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft]. + /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. /// /// # Arguments /// /// * `c_size`: number of size of the output [VecZnxDft]. /// * `a_size`: number of size of the input [VecZnxDft]. - /// * `rows`: number of rows of the input [VmpPMat]. - /// * `size`: number of size of the input [VmpPMat]. + /// * `rows`: number of rows of the input [MatZnxDft]. + /// * `size`: number of size of the input [MatZnxDft]. fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. - /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. + /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. /// /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. /// - /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. @@ -284,18 +257,18 @@ pub trait MatZnxDftOps { /// /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. /// * `a`: the left operand [VecZnxDft] of the vector matrix product. - /// * `b`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// * `b`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it. - /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on top of the receiver instead of overwritting it. + /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. /// /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. /// - /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. @@ -311,18 +284,18 @@ pub trait MatZnxDftOps { /// /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. /// * `a`: the left operand [VecZnxDft] of the vector matrix product. - /// * `b`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// * `b`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place. - /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] in place. + /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. /// /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. /// - /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. @@ -337,8 +310,8 @@ pub trait MatZnxDftOps { /// # Arguments /// /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. - /// * `a`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// * `a`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, buf: &mut [u8]); } @@ -404,7 +377,7 @@ impl MatZnxDftOps for Module { unsafe { vmp::vmp_extract_row( self.ptr, - b.ptr as *mut vec_znx_big_t, + b.as_mut_ptr() as *mut vec_znx_big_t, a.as_ptr() as *const vmp_pmat_t, row_i as u64, a.rows() as u64, @@ -423,7 +396,7 @@ impl MatZnxDftOps for Module { vmp::vmp_prepare_row_dft( self.ptr, b.as_mut_ptr() as *mut vmp_pmat_t, - a.ptr as *const vec_znx_dft_t, + a.as_ptr() as *const vec_znx_dft_t, row_i as u64, b.rows() as u64, b.size() as u64, @@ -440,7 +413,7 @@ impl MatZnxDftOps for Module { unsafe { vmp::vmp_extract_row_dft( self.ptr, - b.ptr as *mut vec_znx_dft_t, + b.as_mut_ptr() as *mut vec_znx_dft_t, a.as_ptr() as *const vmp_pmat_t, row_i as u64, a.rows() as u64, @@ -470,7 +443,7 @@ impl MatZnxDftOps for Module { unsafe { vmp::vmp_apply_dft( self.ptr, - c.ptr as *mut vec_znx_dft_t, + c.as_mut_ptr() as *mut vec_znx_dft_t, c.size() as u64, a.as_ptr(), a.size() as u64, @@ -492,7 +465,7 @@ impl MatZnxDftOps for Module { unsafe { vmp::vmp_apply_dft_add( self.ptr, - c.ptr as *mut vec_znx_dft_t, + c.as_mut_ptr() as *mut vec_znx_dft_t, c.size() as u64, a.as_ptr(), a.size() as u64, @@ -526,9 +499,9 @@ impl MatZnxDftOps for Module { unsafe { vmp::vmp_apply_dft_to_dft( self.ptr, - c.ptr as *mut vec_znx_dft_t, + c.as_mut_ptr() as *mut vec_znx_dft_t, c.size() as u64, - a.ptr as *const vec_znx_dft_t, + a.as_ptr() as *const vec_znx_dft_t, a.size() as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, @@ -553,9 +526,9 @@ impl MatZnxDftOps for Module { unsafe { vmp::vmp_apply_dft_to_dft_add( self.ptr, - c.ptr as *mut vec_znx_dft_t, + c.as_mut_ptr() as *mut vec_znx_dft_t, c.size() as u64, - a.ptr as *const vec_znx_dft_t, + a.as_ptr() as *const vec_znx_dft_t, a.size() as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, @@ -574,9 +547,9 @@ impl MatZnxDftOps for Module { unsafe { vmp::vmp_apply_dft_to_dft( self.ptr, - b.ptr as *mut vec_znx_dft_t, + b.as_mut_ptr() as *mut vec_znx_dft_t, b.size() as u64, - b.ptr as *mut vec_znx_dft_t, + b.as_ptr() as *mut vec_znx_dft_t, b.size() as u64, a.as_ptr() as *const vmp_pmat_t, a.rows() as u64, @@ -591,7 +564,7 @@ impl MatZnxDftOps for Module { mod tests { use crate::{ FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, - ZnxLayout, alloc_aligned, + alloc_aligned, znx_base::ZnxLayout, }; use sampling::source::Source; @@ -614,7 +587,7 @@ mod tests { for row_i in 0..vpmat_rows { let mut source: Source = Source::new([0u8; 32]); module.fill_uniform(log_base2k, &mut a, 0, vpmat_size, &mut source); - module.vec_znx_dft(&mut a_dft, &a); + module.vec_znx_dft(&mut a_dft, 0, &a, 0); module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes); // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) @@ -627,7 +600,7 @@ mod tests { // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); - module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes); + module.vec_znx_idft(&mut a_big, 0, &a_dft, 0, &mut tmp_bytes); assert_eq!(a_big.raw(), b_big.raw()); } diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index 5261207..b52c4db 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,4 +1,4 @@ -use crate::{Backend, Module, VecZnx, ZnxLayout}; +use crate::{Backend, Module, VecZnx, znx_base::ZnxLayout}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; @@ -106,7 +106,7 @@ impl Sampling for Module { #[cfg(test)] mod tests { use super::Sampling; - use crate::{FFT64, Module, Stats, VecZnx, ZnxBase, ZnxLayout}; + use crate::{FFT64, Module, Stats, VecZnx, VecZnxOps, znx_base::ZnxLayout}; use sampling::source::Source; #[test] @@ -120,7 +120,7 @@ mod tests { let zero: Vec = vec![0; n]; let one_12_sqrt: f64 = 0.28867513459481287; (0..cols).for_each(|col_i| { - let mut a: VecZnx = VecZnx::new(&module, cols, size); + let mut a: VecZnx = module.new_vec_znx(cols, size); module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source); (0..cols).for_each(|col_j| { if col_j != col_i { @@ -154,7 +154,7 @@ mod tests { let zero: Vec = vec![0; n]; let k_f64: f64 = (1u64 << log_k as u64) as f64; (0..cols).for_each(|col_i| { - let mut a: VecZnx = VecZnx::new(&module, cols, size); + let mut a: VecZnx = module.new_vec_znx(cols, size); module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound); (0..cols).for_each(|col_j| { if col_j != col_i { diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index 474135b..07e156d 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -2,9 +2,8 @@ use std::marker::PhantomData; use crate::ffi::svp::{self, svp_ppol_t}; use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, ZnxLayout, assert_alignement}; - -use crate::{ZnxInfos, alloc_aligned, cast_mut}; +use crate::znx_base::{ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement, cast_mut}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; @@ -118,11 +117,14 @@ impl Scalar { pub fn as_vec_znx(&self) -> VecZnx { VecZnx { - n: self.n, - cols: 1, - size: 1, - data: Vec::new(), - ptr: self.ptr, + inner: ZnxBase { + n: self.n, + rows: 1, + cols: 1, + size: 1, + data: Vec::new(), + ptr: self.ptr as *mut u8, + }, } } } @@ -159,7 +161,7 @@ pub struct ScalarZnxDft { /// An [SvpPPol] an be seen as a [VecZnxDft] of one limb. impl ScalarZnxDft { pub fn new(module: &Module) -> Self { - module.new_svp_ppol() + module.new_scalar_znx_dft() } /// Returns the ring degree of the [SvpPPol]. @@ -168,14 +170,14 @@ impl ScalarZnxDft { } pub fn bytes_of(module: &Module) -> usize { - module.bytes_of_svp_ppol() + module.bytes_of_scalar_znx_dft() } pub fn from_bytes(module: &Module, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { assert_alignement(bytes.as_ptr()); - assert_eq!(bytes.len(), module.bytes_of_svp_ppol()); + assert_eq!(bytes.len(), module.bytes_of_scalar_znx_dft()); } unsafe { Self { @@ -191,7 +193,7 @@ impl ScalarZnxDft { #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); - assert_eq!(tmp_bytes.len(), module.bytes_of_svp_ppol()); + assert_eq!(tmp_bytes.len(), module.bytes_of_scalar_znx_dft()); } Self { n: module.n(), @@ -209,33 +211,33 @@ impl ScalarZnxDft { pub trait ScalarZnxDftOps { /// Allocates a new [SvpPPol]. - fn new_svp_ppol(&self) -> ScalarZnxDft; + fn new_scalar_znx_dft(&self) -> ScalarZnxDft; /// Returns the minimum number of bytes necessary to allocate /// a new [SvpPPol] through [SvpPPol::from_bytes] ro. - fn bytes_of_svp_ppol(&self) -> usize; + fn bytes_of_scalar_znx_dft(&self) -> usize; /// Allocates a new [SvpPPol] from an array of bytes. /// The array of bytes is owned by the [SvpPPol]. /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft; + fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft; /// Allocates a new [SvpPPol] from an array of bytes. /// The array of bytes is borrowed by the [SvpPPol]. /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft; + fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft; /// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft]. fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft, a: &Scalar); /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// the [VecZnxDft] is multiplied with [SvpPPol]. - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx, b_col: usize); + fn svp_apply_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &ScalarZnxDft, b: &VecZnx, b_col: usize); } impl ScalarZnxDftOps for Module { - fn new_svp_ppol(&self) -> ScalarZnxDft { - let mut data: Vec = alloc_aligned::(self.bytes_of_svp_ppol()); + fn new_scalar_znx_dft(&self) -> ScalarZnxDft { + let mut data: Vec = alloc_aligned::(self.bytes_of_scalar_znx_dft()); let ptr: *mut u8 = data.as_mut_ptr(); ScalarZnxDft:: { data: data, @@ -245,28 +247,28 @@ impl ScalarZnxDftOps for Module { } } - fn bytes_of_svp_ppol(&self) -> usize { + fn bytes_of_scalar_znx_dft(&self) -> usize { unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize } } - fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft { + fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft { ScalarZnxDft::from_bytes(self, bytes) } - fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft { + fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft { ScalarZnxDft::from_bytes_borrow(self, tmp_bytes) } - fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft, a: &Scalar) { - unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) } + fn svp_prepare(&self, res: &mut ScalarZnxDft, a: &Scalar) { + unsafe { svp::svp_prepare(self.ptr, res.ptr as *mut svp_ppol_t, a.as_ptr()) } } - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx, b_col: usize) { + fn svp_apply_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &ScalarZnxDft, b: &VecZnx, b_col: usize) { unsafe { svp::svp_apply_dft( self.ptr, - c.ptr as *mut vec_znx_dft_t, - c.size() as u64, + res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, + res.size() as u64, a.ptr as *const svp_ppol_t, b.at_ptr(b_col, 0), b.size() as u64, diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs index 4e2a512..a1946ab 100644 --- a/base2k/src/stats.rs +++ b/base2k/src/stats.rs @@ -1,4 +1,5 @@ -use crate::{Encoding, VecZnx, ZnxInfos}; +use crate::znx_base::ZnxInfos; +use crate::{Encoding, VecZnx}; use rug::Float; use rug::float::Round; use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound}; diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 53aeb39..125f32e 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,12 +1,13 @@ use crate::Backend; -use crate::ZnxBase; +use crate::Module; +use crate::assert_alignement; use crate::cast_mut; use crate::ffi::znx; -use crate::switch_degree; -use crate::{Module, ZnxBasics, ZnxInfos, ZnxLayout}; -use crate::{alloc_aligned, assert_alignement}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree}; use std::cmp::min; +pub const VEC_ZNX_ROWS: usize = 1; + /// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of /// Zn\[X\] with [i64] coefficients. /// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array @@ -17,56 +18,54 @@ use std::cmp::min; /// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory /// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci /// are small polynomials of Zn\[X\]. -#[derive(Clone)] pub struct VecZnx { - /// Polynomial degree. - pub n: usize, - - /// The number of polynomials - pub cols: usize, - - /// The number of size per polynomial (a.k.a small polynomials). - pub size: usize, - - /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. - pub data: Vec, - - /// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it). - pub ptr: *mut i64, + pub inner: ZnxBase, } -impl ZnxInfos for VecZnx { - fn n(&self) -> usize { - self.n +impl GetZnxBase for VecZnx { + fn znx(&self) -> &ZnxBase { + &self.inner } - fn rows(&self) -> usize { - 1 + fn znx_mut(&mut self) -> &mut ZnxBase { + &mut self.inner } +} - fn cols(&self) -> usize { - self.cols - } +impl ZnxInfos for VecZnx {} - fn size(&self) -> usize { - self.size +impl ZnxSliceSize for VecZnx { + fn sl(&self) -> usize { + self.cols() * self.n() } } impl ZnxLayout for VecZnx { type Scalar = i64; - - fn as_ptr(&self) -> *const Self::Scalar { - self.ptr - } - - fn as_mut_ptr(&mut self) -> *mut Self::Scalar { - self.ptr - } } impl ZnxBasics for VecZnx {} +impl ZnxAlloc for VecZnx { + type Scalar = i64; + + fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx { + debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); + VecZnx { + inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes), + } + } + + fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { + debug_assert_eq!( + _rows, VEC_ZNX_ROWS, + "rows != {} not supported for VecZnx", + VEC_ZNX_ROWS + ); + module.n() * cols * size * size_of::() + } +} + /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. /// Panics if the cols do not match. @@ -78,80 +77,6 @@ pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { data_b[..size].copy_from_slice(&data_a[..size]) } -impl ZnxBase for VecZnx { - type Scalar = i64; - - /// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\]. - fn new(module: &Module, cols: usize, size: usize) -> Self { - let n: usize = module.n(); - #[cfg(debug_assertions)] - { - assert!(n > 0); - assert!(n & (n - 1) == 0); - assert!(cols > 0); - assert!(size > 0); - } - let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, size)); - let ptr: *mut i64 = data.as_mut_ptr(); - Self { - n: n, - cols: cols, - size: size, - data: data, - ptr: ptr, - } - } - - fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { - module.n() * cols * size * size_of::() - } - - /// Returns a new struct implementing [VecZnx] with the provided data as backing array. - /// - /// The struct will take ownership of buf[..[Self::bytes_of]] - /// - /// User must ensure that data is properly alligned and that - /// the size of data is equal to [Self::bytes_of]. - fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self { - let n: usize = module.n(); - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); - assert_alignement(bytes.as_ptr()); - } - unsafe { - let bytes_i64: &mut [i64] = cast_mut::(bytes); - let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - Self { - n: n, - cols: cols, - size: size, - data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), - ptr: ptr, - } - } - } - - fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - assert!(bytes.len() >= Self::bytes_of(module, cols, size)); - assert_alignement(bytes.as_ptr()); - } - Self { - n: module.n(), - cols: cols, - size: size, - data: Vec::new(), - ptr: bytes.as_mut_ptr() as *mut i64, - } - } -} - impl VecZnx { /// Truncates the precision of the [VecZnx] by k bits. /// @@ -165,11 +90,12 @@ impl VecZnx { } if !self.borrowing() { - self.data + self.inner + .data .truncate(self.n() * self.cols() * (self.size() - k / log_base2k)); } - self.size -= k / log_base2k; + self.inner.size -= k / log_base2k; let k_rem: usize = k % log_base2k; @@ -185,10 +111,6 @@ impl VecZnx { copy_vec_znx_from(self, a); } - pub fn borrowing(&self) -> bool { - self.data.len() == 0 - } - pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { normalize(log_base2k, self, carry) } diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 67b75a2..cbcd4b9 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,115 +1,71 @@ use crate::ffi::vec_znx_big; -use crate::{Backend, FFT64, Module, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, FFT64, Module, NTT120}; use std::marker::PhantomData; +const VEC_ZNX_BIG_ROWS: usize = 1; + pub struct VecZnxBig { - pub data: Vec, - pub ptr: *mut u8, - pub n: usize, - pub cols: usize, - pub size: usize, + pub inner: ZnxBase, pub _marker: PhantomData, } -impl ZnxBasics for VecZnxBig {} - -impl ZnxBase for VecZnxBig { - type Scalar = u8; - - fn new(module: &Module, cols: usize, size: usize) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - } - let mut data: Vec = alloc_aligned(Self::bytes_of(module, cols, size)); - let ptr: *mut Self::Scalar = data.as_mut_ptr(); - Self { - data: data, - ptr: ptr, - n: module.n(), - cols: cols, - size: size, - _marker: PhantomData, - } +impl GetZnxBase for VecZnxBig { + fn znx(&self) -> &ZnxBase { + &self.inner } - fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { - unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } - } - - /// Returns a new [VecZnxBig] with the provided data as backing array. - /// User must ensure that data is properly alligned and that - /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. - fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); - assert_alignement(bytes.as_ptr()) - }; - unsafe { - Self { - data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), - ptr: bytes.as_mut_ptr(), - n: module.n(), - cols: cols, - size: size, - _marker: PhantomData, - } - } - } - - fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); - assert_alignement(bytes.as_ptr()); - } - Self { - data: Vec::new(), - ptr: bytes.as_mut_ptr(), - n: module.n(), - cols: cols, - size: size, - _marker: PhantomData, - } + fn znx_mut(&mut self) -> &mut ZnxBase { + &mut self.inner } } -impl ZnxInfos for VecZnxBig { - fn n(&self) -> usize { - self.n +impl ZnxInfos for VecZnxBig {} + +impl ZnxAlloc for VecZnxBig { + type Scalar = u8; + + fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + VecZnxBig { + inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_BIG_ROWS, cols, size, bytes), + _marker: PhantomData, + } } - fn cols(&self) -> usize { - self.cols - } - - fn rows(&self) -> usize { - 1 - } - - fn size(&self) -> usize { - self.size + fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { + debug_assert_eq!( + _rows, VEC_ZNX_BIG_ROWS, + "rows != {} not supported for VecZnxBig", + VEC_ZNX_BIG_ROWS + ); + unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } } } impl ZnxLayout for VecZnxBig { type Scalar = i64; +} - fn as_ptr(&self) -> *const Self::Scalar { - self.ptr as *const Self::Scalar - } +impl ZnxLayout for VecZnxBig { + type Scalar = i128; +} - fn as_mut_ptr(&mut self) -> *mut Self::Scalar { - self.ptr as *mut Self::Scalar +impl ZnxBasics for VecZnxBig {} + +impl ZnxSliceSize for VecZnxBig { + fn sl(&self) -> usize { + self.n() } } +impl ZnxSliceSize for VecZnxBig { + fn sl(&self) -> usize { + self.n() * 4 + } +} + +impl ZnxBasics for VecZnxBig {} + impl VecZnxBig { pub fn print(&self, n: usize) { (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index e59fda1..9c6feee 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,5 +1,6 @@ -use crate::ffi::vec_znx; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement}; +use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; +use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, assert_alignement}; pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. @@ -17,7 +18,7 @@ pub trait VecZnxBigOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig; + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBig; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -41,74 +42,74 @@ pub trait VecZnxBigOps { fn vec_znx_big_add( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, b: &VecZnxBig, - col_b: usize, + b_col: usize, ); /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); + fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); /// Adds `a` to `b` and stores the result on `c`. fn vec_znx_big_add_small( &self, res: &mut VecZnxBig, - col_res: usize, - a: &VecZnx, - col_a: usize, - b: &VecZnxBig, - col_b: usize, + res_col: usize, + a: &VecZnxBig, + a_col: usize, + b: &VecZnx, + b_col: usize, ); /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); /// Subtracts `a` to `b` and stores the result on `c`. fn vec_znx_big_sub( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, b: &VecZnxBig, - col_b: usize, + b_col: usize, ); /// Subtracts `a` to `b` and stores the result on `b`. - fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); + fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); /// Subtracts `b` to `a` and stores the result on `b`. - fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); + fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); /// Subtracts `b` to `a` and stores the result on `c`. fn vec_znx_big_sub_small_a( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnx, - col_a: usize, + a_col: usize, b: &VecZnxBig, - col_b: usize, + b_col: usize, ); /// Subtracts `a` to `b` and stores the result on `b`. - fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); /// Subtracts `b` to `a` and stores the result on `c`. fn vec_znx_big_sub_small_b( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, b: &VecZnx, - col_b: usize, + b_col: usize, ); /// Subtracts `b` to `a` and stores the result on `b`. - fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; @@ -123,44 +124,44 @@ pub trait VecZnxBigOps { &self, log_base2k: usize, res: &mut VecZnx, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, tmp_bytes: &mut [u8], ); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); + fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, col_a: usize); + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize); } impl VecZnxBigOps for Module { fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig { - VecZnxBig::new(self, cols, size) + VecZnxBig::new(self, 1, cols, size) } - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes(self, cols, size, bytes) + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBig { + VecZnxBig::from_bytes(self, 1, cols, size, bytes) } fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes) + VecZnxBig::from_bytes_borrow(self, 1, cols, size, tmp_bytes) } fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { - VecZnxBig::bytes_of(self, cols, size) + VecZnxBig::bytes_of(self, 1, cols, size) } fn vec_znx_big_add( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, b: &VecZnxBig, - col_b: usize, + b_col: usize, ) { #[cfg(debug_assertions)] { @@ -170,36 +171,33 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx::vec_znx_add( + vec_znx_big::vec_znx_big_add( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.size() as u64, - res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t, a.size() as u64, - a.sl() as u64, - b.at_ptr(col_b, 0), + b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t, b.size() as u64, - b.sl() as u64, ) } } - fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { unsafe { let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_add(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + Self::vec_znx_big_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); } } fn vec_znx_big_sub( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, b: &VecZnxBig, - col_b: usize, + b_col: usize, ) { #[cfg(debug_assertions)] { @@ -209,43 +207,40 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx::vec_znx_sub( + vec_znx_big::vec_znx_big_sub( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.size() as u64, - res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t, a.size() as u64, - a.sl() as u64, - b.at_ptr(col_b, 0), + b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t, b.size() as u64, - b.sl() as u64, ) } } - fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { unsafe { let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); } } - fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { unsafe { let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a); + Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); } } fn vec_znx_big_sub_small_b( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, b: &VecZnx, - col_b: usize, + b_col: usize, ) { #[cfg(debug_assertions)] { @@ -255,36 +250,34 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx::vec_znx_sub( + vec_znx_big::vec_znx_big_sub_small_b( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.size() as u64, - res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, a.size() as u64, - a.sl() as u64, - b.at_ptr(col_b, 0), + b.at_ptr(b_col, 0), b.size() as u64, b.sl() as u64, ) } } - fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a); + Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); } } fn vec_znx_big_sub_small_a( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnx, - col_a: usize, + a_col: usize, b: &VecZnxBig, - col_b: usize, + b_col: usize, ) { #[cfg(debug_assertions)] { @@ -294,36 +287,34 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx::vec_znx_sub( + vec_znx_big::vec_znx_big_sub_small_a( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.size() as u64, - res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, - b.at_ptr(col_b, 0), + b.at_ptr(b_col * b.size(), 0) as *const vec_znx_big_t, b.size() as u64, - b.sl() as u64, ) } } - fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); } } fn vec_znx_big_add_small( &self, res: &mut VecZnxBig, - col_res: usize, - a: &VecZnx, - col_a: usize, - b: &VecZnxBig, - col_b: usize, + res_col: usize, + a: &VecZnxBig, + a_col: usize, + b: &VecZnx, + b_col: usize, ) { #[cfg(debug_assertions)] { @@ -333,25 +324,23 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx::vec_znx_add( + vec_znx_big::vec_znx_big_add_small( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.size() as u64, - res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, a.size() as u64, - a.sl() as u64, - b.at_ptr(col_b, 0), + b.at_ptr(b_col, 0), b.size() as u64, b.sl() as u64, ) } } - fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_add_small(self, &mut *res_ptr, col_res, a, a_col, &*res_ptr, col_res); + Self::vec_znx_big_add_small(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); } } @@ -363,9 +352,9 @@ impl VecZnxBigOps for Module { &self, log_base2k: usize, res: &mut VecZnx, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, tmp_bytes: &mut [u8], ) { #[cfg(debug_assertions)] @@ -376,44 +365,41 @@ impl VecZnxBigOps for Module { assert_alignement(tmp_bytes.as_ptr()); } unsafe { - vec_znx::vec_znx_normalize_base2k( + vec_znx_big::vec_znx_big_normalize_base2k( self.ptr, log_base2k as u64, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, a.size() as u64, - a.sl() as u64, tmp_bytes.as_mut_ptr(), ); } } - fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); assert_eq!(res.n(), self.n()); } unsafe { - vec_znx::vec_znx_automorphism( + vec_znx_big::vec_znx_big_automorphism( self.ptr, k, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.size() as u64, - res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, a.size() as u64, - a.sl() as u64, ) } } - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, col_a: usize) { + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize) { unsafe { let a_ptr: *mut VecZnxBig = a as *mut VecZnxBig; - Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a); + Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col); } } } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 1b88af5..09ee971 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,129 +1,54 @@ -use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft; -use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; -use crate::{Backend, FFT64, Module, VecZnxBig, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement}; -use crate::{VecZnx, alloc_aligned}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, FFT64, Module, VecZnxBig}; use std::marker::PhantomData; +const VEC_ZNX_DFT_ROWS: usize = 1; + pub struct VecZnxDft { - pub data: Vec, - pub ptr: *mut u8, - pub n: usize, - pub cols: usize, - pub size: usize, + inner: ZnxBase, pub _marker: PhantomData, } -impl ZnxBase for VecZnxDft { +impl GetZnxBase for VecZnxDft { + fn znx(&self) -> &ZnxBase { + &self.inner + } + + fn znx_mut(&mut self) -> &mut ZnxBase { + &mut self.inner + } +} + +impl ZnxInfos for VecZnxDft {} + +impl ZnxAlloc for VecZnxDft { type Scalar = u8; - fn new(module: &Module, cols: usize, size: usize) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - } - let mut data: Vec = alloc_aligned(Self::bytes_of(module, cols, size)); - let ptr: *mut Self::Scalar = data.as_mut_ptr(); - Self { - data: data, - ptr: ptr, - n: module.n(), - size: size, - cols: cols, + fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + VecZnxDft { + inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes), _marker: PhantomData, } } - fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { - unsafe { bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } - } - - /// Returns a new [VecZnxDft] with the provided data as backing array. - /// User must ensure that data is properly alligned and that - /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. - fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); - assert_alignement(bytes.as_ptr()) - } - unsafe { - Self { - data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), - ptr: bytes.as_mut_ptr(), - n: module.n(), - cols: cols, - size: size, - _marker: PhantomData, - } - } - } - - fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); - assert_alignement(bytes.as_ptr()); - } - Self { - data: Vec::new(), - ptr: bytes.as_mut_ptr(), - n: module.n(), - cols: cols, - size: size, - _marker: PhantomData, - } - } -} - -impl VecZnxDft { - /// Cast a [VecZnxDft] into a [VecZnxBig]. - /// The returned [VecZnxBig] shares the backing array - /// with the original [VecZnxDft]. - pub fn as_vec_znx_big(&mut self) -> VecZnxBig { - VecZnxBig:: { - data: Vec::new(), - ptr: self.ptr, - n: self.n, - cols: self.cols, - size: self.size, - _marker: PhantomData, - } - } -} - -impl ZnxInfos for VecZnxDft { - fn n(&self) -> usize { - self.n - } - - fn rows(&self) -> usize { - 1 - } - - fn cols(&self) -> usize { - self.cols - } - - fn size(&self) -> usize { - self.size + fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { + debug_assert_eq!( + _rows, VEC_ZNX_DFT_ROWS, + "rows != {} not supported for VecZnxDft", + VEC_ZNX_DFT_ROWS + ); + unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } } } impl ZnxLayout for VecZnxDft { type Scalar = f64; +} - fn as_ptr(&self) -> *const Self::Scalar { - self.ptr as *const Self::Scalar - } - - fn as_mut_ptr(&mut self) -> *mut Self::Scalar { - self.ptr as *mut Self::Scalar +impl ZnxSliceSize for VecZnxDft { + fn sl(&self) -> usize { + self.n() } } @@ -133,225 +58,21 @@ impl VecZnxDft { } } -pub trait VecZnxDftOps { - /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. - fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft; - - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - /// - /// Behavior: takes ownership of the backing array. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxDft]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; - - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - /// - /// Behavior: the backing array is only borrowed. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxDft]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; - - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxDft]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; - - /// Returns the minimum number of bytes necessary to allocate - /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. - fn vec_znx_idft_tmp_bytes(&self) -> usize; - - /// b <- IDFT(a), uses a as scratch space. - fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft); - - fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]); - - fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx); - - fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft); - - fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]); - - fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize; -} - -impl VecZnxDftOps for Module { - fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft { - VecZnxDft::::new(&self, cols, size) - } - - fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes(self, cols, size, tmp_bytes) - } - - fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes_borrow(self, cols, size, tmp_bytes) - } - - fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { - VecZnxDft::bytes_of(&self, cols, size) - } - - fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) { - unsafe { - vec_znx_dft::vec_znx_idft_tmp_a( - self.ptr, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - a.ptr as *mut vec_znx_dft_t, - a.poly_count() as u64, - ) +impl VecZnxDft { + /// Cast a [VecZnxDft] into a [VecZnxBig]. + /// The returned [VecZnxBig] shares the backing array + /// with the original [VecZnxDft]. + pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig { + VecZnxBig:: { + inner: ZnxBase { + data: Vec::new(), + ptr: self.ptr(), + n: self.n(), + rows: self.rows(), + cols: self.cols(), + size: self.size(), + }, + _marker: PhantomData, } } - - fn vec_znx_idft_tmp_bytes(&self) -> usize { - unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize } - } - - /// b <- DFT(a) - /// - /// # Panics - /// If b.cols < a_cols - fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx) { - unsafe { - vec_znx_dft::vec_znx_dft( - self.ptr, - b.ptr as *mut vec_znx_dft_t, - b.size() as u64, - a.as_ptr(), - a.size() as u64, - (a.n() * a.cols()) as u64, - ) - } - } - - // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. - fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]) { - #[cfg(debug_assertions)] - { - assert!( - tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", - tmp_bytes.len(), - Self::vec_znx_idft_tmp_bytes(self) - ); - assert_alignement(tmp_bytes.as_ptr()) - } - unsafe { - vec_znx_dft::vec_znx_idft( - self.ptr, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - a.ptr as *const vec_znx_dft_t, - a.poly_count() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft) { - unsafe { - vec_znx_dft::vec_znx_dft_automorphism( - self.ptr, - k, - b.ptr as *mut vec_znx_dft_t, - b.poly_count() as u64, - a.ptr as *const vec_znx_dft_t, - a.poly_count() as u64, - [0u8; 0].as_mut_ptr(), - ); - } - } - - fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]) { - #[cfg(debug_assertions)] - { - assert!( - tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}", - tmp_bytes.len(), - Self::vec_znx_dft_automorphism_tmp_bytes(self) - ); - assert_alignement(tmp_bytes.as_ptr()) - } - println!("{}", a.poly_count()); - unsafe { - vec_znx_dft::vec_znx_dft_automorphism( - self.ptr, - k, - a.ptr as *mut vec_znx_dft_t, - a.poly_count() as u64, - a.ptr as *const vec_znx_dft_t, - a.poly_count() as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize { - unsafe { vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize } - } -} - -#[cfg(test)] -mod tests { - use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxLayout, alloc_aligned}; - use itertools::izip; - use sampling::source::Source; - - #[test] - fn test_automorphism_dft() { - let n: usize = 8; - let module: Module = Module::::new(n); - - let size: usize = 2; - let log_base2k: usize = 17; - let mut a: VecZnx = module.new_vec_znx(1, size); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, size); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, size); - - let mut source: Source = Source::new([0u8; 32]); - module.fill_uniform(log_base2k, &mut a, 0, size, &mut source); - - let mut tmp_bytes: Vec = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes()); - - let p: i64 = -5; - - // a_dft <- DFT(a) - module.vec_znx_dft(&mut a_dft, &a); - - // a_dft <- AUTO(a_dft) - module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes); - - // a <- AUTO(a) - module.vec_znx_automorphism_inplace(p, &mut a, 0); - - // b_dft <- DFT(AUTO(a)) - module.vec_znx_dft(&mut b_dft, &a); - - let a_f64: &[f64] = a_dft.raw(); - let b_f64: &[f64] = b_dft.raw(); - izip!(a_f64.iter(), b_f64.iter()).for_each(|(ai, bi)| { - assert!((ai - bi).abs() <= 1e-9, "{:+e} > 1e-9", (ai - bi).abs()); - }); - - module.free() - } } diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs new file mode 100644 index 0000000..57b3777 --- /dev/null +++ b/base2k/src/vec_znx_dft_ops.rs @@ -0,0 +1,140 @@ +use crate::ffi::vec_znx_big; +use crate::ffi::vec_znx_dft; +use crate::znx_base::ZnxAlloc; +use crate::znx_base::ZnxInfos; +use crate::znx_base::ZnxLayout; +use crate::znx_base::ZnxSliceSize; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, assert_alignement}; + +pub trait VecZnxDftOps { + /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft; + + /// Returns a new [VecZnxDft] with the provided bytes array as backing array. + /// + /// Behavior: takes ownership of the backing array. + /// + /// # Arguments + /// + /// * `cols`: the number of cols of the [VecZnxDft]. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDft; + + /// Returns a new [VecZnxDft] with the provided bytes array as backing array. + /// + /// Behavior: the backing array is only borrowed. + /// + /// # Arguments + /// + /// * `cols`: the number of cols of the [VecZnxDft]. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. + fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; + + /// Returns a new [VecZnxDft] with the provided bytes array as backing array. + /// + /// # Arguments + /// + /// * `cols`: the number of cols of the [VecZnxDft]. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. + fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; + + /// Returns the minimum number of bytes necessary to allocate + /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. + fn vec_znx_idft_tmp_bytes(&self) -> usize; + + /// b <- IDFT(a), uses a as scratch space. + fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig, res_col: usize, a: &mut VecZnxDft, a_cols: usize); + + fn vec_znx_idft(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxDft, a_col: usize, tmp_bytes: &mut [u8]); + + fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize); +} + +impl VecZnxDftOps for Module { + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft { + VecZnxDft::::new(&self, 1, cols, size) + } + + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDft { + VecZnxDft::from_bytes(self, 1, cols, size, bytes) + } + + fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft { + VecZnxDft::from_bytes_borrow(self, 1, cols, size, bytes) + } + + fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { + VecZnxDft::bytes_of(&self, 1, cols, size) + } + + fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig, res_col: usize, a: &mut VecZnxDft, a_col: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(res.poly_count(), a.poly_count()); + } + + unsafe { + vec_znx_dft::vec_znx_idft_tmp_a( + self.ptr, + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t, + res.size() as u64, + a.at_ptr(a_col * a.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t, + a.size() as u64, + ) + } + } + + fn vec_znx_idft_tmp_bytes(&self) -> usize { + unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize } + } + + /// b <- DFT(a) + /// + /// # Panics + /// If b.cols < a_cols + fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize) { + unsafe { + vec_znx_dft::vec_znx_dft( + self.ptr, + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t, + res.size() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. + fn vec_znx_idft(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxDft, a_col: usize, tmp_bytes: &mut [u8]) { + #[cfg(debug_assertions)] + { + assert!( + tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self), + "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", + tmp_bytes.len(), + Self::vec_znx_idft_tmp_bytes(self) + ); + assert_alignement(tmp_bytes.as_ptr()) + } + unsafe { + vec_znx_dft::vec_znx_idft( + self.ptr, + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t, + res.size() as u64, + a.at_ptr(a_col * res.size(), 0) as *const vec_znx_dft::vec_znx_dft_t, + a.size() as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } +} diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 9f2d43a..7ee1529 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,5 +1,6 @@ use crate::ffi::vec_znx; -use crate::{Backend, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement, switch_degree}; +use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree}; +use crate::{Backend, Module, VEC_ZNX_ROWS, VecZnx, assert_alignement}; pub trait VecZnxOps { /// Allocates a new [VecZnx]. /// @@ -19,7 +20,7 @@ pub trait VecZnxOps { /// /// # Panic /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. - fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx; + fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnx; /// Instantiates a new [VecZnx] from a slice of bytes. /// The returned [VecZnx] does take ownership of the slice of bytes. @@ -107,19 +108,19 @@ pub trait VecZnxOps { impl VecZnxOps for Module { fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx { - VecZnx::new(self, cols, size) + VecZnx::new(self, VEC_ZNX_ROWS, cols, size) } fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize { - VecZnx::bytes_of(self, cols, size) + VecZnx::bytes_of(self, VEC_ZNX_ROWS, cols, size) } - fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx { - VecZnx::from_bytes(self, cols, size, bytes) + fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnx { + VecZnx::from_bytes(self, VEC_ZNX_ROWS, cols, size, bytes) } fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx { - VecZnx::from_bytes_borrow(self, cols, size, tmp_bytes) + VecZnx::from_bytes_borrow(self, VEC_ZNX_ROWS, cols, size, tmp_bytes) } fn vec_znx_normalize_tmp_bytes(&self) -> usize { diff --git a/base2k/src/commons.rs b/base2k/src/znx_base.rs similarity index 67% rename from base2k/src/commons.rs rename to base2k/src/znx_base.rs index d5f60ee..64ad85f 100644 --- a/base2k/src/commons.rs +++ b/base2k/src/znx_base.rs @@ -1,10 +1,37 @@ -use crate::{Backend, Module, assert_alignement, cast_mut}; +use crate::{Backend, Module, alloc_aligned, assert_alignement, cast_mut}; use itertools::izip; use std::cmp::min; -pub trait ZnxInfos { +pub struct ZnxBase { + /// The ring degree + pub n: usize, + + /// The number of rows (in the third dimension) + pub rows: usize, + + /// The number of polynomials + pub cols: usize, + + /// The number of size per polynomial (a.k.a small polynomials). + pub size: usize, + + /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. + pub data: Vec, + + /// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it). + pub ptr: *mut u8, +} + +pub trait GetZnxBase { + fn znx(&self) -> &ZnxBase; + fn znx_mut(&mut self) -> &mut ZnxBase; +} + +pub trait ZnxInfos: GetZnxBase { /// Returns the ring degree of the polynomials. - fn n(&self) -> usize; + fn n(&self) -> usize { + self.znx().n + } /// Returns the base two logarithm of the ring dimension of the polynomials. fn log_n(&self) -> usize { @@ -12,41 +39,104 @@ pub trait ZnxInfos { } /// Returns the number of rows. - fn rows(&self) -> usize; - + fn rows(&self) -> usize { + self.znx().rows + } /// Returns the number of polynomials in each row. - fn cols(&self) -> usize; + fn cols(&self) -> usize { + self.znx().cols + } /// Returns the number of size per polynomial. - fn size(&self) -> usize; + fn size(&self) -> usize { + self.znx().size + } + + fn data(&self) -> &[u8] { + &self.znx().data + } + + fn ptr(&self) -> *mut u8 { + self.znx().ptr + } /// Returns the total number of small polynomials. fn poly_count(&self) -> usize { self.rows() * self.cols() * self.size() } +} +pub trait ZnxSliceSize { /// Returns the slice size, which is the offset between /// two size of the same column. - fn sl(&self) -> usize { - self.n() * self.cols() + fn sl(&self) -> usize; +} + +impl ZnxBase { + pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { + let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes); + res.data = bytes; + res + } + + pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + #[cfg(debug_assertions)] + { + assert_eq!(n & (n - 1), 0, "n must be a power of two"); + assert!(n > 0, "n must be greater than 0"); + assert!(rows > 0, "rows must be greater than 0"); + assert!(cols > 0, "cols must be greater than 0"); + assert!(size > 0, "size must be greater than 0"); + } + Self { + n: n, + rows: rows, + cols: cols, + size: size, + data: Vec::new(), + ptr: bytes.as_mut_ptr(), + } } } -pub trait ZnxBase { +pub trait ZnxAlloc +where + Self: Sized + ZnxInfos, +{ type Scalar; - fn new(module: &Module, cols: usize, size: usize) -> Self; - fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self; - fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self; - fn bytes_of(module: &Module, cols: usize, size: usize) -> usize; + fn new(module: &Module, rows: usize, cols: usize, size: usize) -> Self { + let bytes: Vec = alloc_aligned::(Self::bytes_of(module, rows, cols, size)); + Self::from_bytes(module, rows, cols, size, bytes) + } + + fn from_bytes(module: &Module, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { + let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes); + res.znx_mut().data = bytes; + res + } + + fn from_bytes_borrow(module: &Module, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self; + + fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize; } + pub trait ZnxLayout: ZnxInfos { type Scalar; + /// Returns true if the receiver is only borrowing the data. + fn borrowing(&self) -> bool { + self.znx().data.len() == 0 + } + /// Returns a non-mutable pointer to the underlying coefficients array. - fn as_ptr(&self) -> *const Self::Scalar; + fn as_ptr(&self) -> *const Self::Scalar { + self.znx().ptr as *const Self::Scalar + } /// Returns a mutable pointer to the underlying coefficients array. - fn as_mut_ptr(&mut self) -> *mut Self::Scalar; + fn as_mut_ptr(&mut self) -> *mut Self::Scalar { + self.znx_mut().ptr as *mut Self::Scalar + } /// Returns a non-mutable reference to the entire underlying coefficient array. fn raw(&self) -> &[Self::Scalar] { diff --git a/rlwe/Cargo.toml b/rlwe/Cargo.toml index a8b8207..0822281 100644 --- a/rlwe/Cargo.toml +++ b/rlwe/Cargo.toml @@ -1,5 +1,3 @@ -cargo-features = ["edition2024"] - [package] name = "rlwe" version = "0.1.0" diff --git a/rlwe/src/automorphism.rs b/rlwe/src/automorphism.rs index d76e356..95a935f 100644 --- a/rlwe/src/automorphism.rs +++ b/rlwe/src/automorphism.rs @@ -20,7 +20,7 @@ pub struct AutomorphismKey { } pub fn automorphis_key_new_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize { - module.bytes_of_scalar() + module.bytes_of_svp_ppol() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q) + module.bytes_of_scalar() + module.bytes_of_scalar_znx_dft() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q) } impl Parameters { @@ -103,10 +103,10 @@ impl AutomorphismKey { tmp_bytes: &mut [u8], ) -> Vec { let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar()); - let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol()); + let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar_znx_dft()); let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes); - let mut sk_out: ScalarZnxDft = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes); + let mut sk_out: ScalarZnxDft = module.new_scalar_znx_dft_from_bytes_borrow(sk_out_bytes); let mut keys: Vec = Vec::new(); @@ -116,7 +116,7 @@ impl AutomorphismKey { let p_inv: i64 = module.galois_element_inv(*pi); module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx()); - module.svp_prepare(&mut sk_out, &sk_auto); + module.scalar_znx_dft_prepare(&mut sk_out, &sk_auto); encrypt_grlwe_sk( module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes, ); diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index 6017159..511f755 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -20,7 +20,7 @@ impl SecretKey { } pub fn prepare(&self, module: &Module, sk_ppol: &mut ScalarZnxDft) { - module.svp_prepare(sk_ppol, &self.0) + module.scalar_znx_dft_prepare(sk_ppol, &self.0) } } From 9ade995cd70f3409cc257365be7ebe2f20385fb0 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 30 Apr 2025 23:11:43 +0200 Subject: [PATCH 20/87] reworked scalar --- base2k/examples/rlwe_encrypt.rs | 14 +- base2k/src/lib.rs | 28 +-- base2k/src/mat_znx_dft.rs | 14 +- base2k/src/scalar_znx.rs | 113 ++++++++++++ base2k/src/scalar_znx_dft.rs | 293 +++++-------------------------- base2k/src/scalar_znx_dft_ops.rs | 63 +++++++ base2k/src/vec_znx_dft.rs | 2 +- base2k/src/vec_znx_ops.rs | 122 ++++++------- 8 files changed, 311 insertions(+), 338 deletions(-) create mode 100644 base2k/src/scalar_znx.rs create mode 100644 base2k/src/scalar_znx_dft_ops.rs diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 0f75ef3..07fe1c6 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,6 @@ use base2k::{ - Encoding, FFT64, Module, Sampling, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, - VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned, + Encoding, FFT64, Module, Sampling, Scalar, ScalarOps, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, + VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned, }; use itertools::izip; use sampling::source::Source; @@ -19,14 +19,14 @@ fn main() { let mut source: Source = Source::new(seed); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) - let mut s: Scalar = Scalar::new(n); - s.fill_ternary_prob(0.5, &mut source); + let mut s: Scalar = module.new_scalar(1); + s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_dft: ScalarZnxDft = module.new_scalar_znx_dft(); + let mut s_dft: ScalarZnxDft = module.new_scalar_znx_dft(s.cols()); // s_dft <- DFT(s) - module.svp_prepare(&mut s_dft, &s); + module.svp_prepare(&mut s_dft, 0, &s, 0); // Allocates a VecZnx with two columns: ct=(0, 0) let mut ct: VecZnx = module.new_vec_znx( @@ -48,6 +48,7 @@ fn main() { &mut buf_dft, // DFT(ct[1] * s) 0, // Selects the first column of res &s_dft, // DFT(s) + 0, // Selects the first column of s_dft &ct, 1, // Selects the second column of ct ); @@ -106,6 +107,7 @@ fn main() { &mut buf_dft, 0, // Selects the first column of res. &s_dft, + 0, &ct, 1, // Selects the second column of ct (ct[1]) ); diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index f57e482..3fa0bbe 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -5,7 +5,9 @@ pub mod ffi; pub mod mat_znx_dft; pub mod module; pub mod sampling; +pub mod scalar_znx; pub mod scalar_znx_dft; +pub mod scalar_znx_dft_ops; pub mod stats; pub mod vec_znx; pub mod vec_znx_big; @@ -19,8 +21,11 @@ pub use encoding::*; pub use mat_znx_dft::*; pub use module::*; pub use sampling::*; +#[allow(unused_imports)] +pub use scalar_znx::*; pub use scalar_znx_dft::*; #[allow(unused_imports)] +pub use scalar_znx_dft_ops::*; pub use stats::*; pub use vec_znx::*; pub use vec_znx_big::*; @@ -50,13 +55,13 @@ pub fn assert_alignement(ptr: *const T) { pub fn cast(data: &[T]) -> &[V] { let ptr: *const V = data.as_ptr() as *const V; - let len: usize = data.len() / std::mem::size_of::(); + let len: usize = data.len() / size_of::(); unsafe { std::slice::from_raw_parts(ptr, len) } } pub fn cast_mut(data: &[T]) -> &mut [V] { let ptr: *mut V = data.as_ptr() as *mut V; - let len: usize = data.len() / std::mem::size_of::(); + let len: usize = data.len() / size_of::(); unsafe { std::slice::from_raw_parts_mut(ptr, len) } } @@ -70,7 +75,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { align ); assert_eq!( - (size * std::mem::size_of::()) % align, + (size * size_of::()) % align, 0, "size={} must be a multiple of align={}", size, @@ -98,22 +103,25 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { /// Size of T * size msut be a multiple of [DEFAULTALIGN]. pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { assert_eq!( - (size * std::mem::size_of::()) % align, + (size * size_of::()) % align, 0, "size={} must be a multiple of align={}", size, align ); - let mut vec_u8: Vec = alloc_aligned_custom_u8(std::mem::size_of::() * size, align); + let mut vec_u8: Vec = alloc_aligned_custom_u8(size_of::() * size, align); let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T; - let len: usize = vec_u8.len() / std::mem::size_of::(); - let cap: usize = vec_u8.capacity() / std::mem::size_of::(); + let len: usize = vec_u8.len() / size_of::(); + let cap: usize = vec_u8.capacity() / size_of::(); std::mem::forget(vec_u8); unsafe { Vec::from_raw_parts(ptr, len, cap) } } -/// Allocates an aligned of size equal to the smallest multiple -/// of [DEFAULTALIGN] that is equal or greater to `size`. +/// Allocates an aligned vector of size equal to the smallest multiple +/// of [DEFAULTALIGN]/size_of::() that is equal or greater to `size`. pub fn alloc_aligned(size: usize) -> Vec { - alloc_aligned_custom::(size + (size % DEFAULTALIGN), DEFAULTALIGN) + alloc_aligned_custom::( + size + (size % (DEFAULTALIGN / size_of::())), + DEFAULTALIGN, + ) } diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index 9b5e2ca..44d44df 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -160,7 +160,7 @@ pub trait MatZnxDftOps { /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, row_i: usize); + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, row_i: usize, a: &MatZnxDft); /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. /// @@ -170,7 +170,7 @@ pub trait MatZnxDftOps { /// * `a_size`: number of size of the input [VecZnx]. /// * `rows`: number of rows of the input [MatZnxDft]. /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; + fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize; /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// @@ -404,7 +404,7 @@ impl MatZnxDftOps for Module { } } - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, row_i: usize) { + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, row_i: usize, a: &MatZnxDft) { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); @@ -422,14 +422,14 @@ impl MatZnxDftOps for Module { } } - fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize { + fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize { unsafe { vmp::vmp_apply_dft_tmp_bytes( self.ptr, res_size as u64, a_size as u64, - gct_rows as u64, - gct_size as u64, + b_rows as u64, + b_size as u64, ) as usize } } @@ -595,7 +595,7 @@ mod tests { assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); // Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) - module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i); + module.vmp_extract_row_dft(&mut b_dft, row_i, &vmpmat_0); assert_eq!(a_dft.raw(), b_dft.raw()); // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs new file mode 100644 index 0000000..df3e6d1 --- /dev/null +++ b/base2k/src/scalar_znx.rs @@ -0,0 +1,113 @@ +use crate::znx_base::{ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, GetZnxBase, Module, VecZnx}; +use rand::seq::SliceRandom; +use rand_core::RngCore; +use rand_distr::{Distribution, weighted::WeightedIndex}; +use sampling::source::Source; + +pub const SCALAR_ZNX_ROWS: usize = 1; +pub const SCALAR_ZNX_SIZE: usize = 1; + +pub struct Scalar { + pub inner: ZnxBase, +} + +impl GetZnxBase for Scalar { + fn znx(&self) -> &ZnxBase { + &self.inner + } + + fn znx_mut(&mut self) -> &mut ZnxBase { + &mut self.inner + } +} + +impl ZnxInfos for Scalar {} + +impl ZnxAlloc for Scalar { + type Scalar = i64; + + fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { + Self { + inner: ZnxBase::from_bytes_borrow(module.n(), SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes), + } + } + + fn bytes_of(module: &Module, _rows: usize, cols: usize, _size: usize) -> usize { + debug_assert_eq!( + _rows, SCALAR_ZNX_ROWS, + "rows != {} not supported for Scalar", + SCALAR_ZNX_ROWS + ); + debug_assert_eq!( + _size, SCALAR_ZNX_SIZE, + "rows != {} not supported for Scalar", + SCALAR_ZNX_SIZE + ); + module.n() * cols * std::mem::size_of::() + } +} + +impl ZnxLayout for Scalar { + type Scalar = i64; +} + +impl ZnxSliceSize for Scalar { + fn sl(&self) -> usize { + self.n() + } +} + +impl Scalar { + pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) { + let choices: [i64; 3] = [-1, 0, 1]; + let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; + let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); + self.at_mut(col, 0) + .iter_mut() + .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); + } + + pub fn fill_ternary_hw(&mut self, col: usize, hw: usize, source: &mut Source) { + assert!(hw <= self.n()); + self.at_mut(col, 0)[..hw] + .iter_mut() + .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); + self.at_mut(col, 0).shuffle(source); + } + + pub fn alias_as_vec_znx(&self) -> VecZnx { + VecZnx { + inner: ZnxBase { + n: self.n(), + rows: 1, + cols: 1, + size: 1, + data: Vec::new(), + ptr: self.ptr() as *mut u8, + }, + } + } +} + +pub trait ScalarOps { + fn bytes_of_scalar(&self, cols: usize) -> usize; + fn new_scalar(&self, cols: usize) -> Scalar; + fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> Scalar; + fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar; +} + +impl ScalarOps for Module { + fn bytes_of_scalar(&self, cols: usize) -> usize { + Scalar::bytes_of(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE) + } + fn new_scalar(&self, cols: usize) -> Scalar { + Scalar::new(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE) + } + fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> Scalar { + Scalar::from_bytes(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes) + } + fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar { + Scalar::from_bytes_borrow(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes) + } +} diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index 07e156d..ffb54b5 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -1,279 +1,66 @@ use std::marker::PhantomData; -use crate::ffi::svp::{self, svp_ppol_t}; -use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::znx_base::{ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement, cast_mut}; -use rand::seq::SliceRandom; -use rand_core::RngCore; -use rand_distr::{Distribution, weighted::WeightedIndex}; -use sampling::source::Source; +use crate::ffi::svp; +use crate::znx_base::{ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, FFT64, GetZnxBase, Module}; -pub struct Scalar { - pub n: usize, - pub data: Vec, - pub ptr: *mut i64, -} - -impl Module { - pub fn new_scalar(&self) -> Scalar { - Scalar::new(self.n()) - } -} - -impl Scalar { - pub fn new(n: usize) -> Self { - let mut data: Vec = alloc_aligned::(n); - let ptr: *mut i64 = data.as_mut_ptr(); - Self { - n: n, - data: data, - ptr: ptr, - } - } - - pub fn n(&self) -> usize { - self.n - } - - pub fn bytes_of(n: usize) -> usize { - n * std::mem::size_of::() - } - - pub fn from_bytes(n: usize, bytes: &mut [u8]) -> Self { - let size: usize = Self::bytes_of(n); - debug_assert!( - bytes.len() == size, - "invalid buffer: bytes.len()={} < self.bytes_of(n={})={}", - bytes.len(), - n, - size - ); - #[cfg(debug_assertions)] - { - assert_alignement(bytes.as_ptr()) - } - unsafe { - let bytes_i64: &mut [i64] = cast_mut::(bytes); - let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - Self { - n: n, - data: Vec::from_raw_parts(bytes_i64.as_mut_ptr(), bytes.len(), bytes.len()), - ptr: ptr, - } - } - } - - pub fn from_bytes_borrow(n: usize, bytes: &mut [u8]) -> Self { - let size: usize = Self::bytes_of(n); - debug_assert!( - bytes.len() == size, - "invalid buffer: bytes.len()={} < self.bytes_of(n={})={}", - bytes.len(), - n, - size - ); - #[cfg(debug_assertions)] - { - assert_alignement(bytes.as_ptr()) - } - let bytes_i64: &mut [i64] = cast_mut::(bytes); - let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - Self { - n: n, - data: Vec::new(), - ptr: ptr, - } - } - - pub fn as_ptr(&self) -> *const i64 { - self.ptr - } - - pub fn raw(&self) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.ptr, self.n) } - } - - pub fn raw_mut(&self) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n) } - } - - pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { - let choices: [i64; 3] = [-1, 0, 1]; - let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; - let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); - self.data - .iter_mut() - .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); - } - - pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { - assert!(hw <= self.n()); - self.data[..hw] - .iter_mut() - .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); - self.data.shuffle(source); - } - - pub fn as_vec_znx(&self) -> VecZnx { - VecZnx { - inner: ZnxBase { - n: self.n, - rows: 1, - cols: 1, - size: 1, - data: Vec::new(), - ptr: self.ptr as *mut u8, - }, - } - } -} - -pub trait ScalarOps { - fn bytes_of_scalar(&self) -> usize; - fn new_scalar(&self) -> Scalar; - fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar; - fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar; -} -impl ScalarOps for Module { - fn bytes_of_scalar(&self) -> usize { - Scalar::bytes_of(self.n()) - } - fn new_scalar(&self) -> Scalar { - Scalar::new(self.n()) - } - fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar { - Scalar::from_bytes(self.n(), bytes) - } - fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar { - Scalar::from_bytes_borrow(self.n(), tmp_bytes) - } -} +pub const SCALAR_ZNX_DFT_ROWS: usize = 1; +pub const SCALAR_ZNX_DFT_SIZE: usize = 1; pub struct ScalarZnxDft { - pub n: usize, - pub data: Vec, - pub ptr: *mut u8, + pub inner: ZnxBase, _marker: PhantomData, } -/// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft]. -/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb. -impl ScalarZnxDft { - pub fn new(module: &Module) -> Self { - module.new_scalar_znx_dft() +impl GetZnxBase for ScalarZnxDft { + fn znx(&self) -> &ZnxBase { + &self.inner } - /// Returns the ring degree of the [SvpPPol]. - pub fn n(&self) -> usize { - self.n + fn znx_mut(&mut self) -> &mut ZnxBase { + &mut self.inner } +} - pub fn bytes_of(module: &Module) -> usize { - module.bytes_of_scalar_znx_dft() - } +impl ZnxInfos for ScalarZnxDft {} - pub fn from_bytes(module: &Module, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert_alignement(bytes.as_ptr()); - assert_eq!(bytes.len(), module.bytes_of_scalar_znx_dft()); - } - unsafe { - Self { - n: module.n(), - data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), - ptr: bytes.as_mut_ptr(), - _marker: PhantomData, - } - } - } +impl ZnxAlloc for ScalarZnxDft { + type Scalar = u8; - pub fn from_bytes_borrow(module: &Module, tmp_bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - assert_eq!(tmp_bytes.len(), module.bytes_of_scalar_znx_dft()); - } + fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { Self { - n: module.n(), - data: Vec::new(), - ptr: tmp_bytes.as_mut_ptr(), + inner: ZnxBase::from_bytes_borrow( + module.n(), + SCALAR_ZNX_DFT_ROWS, + cols, + SCALAR_ZNX_DFT_SIZE, + bytes, + ), _marker: PhantomData, } } - /// Returns the number of cols of the [SvpPPol], which is always 1. - pub fn cols(&self) -> usize { - 1 + fn bytes_of(module: &Module, _rows: usize, cols: usize, _size: usize) -> usize { + debug_assert_eq!( + _rows, SCALAR_ZNX_DFT_ROWS, + "rows != {} not supported for ScalarZnxDft", + SCALAR_ZNX_DFT_ROWS + ); + debug_assert_eq!( + _size, SCALAR_ZNX_DFT_SIZE, + "rows != {} not supported for ScalarZnxDft", + SCALAR_ZNX_DFT_SIZE + ); + unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols } } } -pub trait ScalarZnxDftOps { - /// Allocates a new [SvpPPol]. - fn new_scalar_znx_dft(&self) -> ScalarZnxDft; - - /// Returns the minimum number of bytes necessary to allocate - /// a new [SvpPPol] through [SvpPPol::from_bytes] ro. - fn bytes_of_scalar_znx_dft(&self) -> usize; - - /// Allocates a new [SvpPPol] from an array of bytes. - /// The array of bytes is owned by the [SvpPPol]. - /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft; - - /// Allocates a new [SvpPPol] from an array of bytes. - /// The array of bytes is borrowed by the [SvpPPol]. - /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft; - - /// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft]. - fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft, a: &Scalar); - - /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of - /// the [VecZnxDft] is multiplied with [SvpPPol]. - fn svp_apply_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &ScalarZnxDft, b: &VecZnx, b_col: usize); +impl ZnxLayout for ScalarZnxDft { + type Scalar = f64; } -impl ScalarZnxDftOps for Module { - fn new_scalar_znx_dft(&self) -> ScalarZnxDft { - let mut data: Vec = alloc_aligned::(self.bytes_of_scalar_znx_dft()); - let ptr: *mut u8 = data.as_mut_ptr(); - ScalarZnxDft:: { - data: data, - ptr: ptr, - n: self.n(), - _marker: PhantomData, - } - } - - fn bytes_of_scalar_znx_dft(&self) -> usize { - unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize } - } - - fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft { - ScalarZnxDft::from_bytes(self, bytes) - } - - fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft { - ScalarZnxDft::from_bytes_borrow(self, tmp_bytes) - } - - fn svp_prepare(&self, res: &mut ScalarZnxDft, a: &Scalar) { - unsafe { svp::svp_prepare(self.ptr, res.ptr as *mut svp_ppol_t, a.as_ptr()) } - } - - fn svp_apply_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &ScalarZnxDft, b: &VecZnx, b_col: usize) { - unsafe { - svp::svp_apply_dft( - self.ptr, - res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, - res.size() as u64, - a.ptr as *const svp_ppol_t, - b.at_ptr(b_col, 0), - b.size() as u64, - b.sl() as u64, - ) - } +impl ZnxSliceSize for ScalarZnxDft { + fn sl(&self) -> usize { + self.n() } } diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs new file mode 100644 index 0000000..4fbe99d --- /dev/null +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -0,0 +1,63 @@ +use crate::ffi::svp::{self, svp_ppol_t}; +use crate::ffi::vec_znx_dft::vec_znx_dft_t; +use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, FFT64, Module, SCALAR_ZNX_DFT_ROWS, SCALAR_ZNX_DFT_SIZE, Scalar, ScalarZnxDft, VecZnx, VecZnxDft}; + +pub trait ScalarZnxDftOps { + fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDft; + fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; + fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDft; + fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft; + fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize); + fn svp_apply_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &ScalarZnxDft, a_col: usize, b: &VecZnx, b_col: usize); +} + +impl ScalarZnxDftOps for Module { + fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDft { + ScalarZnxDft::::new(&self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE) + } + + fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize { + ScalarZnxDft::::bytes_of(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE) + } + + fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDft { + ScalarZnxDft::from_bytes(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes) + } + + fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft { + ScalarZnxDft::from_bytes_borrow(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes) + } + + fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize) { + unsafe { + svp::svp_prepare( + self.ptr, + res.at_mut_ptr(res_col, 0) as *mut svp_ppol_t, + a.at_ptr(a_col, 0), + ) + } + } + + fn svp_apply_dft( + &self, + res: &mut VecZnxDft, + res_col: usize, + a: &ScalarZnxDft, + a_col: usize, + b: &VecZnx, + b_col: usize, + ) { + unsafe { + svp::svp_apply_dft( + self.ptr, + res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, + res.size() as u64, + a.at_ptr(a_col, 0) as *const svp_ppol_t, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } +} diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 09ee971..a9dd378 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -26,7 +26,7 @@ impl ZnxAlloc for VecZnxDft { type Scalar = u8; fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { - VecZnxDft { + Self { inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes), _marker: PhantomData, } diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 7ee1529..6365ad3 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -47,47 +47,47 @@ pub trait VecZnxOps { &self, log_base2k: usize, res: &mut VecZnx, - col_res: usize, + res_col: usize, a: &VecZnx, - col_a: usize, + a_col: usize, tmp_bytes: &mut [u8], ); /// Normalizes the selected column of `a`. - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, tmp_bytes: &mut [u8]); + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]); /// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `c`. - fn vec_znx_add(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize); + fn vec_znx_add(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize); /// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `res`. - fn vec_znx_add_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Subtracts the selected column of `b` to the selected column of `a` and write the result on the selected column of `res`. - fn vec_znx_sub(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize); + fn vec_znx_sub(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize); /// Subtracts the selected column of `a` to the selected column of `res`. - fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Subtracts the selected column of `a` to the selected column of `res` and negates the selected column of `res`. - fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); // Negates the selected column of `a` and stores the result on the selected column of `res`. - fn vec_znx_negate(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Negates the selected column of `a`. - fn vec_znx_negate_inplace(&self, a: &mut VecZnx, col_a: usize); + fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize); /// Multiplies the selected column of `a` by X^k and stores the result on the selected column of `res`. - fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Multiplies the selected column of `a` by X^k. - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize); + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); /// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result on the selected column of `res`. - fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Applies the automorphism X^i -> X^ik on the selected column of `a`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize); + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); /// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`. /// @@ -95,7 +95,7 @@ pub trait VecZnxOps { /// /// This method requires that all [VecZnx] of b have the same ring degree /// and that b.n() * b.len() <= a.n() - fn vec_znx_split(&self, res: &mut Vec, col_res: usize, a: &VecZnx, col_a: usize, buf: &mut VecZnx); + fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &VecZnx, a_col: usize, buf: &mut VecZnx); /// Merges the subrings of the selected column of `a` into the selected column of `res`. /// @@ -103,7 +103,7 @@ pub trait VecZnxOps { /// /// This method requires that all [VecZnx] of a have the same ring degree /// and that a.n() * a.len() <= b.n() - fn vec_znx_merge(&self, res: &mut VecZnx, col_res: usize, a: &Vec, col_a: usize); + fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec, a_col: usize); } impl VecZnxOps for Module { @@ -131,9 +131,9 @@ impl VecZnxOps for Module { &self, log_base2k: usize, res: &mut VecZnx, - col_res: usize, + res_col: usize, a: &VecZnx, - col_a: usize, + a_col: usize, tmp_bytes: &mut [u8], ) { #[cfg(debug_assertions)] @@ -147,10 +147,10 @@ impl VecZnxOps for Module { vec_znx::vec_znx_normalize_base2k( self.ptr, log_base2k as u64, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, tmp_bytes.as_mut_ptr(), @@ -158,22 +158,22 @@ impl VecZnxOps for Module { } } - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, tmp_bytes: &mut [u8]) { + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; Self::vec_znx_normalize( self, log_base2k, &mut *a_ptr, - col_a, + a_col, &*a_ptr, - col_a, + a_col, tmp_bytes, ); } } - fn vec_znx_add(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize) { + fn vec_znx_add(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -184,27 +184,27 @@ impl VecZnxOps for Module { unsafe { vec_znx::vec_znx_add( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, - b.at_ptr(col_b, 0), + b.at_ptr(b_col, 0), b.size() as u64, b.sl() as u64, ) } } - fn vec_znx_add_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { let res_ptr: *mut VecZnx = res as *mut VecZnx; - Self::vec_znx_add(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + Self::vec_znx_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); } } - fn vec_znx_sub(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize) { + fn vec_znx_sub(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -215,34 +215,34 @@ impl VecZnxOps for Module { unsafe { vec_znx::vec_znx_sub( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, - b.at_ptr(col_b, 0), + b.at_ptr(b_col, 0), b.size() as u64, b.sl() as u64, ) } } - fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { let res_ptr: *mut VecZnx = res as *mut VecZnx; - Self::vec_znx_sub(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + Self::vec_znx_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); } } - fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { let res_ptr: *mut VecZnx = res as *mut VecZnx; - Self::vec_znx_sub(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a); + Self::vec_znx_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); } } - fn vec_znx_negate(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -251,24 +251,24 @@ impl VecZnxOps for Module { unsafe { vec_znx::vec_znx_negate( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, ) } } - fn vec_znx_negate_inplace(&self, a: &mut VecZnx, col_a: usize) { + fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_negate(self, &mut *a_ptr, col_a, &*a_ptr, col_a); + Self::vec_znx_negate(self, &mut *a_ptr, a_col, &*a_ptr, a_col); } } - fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -278,24 +278,24 @@ impl VecZnxOps for Module { vec_znx::vec_znx_rotate( self.ptr, k, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, ) } } - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize) { + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_rotate(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a); + Self::vec_znx_rotate(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col); } } - fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -305,24 +305,24 @@ impl VecZnxOps for Module { vec_znx::vec_znx_automorphism( self.ptr, k, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, ) } } - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize) { + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_automorphism(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a); + Self::vec_znx_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col); } } - fn vec_znx_split(&self, res: &mut Vec, col_res: usize, a: &VecZnx, col_a: usize, buf: &mut VecZnx) { + fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &VecZnx, a_col: usize, buf: &mut VecZnx) { let (n_in, n_out) = (a.n(), res[0].n()); debug_assert!( @@ -339,16 +339,16 @@ impl VecZnxOps for Module { res.iter_mut().enumerate().for_each(|(i, bi)| { if i == 0 { - switch_degree(bi, col_res, a, col_a); - self.vec_znx_rotate(-1, buf, 0, a, col_a); + switch_degree(bi, res_col, a, a_col); + self.vec_znx_rotate(-1, buf, 0, a, a_col); } else { - switch_degree(bi, col_res, buf, col_a); - self.vec_znx_rotate_inplace(-1, buf, col_a); + switch_degree(bi, res_col, buf, a_col); + self.vec_znx_rotate_inplace(-1, buf, a_col); } }) } - fn vec_znx_merge(&self, res: &mut VecZnx, col_res: usize, a: &Vec, col_a: usize) { + fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec, a_col: usize) { let (n_in, n_out) = (res.n(), a[0].n()); debug_assert!( @@ -364,10 +364,10 @@ impl VecZnxOps for Module { }); a.iter().enumerate().for_each(|(_, ai)| { - switch_degree(res, col_res, ai, col_a); - self.vec_znx_rotate_inplace(-1, res, col_res); + switch_degree(res, res_col, ai, a_col); + self.vec_znx_rotate_inplace(-1, res, res_col); }); - self.vec_znx_rotate_inplace(a.len() as i64, res, col_res); + self.vec_znx_rotate_inplace(a.len() as i64, res, res_col); } } From 7233e2509d83dda531e3dbb270eb741e73a92145 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 30 Apr 2025 23:23:54 +0200 Subject: [PATCH 21/87] removed unecessary allow --- base2k/src/lib.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 3fa0bbe..198e197 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -21,10 +21,8 @@ pub use encoding::*; pub use mat_znx_dft::*; pub use module::*; pub use sampling::*; -#[allow(unused_imports)] pub use scalar_znx::*; pub use scalar_znx_dft::*; -#[allow(unused_imports)] pub use scalar_znx_dft_ops::*; pub use stats::*; pub use vec_znx::*; From 4e6fce3458b6b88811b36aabda213a77f0984854 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 1 May 2025 08:39:51 +0200 Subject: [PATCH 22/87] split mat_znx into struct and ops + added missing ops on module --- base2k/src/lib.rs | 2 + base2k/src/mat_znx_dft.rs | 526 +-------------------------------- base2k/src/mat_znx_dft_ops.rs | 536 ++++++++++++++++++++++++++++++++++ 3 files changed, 540 insertions(+), 524 deletions(-) create mode 100644 base2k/src/mat_znx_dft_ops.rs diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 198e197..73d90c2 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -3,6 +3,7 @@ pub mod encoding; // Other modules and exports pub mod ffi; pub mod mat_znx_dft; +pub mod mat_znx_dft_ops; pub mod module; pub mod sampling; pub mod scalar_znx; @@ -19,6 +20,7 @@ pub mod znx_base; pub use encoding::*; pub use mat_znx_dft::*; +pub use mat_znx_dft_ops::*; pub use module::*; pub use sampling::*; pub use scalar_znx::*; diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index 44d44df..104bd4b 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,8 +1,5 @@ -use crate::ffi::vec_znx_big::vec_znx_big_t; -use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::ffi::vmp::{self, vmp_pmat_t}; use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Module, alloc_aligned}; use std::marker::PhantomData; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], @@ -49,7 +46,7 @@ impl ZnxAlloc for MatZnxDft { } fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize { - unsafe { vmp::bytes_of_vmp_pmat(module.ptr, rows as u64, size as u64) as usize * cols } + unsafe { crate::ffi::vmp::bytes_of_vmp_pmat(module.ptr, rows as u64, size as u64) as usize * cols } } } @@ -88,522 +85,3 @@ impl MatZnxDft { } } } - -/// This trait implements methods for vector matrix product, -/// that is, multiplying a [VecZnx] with a [MatZnxDft]. -pub trait MatZnxDftOps { - fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize; - - /// Allocates a new [MatZnxDft] with the given number of rows and columns. - /// - /// # Arguments - /// - /// * `rows`: number of rows (number of [VecZnxDft]). - /// * `size`: number of size (number of size of each [VecZnxDft]). - fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft; - - /// Returns the number of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_contiguous]. - /// - /// # Arguments - /// - /// * `rows`: number of rows of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous]. - /// * `size`: number of size of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous]. - fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize; - - /// Prepares a [MatZnxDft] from a contiguous array of [i64]. - /// The helper struct [Matrix3D] can be used to contruct and populate - /// the appropriate contiguous array. - /// - /// # Arguments - /// - /// * `b`: [MatZnxDft] on which the values are encoded. - /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [MatZnxDft]. - /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft, a: &[i64], buf: &mut [u8]); - - /// Prepares the ith-row of [MatZnxDft] from a [VecZnx]. - /// - /// # Arguments - /// - /// * `b`: [MatZnxDft] on which the values are encoded. - /// * `a`: the vector of [VecZnx] to encode on the [MatZnxDft]. - /// * `row_i`: the index of the row to prepare. - /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - /// - /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); - - /// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig]. - /// - /// # Arguments - /// - /// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft]. - /// * `a`: [MatZnxDft] on which the values are encoded. - /// * `row_i`: the index of the row to extract. - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &MatZnxDft, row_i: usize); - - /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `b`: [MatZnxDft] on which the values are encoded. - /// * `a`: the [VecZnxDft] to encode on the [MatZnxDft]. - /// * `row_i`: the index of the row to prepare. - /// - /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, a: &VecZnxDft, row_i: usize); - - /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. - /// * `a`: [MatZnxDft] on which the values are encoded. - /// * `row_i`: the index of the row to extract. - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, row_i: usize, a: &MatZnxDft); - - /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. - /// - /// # Arguments - /// - /// * `c_size`: number of size of the output [VecZnxDft]. - /// * `a_size`: number of size of the input [VecZnx]. - /// * `rows`: number of rows of the input [MatZnxDft]. - /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize; - - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. - /// - /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and - /// `j` size, the output is a [VecZnx] of `j` size. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnx] of the vector matrix product. - /// * `b`: the right operand [MatZnxDft] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); - - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on the receiver. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. - /// - /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and - /// `j` size, the output is a [VecZnx] of `j` size. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnx] of the vector matrix product. - /// * `b`: the right operand [MatZnxDft] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); - - /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. - /// - /// # Arguments - /// - /// * `c_size`: number of size of the output [VecZnxDft]. - /// * `a_size`: number of size of the input [VecZnxDft]. - /// * `rows`: number of rows of the input [MatZnxDft]. - /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; - - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. - /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. - /// - /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and - /// `j` size, the output is a [VecZnx] of `j` size. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnxDft] of the vector matrix product. - /// * `b`: the right operand [MatZnxDft] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); - - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on top of the receiver instead of overwritting it. - /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. - /// - /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and - /// `j` size, the output is a [VecZnx] of `j` size. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnxDft] of the vector matrix product. - /// * `b`: the right operand [MatZnxDft] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); - - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] in place. - /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. - /// - /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and - /// `j` size, the output is a [VecZnx] of `j` size. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. - /// * `a`: the right operand [MatZnxDft] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, buf: &mut [u8]); -} - -impl MatZnxDftOps for Module { - fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft { - MatZnxDft::::new(self, rows, cols, size) - } - - fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize { - unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (size * cols) as u64) as usize } - } - - fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize { - unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize } - } - - fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft, a: &[i64], tmp_bytes: &mut [u8]) { - #[cfg(debug_assertions)] - { - assert_eq!(a.len(), b.n() * b.poly_count()); - assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size())); - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_prepare_contiguous( - self.ptr, - b.as_mut_ptr() as *mut vmp_pmat_t, - a.as_ptr(), - b.rows() as u64, - (b.size() * b.cols()) as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { - #[cfg(debug_assertions)] - { - assert_eq!(a.len(), b.size() * self.n() * b.cols()); - assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size())); - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_prepare_row( - self.ptr, - b.as_mut_ptr() as *mut vmp_pmat_t, - a.as_ptr(), - row_i as u64, - b.rows() as u64, - (b.size() * b.cols()) as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &MatZnxDft, row_i: usize) { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), b.n()); - assert_eq!(a.size(), b.size()); - assert_eq!(a.cols(), b.cols()); - } - unsafe { - vmp::vmp_extract_row( - self.ptr, - b.as_mut_ptr() as *mut vec_znx_big_t, - a.as_ptr() as *const vmp_pmat_t, - row_i as u64, - a.rows() as u64, - (a.size() * a.cols()) as u64, - ); - } - } - - fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, a: &VecZnxDft, row_i: usize) { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), b.n()); - assert_eq!(a.size(), b.size()); - } - unsafe { - vmp::vmp_prepare_row_dft( - self.ptr, - b.as_mut_ptr() as *mut vmp_pmat_t, - a.as_ptr() as *const vec_znx_dft_t, - row_i as u64, - b.rows() as u64, - b.size() as u64, - ); - } - } - - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, row_i: usize, a: &MatZnxDft) { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), b.n()); - assert_eq!(a.size(), b.size()); - } - unsafe { - vmp::vmp_extract_row_dft( - self.ptr, - b.as_mut_ptr() as *mut vec_znx_dft_t, - a.as_ptr() as *const vmp_pmat_t, - row_i as u64, - a.rows() as u64, - a.size() as u64, - ); - } - } - - fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize { - unsafe { - vmp::vmp_apply_dft_tmp_bytes( - self.ptr, - res_size as u64, - a_size as u64, - b_rows as u64, - b_size as u64, - ) as usize - } - } - - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft( - self.ptr, - c.as_mut_ptr() as *mut vec_znx_dft_t, - c.size() as u64, - a.as_ptr(), - a.size() as u64, - (a.n() * a.cols()) as u64, - b.as_ptr() as *const vmp_pmat_t, - b.rows() as u64, - b.size() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft_add( - self.ptr, - c.as_mut_ptr() as *mut vec_znx_dft_t, - c.size() as u64, - a.as_ptr(), - a.size() as u64, - (a.n() * a.size()) as u64, - b.as_ptr() as *const vmp_pmat_t, - b.rows() as u64, - b.size() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize { - unsafe { - vmp::vmp_apply_dft_to_dft_tmp_bytes( - self.ptr, - res_size as u64, - a_size as u64, - gct_rows as u64, - gct_size as u64, - ) as usize - } - } - - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft_to_dft( - self.ptr, - c.as_mut_ptr() as *mut vec_znx_dft_t, - c.size() as u64, - a.as_ptr() as *const vec_znx_dft_t, - a.size() as u64, - b.as_ptr() as *const vmp_pmat_t, - b.rows() as u64, - b.size() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_dft_to_dft_add( - &self, - c: &mut VecZnxDft, - a: &VecZnxDft, - b: &MatZnxDft, - tmp_bytes: &mut [u8], - ) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft_to_dft_add( - self.ptr, - c.as_mut_ptr() as *mut vec_znx_dft_t, - c.size() as u64, - a.as_ptr() as *const vec_znx_dft_t, - a.size() as u64, - b.as_ptr() as *const vmp_pmat_t, - b.rows() as u64, - b.size() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.size(), b.size(), a.rows(), a.size())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft_to_dft( - self.ptr, - b.as_mut_ptr() as *mut vec_znx_dft_t, - b.size() as u64, - b.as_ptr() as *mut vec_znx_dft_t, - b.size() as u64, - a.as_ptr() as *const vmp_pmat_t, - a.rows() as u64, - a.size() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } -} - -#[cfg(test)] -mod tests { - use crate::{ - FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, - alloc_aligned, znx_base::ZnxLayout, - }; - use sampling::source::Source; - - #[test] - fn vmp_prepare_row_dft() { - let module: Module = Module::::new(32); - let vpmat_rows: usize = 4; - let vpmat_size: usize = 5; - let log_base2k: usize = 8; - let mut a: VecZnx = module.new_vec_znx(1, vpmat_size); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_size); - let mut a_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); - let mut b_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_size); - let mut vmpmat_0: MatZnxDft = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size); - let mut vmpmat_1: MatZnxDft = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size); - - let mut tmp_bytes: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, 1, vpmat_size)); - - for row_i in 0..vpmat_rows { - let mut source: Source = Source::new([0u8; 32]); - module.fill_uniform(log_base2k, &mut a, 0, vpmat_size, &mut source); - module.vec_znx_dft(&mut a_dft, 0, &a, 0); - module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes); - - // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) - module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i); - assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); - - // Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) - module.vmp_extract_row_dft(&mut b_dft, row_i, &vmpmat_0); - assert_eq!(a_dft.raw(), b_dft.raw()); - - // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) - module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); - module.vec_znx_idft(&mut a_big, 0, &a_dft, 0, &mut tmp_bytes); - assert_eq!(a_big.raw(), b_big.raw()); - } - - module.free(); - } -} diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs new file mode 100644 index 0000000..85177aa --- /dev/null +++ b/base2k/src/mat_znx_dft_ops.rs @@ -0,0 +1,536 @@ +use crate::ffi::vec_znx_big::vec_znx_big_t; +use crate::ffi::vec_znx_dft::vec_znx_dft_t; +use crate::ffi::vmp; +use crate::znx_base::{ZnxInfos, ZnxLayout}; +use crate::{Backend, FFT64, MatZnxDft, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxAlloc, assert_alignement}; + +/// This trait implements methods for vector matrix product, +/// that is, multiplying a [VecZnx] with a [MatZnxDft]. +pub trait MatZnxDftOps { + /// Allocates a new [MatZnxDft] with the given number of rows and columns. + /// + /// # Arguments + /// + /// * `rows`: number of rows (number of [VecZnxDft]). + /// * `size`: number of size (number of size of each [VecZnxDft]). + fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft; + + fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize; + + fn new_mat_znx_dft_from_bytes(&self, rows: usize, cols: usize, size: usize, bytes: Vec) -> MatZnxDft; + + fn new_mat_znx_dft_from_bytes_borrow(&self, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> MatZnxDft; + + /// Returns the number of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_contiguous]. + /// + /// # Arguments + /// + /// * `rows`: number of rows of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous]. + /// * `size`: number of size of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous]. + fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize; + + /// Prepares a [MatZnxDft] from a contiguous array of [i64]. + /// The helper struct [Matrix3D] can be used to contruct and populate + /// the appropriate contiguous array. + /// + /// # Arguments + /// + /// * `b`: [MatZnxDft] on which the values are encoded. + /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [MatZnxDft]. + /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. + fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft, a: &[i64], buf: &mut [u8]); + + /// Prepares the ith-row of [MatZnxDft] from a [VecZnx]. + /// + /// # Arguments + /// + /// * `b`: [MatZnxDft] on which the values are encoded. + /// * `a`: the vector of [VecZnx] to encode on the [MatZnxDft]. + /// * `row_i`: the index of the row to prepare. + /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. + /// + /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. + fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); + + /// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig]. + /// + /// # Arguments + /// + /// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft]. + /// * `a`: [MatZnxDft] on which the values are encoded. + /// * `row_i`: the index of the row to extract. + fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &MatZnxDft, row_i: usize); + + /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. + /// + /// # Arguments + /// + /// * `b`: [MatZnxDft] on which the values are encoded. + /// * `a`: the [VecZnxDft] to encode on the [MatZnxDft]. + /// * `row_i`: the index of the row to prepare. + /// + /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. + fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, a: &VecZnxDft, row_i: usize); + + /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft]. + /// + /// # Arguments + /// + /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. + /// * `a`: [MatZnxDft] on which the values are encoded. + /// * `row_i`: the index of the row to extract. + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, row_i: usize, a: &MatZnxDft); + + /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. + /// + /// # Arguments + /// + /// * `c_size`: number of size of the output [VecZnxDft]. + /// * `a_size`: number of size of the input [VecZnx]. + /// * `rows`: number of rows of the input [MatZnxDft]. + /// * `size`: number of size of the input [MatZnxDft]. + fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize; + + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. + /// + /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] + /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. + /// + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. + /// + /// If there is a mismatch between the dimensions the largest valid ones are used. + /// + /// ```text + /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| + /// |h i j| + /// |k l m| + /// ``` + /// where each element is a [VecZnxDft]. + /// + /// # Arguments + /// + /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. + /// * `a`: the left operand [VecZnx] of the vector matrix product. + /// * `b`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); + + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on the receiver. + /// + /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] + /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. + /// + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. + /// + /// If there is a mismatch between the dimensions the largest valid ones are used. + /// + /// ```text + /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| + /// |h i j| + /// |k l m| + /// ``` + /// where each element is a [VecZnxDft]. + /// + /// # Arguments + /// + /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. + /// * `a`: the left operand [VecZnx] of the vector matrix product. + /// * `b`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. + fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); + + /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. + /// + /// # Arguments + /// + /// * `c_size`: number of size of the output [VecZnxDft]. + /// * `a_size`: number of size of the input [VecZnxDft]. + /// * `rows`: number of rows of the input [MatZnxDft]. + /// * `size`: number of size of the input [MatZnxDft]. + fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; + + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. + /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// + /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] + /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. + /// + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. + /// + /// If there is a mismatch between the dimensions the largest valid ones are used. + /// + /// ```text + /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| + /// |h i j| + /// |k l m| + /// ``` + /// where each element is a [VecZnxDft]. + /// + /// # Arguments + /// + /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. + /// * `a`: the left operand [VecZnxDft] of the vector matrix product. + /// * `b`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. + fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); + + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on top of the receiver instead of overwritting it. + /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// + /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] + /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. + /// + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. + /// + /// If there is a mismatch between the dimensions the largest valid ones are used. + /// + /// ```text + /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| + /// |h i j| + /// |k l m| + /// ``` + /// where each element is a [VecZnxDft]. + /// + /// # Arguments + /// + /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. + /// * `a`: the left operand [VecZnxDft] of the vector matrix product. + /// * `b`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. + fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); + + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] in place. + /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// + /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] + /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. + /// + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. + /// + /// If there is a mismatch between the dimensions the largest valid ones are used. + /// + /// ```text + /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| + /// |h i j| + /// |k l m| + /// ``` + /// where each element is a [VecZnxDft]. + /// + /// # Arguments + /// + /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. + /// * `a`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. + fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, buf: &mut [u8]); +} + +impl MatZnxDftOps for Module { + fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft { + MatZnxDft::::new(self, rows, cols, size) + } + + fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize { + MatZnxDft::::bytes_of(self, rows, cols, size) + } + + fn new_mat_znx_dft_from_bytes(&self, rows: usize, cols: usize, size: usize, bytes: Vec) -> MatZnxDft { + MatZnxDft::::from_bytes(self, rows, cols, size, bytes) + } + + fn new_mat_znx_dft_from_bytes_borrow(&self, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> MatZnxDft { + MatZnxDft::::from_bytes_borrow(self, rows, cols, size, bytes) + } + + fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize { + unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize } + } + + fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft, a: &[i64], tmp_bytes: &mut [u8]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), b.n() * b.poly_count()); + assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size())); + assert_alignement(tmp_bytes.as_ptr()); + } + unsafe { + vmp::vmp_prepare_contiguous( + self.ptr, + b.as_mut_ptr() as *mut vmp::vmp_pmat_t, + a.as_ptr(), + b.rows() as u64, + (b.size() * b.cols()) as u64, + tmp_bytes.as_mut_ptr(), + ); + } + } + + fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), b.size() * self.n() * b.cols()); + assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size())); + assert_alignement(tmp_bytes.as_ptr()); + } + unsafe { + vmp::vmp_prepare_row( + self.ptr, + b.as_mut_ptr() as *mut vmp::vmp_pmat_t, + a.as_ptr(), + row_i as u64, + b.rows() as u64, + (b.size() * b.cols()) as u64, + tmp_bytes.as_mut_ptr(), + ); + } + } + + fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &MatZnxDft, row_i: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), b.n()); + assert_eq!(a.size(), b.size()); + assert_eq!(a.cols(), b.cols()); + } + unsafe { + vmp::vmp_extract_row( + self.ptr, + b.as_mut_ptr() as *mut vec_znx_big_t, + a.as_ptr() as *const vmp::vmp_pmat_t, + row_i as u64, + a.rows() as u64, + (a.size() * a.cols()) as u64, + ); + } + } + + fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, a: &VecZnxDft, row_i: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), b.n()); + assert_eq!(a.size(), b.size()); + } + unsafe { + vmp::vmp_prepare_row_dft( + self.ptr, + b.as_mut_ptr() as *mut vmp::vmp_pmat_t, + a.as_ptr() as *const vec_znx_dft_t, + row_i as u64, + b.rows() as u64, + b.size() as u64, + ); + } + } + + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, row_i: usize, a: &MatZnxDft) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), b.n()); + assert_eq!(a.size(), b.size()); + } + unsafe { + vmp::vmp_extract_row_dft( + self.ptr, + b.as_mut_ptr() as *mut vec_znx_dft_t, + a.as_ptr() as *const vmp::vmp_pmat_t, + row_i as u64, + a.rows() as u64, + a.size() as u64, + ); + } + } + + fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize { + unsafe { + vmp::vmp_apply_dft_tmp_bytes( + self.ptr, + res_size as u64, + a_size as u64, + b_rows as u64, + b_size as u64, + ) as usize + } + } + + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); + #[cfg(debug_assertions)] + { + assert_alignement(tmp_bytes.as_ptr()); + } + unsafe { + vmp::vmp_apply_dft( + self.ptr, + c.as_mut_ptr() as *mut vec_znx_dft_t, + c.size() as u64, + a.as_ptr(), + a.size() as u64, + (a.n() * a.cols()) as u64, + b.as_ptr() as *const vmp::vmp_pmat_t, + b.rows() as u64, + b.size() as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } + + fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); + #[cfg(debug_assertions)] + { + assert_alignement(tmp_bytes.as_ptr()); + } + unsafe { + vmp::vmp_apply_dft_add( + self.ptr, + c.as_mut_ptr() as *mut vec_znx_dft_t, + c.size() as u64, + a.as_ptr(), + a.size() as u64, + (a.n() * a.size()) as u64, + b.as_ptr() as *const vmp::vmp_pmat_t, + b.rows() as u64, + b.size() as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } + + fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize { + unsafe { + vmp::vmp_apply_dft_to_dft_tmp_bytes( + self.ptr, + res_size as u64, + a_size as u64, + gct_rows as u64, + gct_size as u64, + ) as usize + } + } + + fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, tmp_bytes: &mut [u8]) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); + #[cfg(debug_assertions)] + { + assert_alignement(tmp_bytes.as_ptr()); + } + unsafe { + vmp::vmp_apply_dft_to_dft( + self.ptr, + c.as_mut_ptr() as *mut vec_znx_dft_t, + c.size() as u64, + a.as_ptr() as *const vec_znx_dft_t, + a.size() as u64, + b.as_ptr() as *const vmp::vmp_pmat_t, + b.rows() as u64, + b.size() as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } + + fn vmp_apply_dft_to_dft_add( + &self, + c: &mut VecZnxDft, + a: &VecZnxDft, + b: &MatZnxDft, + tmp_bytes: &mut [u8], + ) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); + #[cfg(debug_assertions)] + { + assert_alignement(tmp_bytes.as_ptr()); + } + unsafe { + vmp::vmp_apply_dft_to_dft_add( + self.ptr, + c.as_mut_ptr() as *mut vec_znx_dft_t, + c.size() as u64, + a.as_ptr() as *const vec_znx_dft_t, + a.size() as u64, + b.as_ptr() as *const vmp::vmp_pmat_t, + b.rows() as u64, + b.size() as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } + + fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, tmp_bytes: &mut [u8]) { + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.size(), b.size(), a.rows(), a.size())); + #[cfg(debug_assertions)] + { + assert_alignement(tmp_bytes.as_ptr()); + } + unsafe { + vmp::vmp_apply_dft_to_dft( + self.ptr, + b.as_mut_ptr() as *mut vec_znx_dft_t, + b.size() as u64, + b.as_ptr() as *mut vec_znx_dft_t, + b.size() as u64, + a.as_ptr() as *const vmp::vmp_pmat_t, + a.rows() as u64, + a.size() as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, + alloc_aligned, znx_base::ZnxLayout, + }; + use sampling::source::Source; + + #[test] + fn vmp_prepare_row_dft() { + let module: Module = Module::::new(32); + let vpmat_rows: usize = 4; + let vpmat_size: usize = 5; + let log_base2k: usize = 8; + let mut a: VecZnx = module.new_vec_znx(1, vpmat_size); + let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_size); + let mut a_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); + let mut b_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); + let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_size); + let mut vmpmat_0: MatZnxDft = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size); + let mut vmpmat_1: MatZnxDft = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size); + + let mut tmp_bytes: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, 1, vpmat_size)); + + for row_i in 0..vpmat_rows { + let mut source: Source = Source::new([0u8; 32]); + module.fill_uniform(log_base2k, &mut a, 0, vpmat_size, &mut source); + module.vec_znx_dft(&mut a_dft, 0, &a, 0); + module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes); + + // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) + module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i); + assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); + + // Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) + module.vmp_extract_row_dft(&mut b_dft, row_i, &vmpmat_0); + assert_eq!(a_dft.raw(), b_dft.raw()); + + // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) + module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); + module.vec_znx_idft(&mut a_big, 0, &a_dft, 0, &mut tmp_bytes); + assert_eq!(a_big.raw(), b_big.raw()); + } + + module.free(); + } +} From ca5e6d46c9a7a2de2fd57151b2968281f0343b5e Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 1 May 2025 10:33:19 +0200 Subject: [PATCH 23/87] Applied discussed changes, everything working, but still to discuss --- base2k/.vscode/settings.json | 3 + base2k/examples/rlwe_encrypt.rs | 19 +- base2k/examples/vector_matrix_product.rs | 59 -- base2k/examples/vmp.rs | 78 +++ base2k/spqlios-arithmetic | 2 +- base2k/src/mat_znx_dft.rs | 52 +- base2k/src/mat_znx_dft_ops.rs | 697 +++++++++++++---------- base2k/src/scalar_znx_dft.rs | 3 +- base2k/src/vec_znx.rs | 37 +- base2k/src/vec_znx_big.rs | 15 +- base2k/src/vec_znx_big_ops.rs | 61 +- base2k/src/vec_znx_dft.rs | 13 +- base2k/src/vec_znx_dft_ops.rs | 74 ++- base2k/src/znx_base.rs | 105 ++-- 14 files changed, 710 insertions(+), 508 deletions(-) delete mode 100644 base2k/examples/vector_matrix_product.rs create mode 100644 base2k/examples/vmp.rs diff --git a/base2k/.vscode/settings.json b/base2k/.vscode/settings.json index eecbcdc..c38916e 100644 --- a/base2k/.vscode/settings.json +++ b/base2k/.vscode/settings.json @@ -4,5 +4,8 @@ "plaintext": false, "markdown": false, "scminput": false + }, + "files.associations": { + "random": "c" } } \ No newline at end of file diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 07fe1c6..2f08633 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -13,7 +13,8 @@ fn main() { let log_scale: usize = msg_size * log_base2k - 5; let module: Module = Module::::new(n); - let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); + let mut tmp_bytes_norm: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); + let mut tmp_bytes_dft = alloc_aligned(module.bytes_of_vec_znx_dft(1, ct_size)); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -38,9 +39,10 @@ fn main() { module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source); // Scratch space for DFT values - let mut buf_dft: VecZnxDft = module.new_vec_znx_dft( + let mut buf_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow( 1, // Number of columns ct.size(), // Number of polynomials per column + &mut tmp_bytes_dft, ); // Applies DFT(ct[1]) * DFT(s) @@ -68,7 +70,7 @@ fn main() { want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); m.encode_vec_i64(0, log_base2k, log_scale, &want, 4); - m.normalize(log_base2k, &mut carry); + m.normalize(log_base2k, 0, &mut tmp_bytes_norm); // m - BIG(ct[1] * s) module.vec_znx_big_sub_small_a_inplace( @@ -81,9 +83,12 @@ fn main() { // Normalizes back to VecZnx // ct[0] <- m - BIG(c1 * s) module.vec_znx_big_normalize( - log_base2k, &mut ct, 0, // Selects the first column of ct (ct[0]) - &buf_big, 0, // Selects the first column of buf_big - &mut carry, + log_base2k, + &mut ct, + 0, // Selects the first column of ct (ct[0]) + &buf_big, + 0, // Selects the first column of buf_big + &mut tmp_bytes_norm, ); // Add noise to ct[0] @@ -120,7 +125,7 @@ fn main() { // m + e <- BIG(ct[1] * s + ct[0]) let mut res: VecZnx = module.new_vec_znx(1, ct_size); - module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, &mut carry); + module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, &mut tmp_bytes_norm); // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs deleted file mode 100644 index e565be1..0000000 --- a/base2k/examples/vector_matrix_product.rs +++ /dev/null @@ -1,59 +0,0 @@ -use base2k::{ - Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, - ZnxInfos, ZnxLayout, alloc_aligned, -}; - -fn main() { - let log_n: i32 = 5; - let n: usize = 1 << log_n; - - let module: Module = Module::::new(n); - let log_base2k: usize = 15; - let limbs_vec: usize = 5; - let log_k: usize = log_base2k * limbs_vec - 5; - - let rows_mat: usize = limbs_vec; - let limbs_mat: usize = limbs_vec + 1; - - // Maximum size of the byte scratch needed - let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows_mat, 1, limbs_mat) - | module.vmp_apply_dft_tmp_bytes(limbs_vec, limbs_vec, rows_mat, limbs_mat); - - let mut buf: Vec = alloc_aligned(tmp_bytes); - - let mut a_values: Vec = vec![i64::default(); n]; - a_values[1] = (1 << log_base2k) + 1; - - let mut a: VecZnx = module.new_vec_znx(1, limbs_vec); - a.encode_vec_i64(0, log_base2k, log_k, &a_values, 32); - a.normalize(log_base2k, &mut buf); - - a.print(n); - println!(); - - let mut mat_znx_dft: MatZnxDft = module.new_mat_znx_dft(rows_mat, 1, limbs_mat); - - (0..a.size()).for_each(|row_i| { - let mut tmp: VecZnx = module.new_vec_znx(1, limbs_mat); - tmp.at_limb_mut(row_i)[1] = 1 as i64; - module.vmp_prepare_row(&mut mat_znx_dft, tmp.raw(), row_i, &mut buf); - }); - - let mut c_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs_mat); - module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut buf); - - let mut c_big: VecZnxBig = c_dft.alias_as_vec_znx_big(); - module.vec_znx_idft_tmp_a(&mut c_big, 0, &mut c_dft, 0); - - let mut res: VecZnx = module.new_vec_znx(1, limbs_vec); - module.vec_znx_big_normalize(log_base2k, &mut res, 0, &c_big, 0, &mut buf); - - let mut values_res: Vec = vec![i64::default(); n]; - res.decode_vec_i64(0, log_base2k, log_k, &mut values_res); - - res.print(n); - - module.free(); - - println!("{:?}", values_res) -} diff --git a/base2k/examples/vmp.rs b/base2k/examples/vmp.rs new file mode 100644 index 0000000..710744e --- /dev/null +++ b/base2k/examples/vmp.rs @@ -0,0 +1,78 @@ +use base2k::{ + Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, + ZnxInfos, ZnxLayout, alloc_aligned, +}; + +fn main() { + let log_n: i32 = 5; + let n: usize = 1 << log_n; + + let module: Module = Module::::new(n); + let log_base2k: usize = 15; + + let a_cols: usize = 2; + let a_size: usize = 5; + + let log_k: usize = log_base2k * a_size - 5; + + let mat_rows: usize = a_size; + let mat_cols_in: usize = a_cols; + let mat_cols_out: usize = 2; + let mat_size: usize = a_size + 1; + + let mut tmp_bytes_vmp: Vec = alloc_aligned( + module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) + | module.vmp_apply_dft_tmp_bytes( + a_size, + a_size, + mat_rows, + mat_cols_in, + mat_cols_out, + mat_size, + ), + ); + + let mut tmp_bytes_dft: Vec = alloc_aligned(module.bytes_of_vec_znx_dft(mat_cols_out, mat_size)); + + let mut a: VecZnx = module.new_vec_znx(a_cols, a_size); + + (0..a_cols).for_each(|i| { + let mut values: Vec = vec![i64::default(); n]; + values[1 + i] = (1 << log_base2k) + 1; + a.encode_vec_i64(i, log_base2k, log_k, &values, 32); + a.normalize(log_base2k, i, &mut tmp_bytes_vmp); + a.print(n, i); + println!(); + }); + + let mut mat_znx_dft: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + + (0..a.size()).for_each(|row_i| { + let mut tmp: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); + (0..mat_cols_out).for_each(|j| { + tmp.at_mut(j, row_i)[1 + j] = 1 as i64; + }); + (0..mat_cols_in).for_each(|j| { + module.vmp_prepare_row(&mut mat_znx_dft, row_i, j, &tmp, &mut tmp_bytes_vmp); + }) + }); + + let mut c_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(mat_cols_out, mat_size, &mut tmp_bytes_dft); + module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut tmp_bytes_vmp); + + let mut res: VecZnx = module.new_vec_znx(mat_cols_out, a_size); + let mut c_big: VecZnxBig = c_dft.alias_as_vec_znx_big(); + (0..mat_cols_out).for_each(|i| { + module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); + module.vec_znx_big_normalize(log_base2k, &mut res, i, &c_big, i, &mut tmp_bytes_vmp); + + let mut values_res: Vec = vec![i64::default(); n]; + res.decode_vec_i64(i, log_base2k, log_k, &mut values_res); + res.print(n, i); + println!(); + println!("{:?}", values_res); + println!(); + }); + + module.free(); +} diff --git a/base2k/spqlios-arithmetic b/base2k/spqlios-arithmetic index e3d3247..8135d85 160000 --- a/base2k/spqlios-arithmetic +++ b/base2k/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit e3d3247335faccf2b6361213c354cd61b958325e +Subproject commit 8135d85e7ac14601568fdd228e7dedf88994f7cf diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index 104bd4b..470adcc 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,4 +1,4 @@ -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; use crate::{Backend, FFT64, Module, alloc_aligned}; use std::marker::PhantomData; @@ -10,6 +10,8 @@ use std::marker::PhantomData; /// See the trait [MatZnxDftOps] for additional information. pub struct MatZnxDft { pub inner: ZnxBase, + pub cols_in: usize, + pub cols_out: usize, _marker: PhantomData, } @@ -35,18 +37,54 @@ impl ZnxLayout for MatZnxDft { type Scalar = f64; } -impl ZnxAlloc for MatZnxDft { - type Scalar = u8; +impl MatZnxDft { + pub fn new(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + let bytes: Vec = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); + Self::from_bytes(module, rows, cols_in, cols_out, size, bytes) + } - fn from_bytes_borrow(module: &Module, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize, mut bytes: Vec) -> Self { + let mut mat: MatZnxDft = Self::from_bytes_borrow(module, rows, cols_in, cols_out, size, &mut bytes); + mat.znx_mut().data = bytes; + mat + } + + pub fn from_bytes_borrow( + module: &Module, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: &mut [u8], + ) -> Self { + debug_assert_eq!( + bytes.len(), + Self::bytes_of(module, rows, cols_in, cols_out, size) + ); Self { - inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols, size, bytes), + inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes), + cols_in: cols_in, + cols_out: cols_out, _marker: PhantomData, } } - fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize { - unsafe { crate::ffi::vmp::bytes_of_vmp_pmat(module.ptr, rows as u64, size as u64) as usize * cols } + pub fn bytes_of(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + unsafe { + crate::ffi::vmp::bytes_of_vmp_pmat( + module.ptr, + (rows * cols_in) as u64, + (size * cols_out) as u64, + ) as usize + } + } + + pub fn cols_in(&self) -> usize { + self.cols_in + } + + pub fn cols_out(&self) -> usize { + self.cols_out } } diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 85177aa..48c3834 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -1,8 +1,9 @@ -use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp; use crate::znx_base::{ZnxInfos, ZnxLayout}; -use crate::{Backend, FFT64, MatZnxDft, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxAlloc, assert_alignement}; +use crate::{ + Backend, FFT64, MatZnxDft, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, assert_alignement, is_aligned, +}; /// This trait implements methods for vector matrix product, /// that is, multiplying a [VecZnx] with a [MatZnxDft]. @@ -13,44 +14,45 @@ pub trait MatZnxDftOps { /// /// * `rows`: number of rows (number of [VecZnxDft]). /// * `size`: number of size (number of size of each [VecZnxDft]). - fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft; + fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDft; - fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize; + fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; - fn new_mat_znx_dft_from_bytes(&self, rows: usize, cols: usize, size: usize, bytes: Vec) -> MatZnxDft; + fn new_mat_znx_dft_from_bytes( + &self, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: Vec, + ) -> MatZnxDft; - fn new_mat_znx_dft_from_bytes_borrow(&self, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> MatZnxDft; + fn new_mat_znx_dft_from_bytes_borrow( + &self, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: &mut [u8], + ) -> MatZnxDft; - /// Returns the number of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_contiguous]. - /// - /// # Arguments - /// - /// * `rows`: number of rows of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous]. - /// * `size`: number of size of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous]. - fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize; - - /// Prepares a [MatZnxDft] from a contiguous array of [i64]. - /// The helper struct [Matrix3D] can be used to contruct and populate - /// the appropriate contiguous array. - /// - /// # Arguments - /// - /// * `b`: [MatZnxDft] on which the values are encoded. - /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [MatZnxDft]. - /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft, a: &[i64], buf: &mut [u8]); + /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row] + fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; /// Prepares the ith-row of [MatZnxDft] from a [VecZnx]. /// /// # Arguments /// /// * `b`: [MatZnxDft] on which the values are encoded. - /// * `a`: the vector of [VecZnx] to encode on the [MatZnxDft]. - /// * `row_i`: the index of the row to prepare. + /// * `row_i`: the row of the [MatZnxDft] to prepare. + /// * `a`: the [VecZnx] to encode on the i-th row of the [MatZnxDft]. /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. /// /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); + fn vmp_prepare_row(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnx, tmp_bytes: &mut [u8]); + + /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row] + fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; /// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig]. /// @@ -59,7 +61,15 @@ pub trait MatZnxDftOps { /// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft]. /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &MatZnxDft, row_i: usize); + fn vmp_extract_row( + &self, + log_base2k: usize, + b: &mut VecZnx, + a: &MatZnxDft, + b_row: usize, + b_col_in: usize, + tmp_bytes: &mut [u8], + ); /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. /// @@ -70,7 +80,7 @@ pub trait MatZnxDftOps { /// * `row_i`: the index of the row to prepare. /// /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, a: &VecZnxDft, row_i: usize); + fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft); /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft]. /// @@ -79,7 +89,7 @@ pub trait MatZnxDftOps { /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, row_i: usize, a: &MatZnxDft); + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize); /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. /// @@ -89,7 +99,15 @@ pub trait MatZnxDftOps { /// * `a_size`: number of size of the input [VecZnx]. /// * `rows`: number of rows of the input [MatZnxDft]. /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize; + fn vmp_apply_dft_tmp_bytes( + &self, + c_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// @@ -117,32 +135,6 @@ pub trait MatZnxDftOps { /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on the receiver. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. - /// - /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and - /// `j` size, the output is a [VecZnx] of `j` size. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnx] of the vector matrix product. - /// * `b`: the right operand [MatZnxDft] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); - /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. /// /// # Arguments @@ -151,7 +143,17 @@ pub trait MatZnxDftOps { /// * `a_size`: number of size of the input [VecZnxDft]. /// * `rows`: number of rows of the input [MatZnxDft]. /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; + fn vmp_apply_dft_to_dft_tmp_bytes( + &self, + c_cols: usize, + c_size: usize, + a_cols: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -179,308 +181,385 @@ pub trait MatZnxDftOps { /// * `b`: the right operand [MatZnxDft] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); - - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on top of the receiver instead of overwritting it. - /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. - /// - /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and - /// `j` size, the output is a [VecZnx] of `j` size. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnxDft] of the vector matrix product. - /// * `b`: the right operand [MatZnxDft] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); - - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] in place. - /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. - /// - /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and - /// `j` size, the output is a [VecZnx] of `j` size. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. - /// * `a`: the right operand [MatZnxDft] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, buf: &mut [u8]); } impl MatZnxDftOps for Module { - fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft { - MatZnxDft::::new(self, rows, cols, size) + fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDft { + MatZnxDft::::new(self, rows, cols_in, cols_out, size) } - fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize { - MatZnxDft::::bytes_of(self, rows, cols, size) + fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + MatZnxDft::::bytes_of(self, rows, cols_in, cols_out, size) } - fn new_mat_znx_dft_from_bytes(&self, rows: usize, cols: usize, size: usize, bytes: Vec) -> MatZnxDft { - MatZnxDft::::from_bytes(self, rows, cols, size, bytes) + fn new_mat_znx_dft_from_bytes( + &self, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: Vec, + ) -> MatZnxDft { + MatZnxDft::::from_bytes(self, rows, cols_in, cols_out, size, bytes) } - fn new_mat_znx_dft_from_bytes_borrow(&self, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> MatZnxDft { - MatZnxDft::::from_bytes_borrow(self, rows, cols, size, bytes) + fn new_mat_znx_dft_from_bytes_borrow( + &self, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: &mut [u8], + ) -> MatZnxDft { + MatZnxDft::::from_bytes_borrow(self, rows, cols_in, cols_out, size, bytes) } - fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize { - unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize } + fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { + self.bytes_of_vec_znx_dft(cols_out, size) } - fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft, a: &[i64], tmp_bytes: &mut [u8]) { + fn vmp_prepare_row(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnx, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] { - assert_eq!(a.len(), b.n() * b.poly_count()); - assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size())); - assert_alignement(tmp_bytes.as_ptr()); + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + a.cols(), + b.cols_out(), + "a.cols(): {} != b.cols_out(): {}", + a.cols(), + b.cols_out() + ); + assert!( + b_row < b.rows(), + "b_row: {} >= b.rows(): {}", + b_row, + b.rows() + ); + assert!( + b_col_in < b.cols_in(), + "b_col_in: {} >= b.cols_in(): {}", + b_col_in, + b.cols_in() + ); + assert_eq!( + b.size(), + a.size(), + "b.size(): {} != a.size(): {}", + b.size(), + a.size() + ); + assert!(tmp_bytes.len() >= self.vmp_prepare_row_tmp_bytes(a.cols(), a.size())); + assert!(is_aligned(tmp_bytes.as_ptr())) } - unsafe { - vmp::vmp_prepare_contiguous( - self.ptr, - b.as_mut_ptr() as *mut vmp::vmp_pmat_t, - a.as_ptr(), - b.rows() as u64, - (b.size() * b.cols()) as u64, - tmp_bytes.as_mut_ptr(), + + let cols_out: usize = a.cols(); + let a_size: usize = a.size(); + + let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size)); + + let mut a_dft: VecZnxDft = self.new_vec_znx_dft_from_bytes_borrow(cols_out, a_size, tmp_bytes_a_dft); + (0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i)); + + Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft); + } + + fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { + self.bytes_of_vec_znx_dft(cols_out, size) + self.vec_znx_big_normalize_tmp_bytes() + } + + fn vmp_extract_row( + &self, + log_base2k: usize, + b: &mut VecZnx, + a: &MatZnxDft, + a_row: usize, + a_col_in: usize, + tmp_bytes: &mut [u8], + ) { + #[cfg(debug_assertions)] + { + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + b.cols(), + a.cols_out(), + "b.cols(): {} != a.cols_out(): {}", + b.cols(), + a.cols_out() + ); + assert!( + a_row < a.rows(), + "a_row: {} >= a.rows(): {}", + a_row, + a.rows() + ); + assert!( + a_col_in < a.cols_in(), + "a_col_in: {} >= a.cols_in(): {}", + a_col_in, + a.cols_in() + ); + assert_eq!( + b.size(), + a.size(), + "b.size(): {} != a.size(): {}", + b.size(), + a.size() + ); + assert!(tmp_bytes.len() >= self.vmp_extract_row_tmp_bytes(a.cols(), a.size())); + assert!(is_aligned(tmp_bytes.as_ptr())) + } + + let cols_out: usize = b.cols(); + let size: usize = b.size(); + + let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size)); + let mut b_dft: VecZnxDft = self.new_vec_znx_dft_from_bytes_borrow(cols_out, size, bytes_a_dft); + Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in); + let mut b_big: VecZnxBig = b_dft.alias_as_vec_znx_big(); + (0..cols_out).for_each(|i| { + self.vec_znx_idft_tmp_a(&mut b_big, i, &mut b_dft, i); + self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, tmp_bytes); + }); + } + + fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft) { + #[cfg(debug_assertions)] + { + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + a.cols(), + b.cols_out(), + "a.cols(): {} != b.cols_out(): {}", + a.cols(), + b.cols_out() + ); + assert!( + b_row < b.rows(), + "b_row: {} >= b.rows(): {}", + b_row, + b.rows() + ); + assert!( + b_col_in < b.cols_in(), + "b_col_in: {} >= b.cols_in(): {}", + b_col_in, + b.cols_in() + ); + assert_eq!( + b.size(), + a.size(), + "b.size(): {} != a.size(): {}", + b.size(), + a.size() ); } - } - fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { - #[cfg(debug_assertions)] - { - assert_eq!(a.len(), b.size() * self.n() * b.cols()); - assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size())); - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_prepare_row( - self.ptr, - b.as_mut_ptr() as *mut vmp::vmp_pmat_t, - a.as_ptr(), - row_i as u64, - b.rows() as u64, - (b.size() * b.cols()) as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &MatZnxDft, row_i: usize) { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), b.n()); - assert_eq!(a.size(), b.size()); - assert_eq!(a.cols(), b.cols()); - } - unsafe { - vmp::vmp_extract_row( - self.ptr, - b.as_mut_ptr() as *mut vec_znx_big_t, - a.as_ptr() as *const vmp::vmp_pmat_t, - row_i as u64, - a.rows() as u64, - (a.size() * a.cols()) as u64, - ); - } - } - - fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, a: &VecZnxDft, row_i: usize) { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), b.n()); - assert_eq!(a.size(), b.size()); - } unsafe { vmp::vmp_prepare_row_dft( self.ptr, b.as_mut_ptr() as *mut vmp::vmp_pmat_t, a.as_ptr() as *const vec_znx_dft_t, - row_i as u64, - b.rows() as u64, - b.size() as u64, + (b_row * b.cols_in() + b_col_in) as u64, + (b.rows() * b.cols_in()) as u64, + (b.size() * b.cols_out()) as u64, ); } } - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, row_i: usize, a: &MatZnxDft) { + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize) { #[cfg(debug_assertions)] { - assert_eq!(a.n(), b.n()); - assert_eq!(a.size(), b.size()); + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + b.cols(), + a.cols_out(), + "b.cols(): {} != a.cols_out(): {}", + b.cols(), + a.cols_out() + ); + assert!( + a_row < a.rows(), + "a_row: {} >= a.rows(): {}", + a_row, + a.rows() + ); + assert!( + a_col_in < a.cols_in(), + "a_col_in: {} >= a.cols_in(): {}", + a_col_in, + a.cols_in() + ); + assert_eq!( + b.size(), + a.size(), + "b.size(): {} != a.size(): {}", + b.size(), + a.size() + ); } unsafe { vmp::vmp_extract_row_dft( self.ptr, b.as_mut_ptr() as *mut vec_znx_dft_t, a.as_ptr() as *const vmp::vmp_pmat_t, - row_i as u64, - a.rows() as u64, - a.size() as u64, + (a_row * a.cols_in() + a_col_in) as u64, + (a.rows() * a.cols_in()) as u64, + (a.size() * a.cols_out()) as u64, ); } } - fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize { + fn vmp_apply_dft_tmp_bytes( + &self, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize { unsafe { vmp::vmp_apply_dft_tmp_bytes( self.ptr, res_size as u64, a_size as u64, - b_rows as u64, - b_size as u64, + (b_rows * b_cols_in) as u64, + (b_size * b_cols_out) as u64, ) as usize } } fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); + debug_assert!( + tmp_bytes.len() + >= self.vmp_apply_dft_tmp_bytes( + c.size(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size() + ) + ); #[cfg(debug_assertions)] { + assert_eq!(c.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + c.cols(), + b.cols_out(), + "c.cols(): {} != b.cols_out: {}", + c.cols(), + b.cols_out() + ); + assert_eq!( + a.cols(), + b.cols_in(), + "a.cols(): {} != b.cols_in: {}", + a.cols(), + b.cols_in() + ); + assert!( + tmp_bytes.len() + >= self.vmp_apply_dft_tmp_bytes( + c.size(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size() + ) + ); assert_alignement(tmp_bytes.as_ptr()); } unsafe { vmp::vmp_apply_dft( self.ptr, c.as_mut_ptr() as *mut vec_znx_dft_t, - c.size() as u64, + (c.size() * c.cols()) as u64, a.as_ptr(), - a.size() as u64, - (a.n() * a.cols()) as u64, + (a.size() * a.cols()) as u64, + a.n() as u64, b.as_ptr() as *const vmp::vmp_pmat_t, - b.rows() as u64, - b.size() as u64, + (b.rows() * b.cols_in()) as u64, + (b.size() * b.cols_out()) as u64, tmp_bytes.as_mut_ptr(), ) } } - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft_add( - self.ptr, - c.as_mut_ptr() as *mut vec_znx_dft_t, - c.size() as u64, - a.as_ptr(), - a.size() as u64, - (a.n() * a.size()) as u64, - b.as_ptr() as *const vmp::vmp_pmat_t, - b.rows() as u64, - b.size() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize { + fn vmp_apply_dft_to_dft_tmp_bytes( + &self, + res_cols: usize, + res_size: usize, + a_size: usize, + a_cols: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize { unsafe { vmp::vmp_apply_dft_to_dft_tmp_bytes( self.ptr, - res_size as u64, - a_size as u64, - gct_rows as u64, - gct_size as u64, + (res_size * res_cols) as u64, + (a_size * a_cols) as u64, + (b_rows * b_cols_in) as u64, + (b_size * b_cols_out) as u64, ) as usize } } fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); #[cfg(debug_assertions)] { + assert_eq!(c.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + c.cols(), + b.cols_out(), + "c.cols(): {} != b.cols_out: {}", + c.cols(), + b.cols_out() + ); + assert_eq!( + a.cols(), + b.cols_in(), + "a.cols(): {} != b.cols_in: {}", + a.cols(), + b.cols_in() + ); + assert!( + tmp_bytes.len() + >= self.vmp_apply_dft_to_dft_tmp_bytes( + c.cols(), + c.size(), + a.cols(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size() + ) + ); assert_alignement(tmp_bytes.as_ptr()); } unsafe { vmp::vmp_apply_dft_to_dft( self.ptr, c.as_mut_ptr() as *mut vec_znx_dft_t, - c.size() as u64, + c.poly_count() as u64, a.as_ptr() as *const vec_znx_dft_t, - a.size() as u64, + a.poly_count() as u64, b.as_ptr() as *const vmp::vmp_pmat_t, b.rows() as u64, - b.size() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_dft_to_dft_add( - &self, - c: &mut VecZnxDft, - a: &VecZnxDft, - b: &MatZnxDft, - tmp_bytes: &mut [u8], - ) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft_to_dft_add( - self.ptr, - c.as_mut_ptr() as *mut vec_znx_dft_t, - c.size() as u64, - a.as_ptr() as *const vec_znx_dft_t, - a.size() as u64, - b.as_ptr() as *const vmp::vmp_pmat_t, - b.rows() as u64, - b.size() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.size(), b.size(), a.rows(), a.size())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft_to_dft( - self.ptr, - b.as_mut_ptr() as *mut vec_znx_dft_t, - b.size() as u64, - b.as_ptr() as *mut vec_znx_dft_t, - b.size() as u64, - a.as_ptr() as *const vmp::vmp_pmat_t, - a.rows() as u64, - a.size() as u64, + (b.size() * b.cols()) as u64, tmp_bytes.as_mut_ptr(), ) } @@ -497,38 +576,52 @@ mod tests { #[test] fn vmp_prepare_row_dft() { - let module: Module = Module::::new(32); - let vpmat_rows: usize = 4; - let vpmat_size: usize = 5; + let module: Module = Module::::new(16); let log_base2k: usize = 8; - let mut a: VecZnx = module.new_vec_znx(1, vpmat_size); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_size); - let mut a_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); - let mut b_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_size); - let mut vmpmat_0: MatZnxDft = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size); - let mut vmpmat_1: MatZnxDft = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size); + let mat_rows: usize = 4; + let mat_cols_in: usize = 2; + let mat_cols_out: usize = 2; + let mat_size: usize = 5; + let mut a: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); + let mut b: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); + let mut a_dft: VecZnxDft = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut a_big: VecZnxBig = module.new_vec_znx_big(mat_cols_out, mat_size); + let mut b_dft: VecZnxDft = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut vmpmat_0: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + let mut vmpmat_1: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - let mut tmp_bytes: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, 1, vpmat_size)); + let mut tmp_bytes: Vec = + alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes()); - for row_i in 0..vpmat_rows { - let mut source: Source = Source::new([0u8; 32]); - module.fill_uniform(log_base2k, &mut a, 0, vpmat_size, &mut source); - module.vec_znx_dft(&mut a_dft, 0, &a, 0); - module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes); + for col_in in 0..mat_cols_in { + for row_i in 0..mat_rows { + let mut source: Source = Source::new([0u8; 32]); - // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) - module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i); - assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); + (0..mat_cols_out).for_each(|col_out| { + module.fill_uniform(log_base2k, &mut a, col_out, mat_size, &mut source); + module.vec_znx_dft(&mut a_dft, col_out, &a, col_out); + }); - // Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) - module.vmp_extract_row_dft(&mut b_dft, row_i, &vmpmat_0); - assert_eq!(a_dft.raw(), b_dft.raw()); + module.vmp_prepare_row(&mut vmpmat_0, row_i, col_in, &a, &mut tmp_bytes); - // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) - module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); - module.vec_znx_idft(&mut a_big, 0, &a_dft, 0, &mut tmp_bytes); - assert_eq!(a_big.raw(), b_big.raw()); + // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) + module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft); + assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); + + // Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) + module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i, col_in); + assert_eq!(a_dft.raw(), b_dft.raw()); + + // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) + module.vmp_extract_row(log_base2k, &mut b, &vmpmat_0, row_i, col_in, &mut tmp_bytes); + + (0..mat_cols_out).for_each(|col_out| { + module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes); + module.vec_znx_big_normalize(log_base2k, &mut a, col_out, &a_big, col_out, &mut tmp_bytes); + }); + + assert_eq!(a.raw(), b.raw()); + } } module.free(); diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index ffb54b5..6fdb991 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -28,6 +28,7 @@ impl ZnxAlloc for ScalarZnxDft { type Scalar = u8; fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { + debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, _size)); Self { inner: ZnxBase::from_bytes_borrow( module.n(), @@ -61,6 +62,6 @@ impl ZnxLayout for ScalarZnxDft { impl ZnxSliceSize for ScalarZnxDft { fn sl(&self) -> usize { - self.n() + self.n() * self.cols() } } diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 125f32e..544c096 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -3,7 +3,7 @@ use crate::Module; use crate::assert_alignement; use crate::cast_mut; use crate::ffi::znx; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxRsh, ZnxSliceSize, ZnxZero, switch_degree}; use std::cmp::min; pub const VEC_ZNX_ROWS: usize = 1; @@ -44,7 +44,9 @@ impl ZnxLayout for VecZnx { type Scalar = i64; } -impl ZnxBasics for VecZnx {} +impl ZnxZero for VecZnx {} + +impl ZnxRsh for VecZnx {} impl ZnxAlloc for VecZnx { type Scalar = i64; @@ -84,7 +86,7 @@ impl VecZnx { /// /// * `log_base2k`: the base two logarithm of the coefficients decomposition. /// * `k`: the number of bits of precision to drop. - pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize) { + pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize, col: usize) { if k == 0 { return; } @@ -101,7 +103,7 @@ impl VecZnx { if k_rem != 0 { let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem; - self.at_limb_mut(self.size() - 1) + self.at_mut(col, self.size() - 1) .iter_mut() .for_each(|x: &mut i64| *x &= mask) } @@ -111,8 +113,8 @@ impl VecZnx { copy_vec_znx_from(self, a); } - pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { - normalize(log_base2k, self, carry) + pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) { + normalize(log_base2k, self, col, carry) } pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) { @@ -120,26 +122,25 @@ impl VecZnx { } // Prints the first `n` coefficients of each limb - pub fn print(&self, n: usize) { - (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])) + pub fn print(&self, n: usize, col: usize) { + (0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n])); } } -fn normalize_tmp_bytes(n: usize, size: usize) -> usize { - n * size * std::mem::size_of::() +fn normalize_tmp_bytes(n: usize) -> usize { + n * std::mem::size_of::() } -fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { +fn normalize(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); - let cols: usize = a.cols(); debug_assert!( - tmp_bytes.len() >= normalize_tmp_bytes(n, cols), - "invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({}, {})", + tmp_bytes.len() >= normalize_tmp_bytes(n), + "invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({})", tmp_bytes.len(), n, - cols, ); + #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()) @@ -151,11 +152,11 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); (0..a.size()).rev().for_each(|i| { znx::znx_normalize( - (n * cols) as u64, + n as u64, log_base2k as u64, - a.at_mut_ptr(0, i), + a.at_mut_ptr(a_col, i), carry_i64.as_mut_ptr(), - a.at_mut_ptr(0, i), + a.at_mut_ptr(a_col, i), carry_i64.as_mut_ptr(), ) }); diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index cbcd4b9..5ba7dde 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,5 +1,5 @@ use crate::ffi::vec_znx_big; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize, ZnxZero}; use crate::{Backend, FFT64, Module, NTT120}; use std::marker::PhantomData; @@ -26,6 +26,7 @@ impl ZnxAlloc for VecZnxBig { type Scalar = u8; fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); VecZnxBig { inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_BIG_ROWS, cols, size, bytes), _marker: PhantomData, @@ -50,24 +51,24 @@ impl ZnxLayout for VecZnxBig { type Scalar = i128; } -impl ZnxBasics for VecZnxBig {} +impl ZnxZero for VecZnxBig {} impl ZnxSliceSize for VecZnxBig { fn sl(&self) -> usize { - self.n() + self.n() * self.cols() } } impl ZnxSliceSize for VecZnxBig { fn sl(&self) -> usize { - self.n() * 4 + self.n() * 4 * self.cols() } } -impl ZnxBasics for VecZnxBig {} +impl ZnxZero for VecZnxBig {} impl VecZnxBig { - pub fn print(&self, n: usize) { - (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); + pub fn print(&self, n: usize, col: usize) { + (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); } } diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 9c6feee..8be526e 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,4 +1,4 @@ -use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; +use crate::ffi::vec_znx; use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize}; use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, assert_alignement}; @@ -171,14 +171,17 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx_big::vec_znx_big_add( + vec_znx::vec_znx_add( self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, + res.at_mut_ptr(res_col, 0), res.size() as u64, - a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t, + res.sl() as u64, + a.at_ptr(a_col, 0), a.size() as u64, - b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t, + a.sl() as u64, + b.at_ptr(b_col, 0), b.size() as u64, + b.sl() as u64, ) } } @@ -207,14 +210,17 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx_big::vec_znx_big_sub( + vec_znx::vec_znx_sub( self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, + res.at_mut_ptr(res_col, 0), res.size() as u64, - a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t, + res.sl() as u64, + a.at_ptr(a_col, 0), a.size() as u64, - b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t, + a.sl() as u64, + b.at_ptr(b_col, 0), b.size() as u64, + b.sl() as u64, ) } } @@ -250,12 +256,14 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx_big::vec_znx_big_sub_small_b( + vec_znx::vec_znx_sub( self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, + res.at_mut_ptr(res_col, 0), res.size() as u64, - a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, + res.sl() as u64, + a.at_ptr(a_col, 0), a.size() as u64, + a.sl() as u64, b.at_ptr(b_col, 0), b.size() as u64, b.sl() as u64, @@ -287,15 +295,17 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx_big::vec_znx_big_sub_small_a( + vec_znx::vec_znx_sub( self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, + res.at_mut_ptr(res_col, 0), res.size() as u64, + res.sl() as u64, a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, - b.at_ptr(b_col * b.size(), 0) as *const vec_znx_big_t, + b.at_ptr(b_col, 0), b.size() as u64, + b.sl() as u64, ) } } @@ -324,12 +334,14 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx_big::vec_znx_big_add_small( + vec_znx::vec_znx_add( self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, + res.at_mut_ptr(res_col, 0), res.size() as u64, - a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, + res.sl() as u64, + a.at_ptr(a_col, 0), a.size() as u64, + a.sl() as u64, b.at_ptr(b_col, 0), b.size() as u64, b.sl() as u64, @@ -365,14 +377,15 @@ impl VecZnxBigOps for Module { assert_alignement(tmp_bytes.as_ptr()); } unsafe { - vec_znx_big::vec_znx_big_normalize_base2k( + vec_znx::vec_znx_normalize_base2k( self.ptr, log_base2k as u64, res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, + a.at_ptr(a_col, 0), a.size() as u64, + a.sl() as u64, tmp_bytes.as_mut_ptr(), ); } @@ -385,13 +398,15 @@ impl VecZnxBigOps for Module { assert_eq!(res.n(), self.n()); } unsafe { - vec_znx_big::vec_znx_big_automorphism( + vec_znx::vec_znx_automorphism( self.ptr, k, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, + res.at_mut_ptr(res_col, 0), res.size() as u64, - a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, + res.sl() as u64, + a.at_ptr(a_col, 0), a.size() as u64, + a.sl() as u64, ) } } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index a9dd378..b187645 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,5 +1,5 @@ use crate::ffi::vec_znx_dft; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize, ZnxZero}; use crate::{Backend, FFT64, Module, VecZnxBig}; use std::marker::PhantomData; @@ -26,6 +26,7 @@ impl ZnxAlloc for VecZnxDft { type Scalar = u8; fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); Self { inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes), _marker: PhantomData, @@ -46,6 +47,8 @@ impl ZnxLayout for VecZnxDft { type Scalar = f64; } +impl ZnxZero for VecZnxDft {} + impl ZnxSliceSize for VecZnxDft { fn sl(&self) -> usize { self.n() @@ -53,8 +56,8 @@ impl ZnxSliceSize for VecZnxDft { } impl VecZnxDft { - pub fn print(&self, n: usize) { - (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); + pub fn print(&self, n: usize, col: usize) { + (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); } } @@ -63,6 +66,10 @@ impl VecZnxDft { /// The returned [VecZnxBig] shares the backing array /// with the original [VecZnxDft]. pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig { + assert!( + self.data().len() == 0, + "cannot alias VecZnxDft into VecZnxBig if it owns the data" + ); VecZnxBig:: { inner: ZnxBase { data: Vec::new(), diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 57b3777..679abce 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -4,7 +4,8 @@ use crate::znx_base::ZnxAlloc; use crate::znx_base::ZnxInfos; use crate::znx_base::ZnxLayout; use crate::znx_base::ZnxSliceSize; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxZero, assert_alignement}; +use std::cmp::min; pub trait VecZnxDftOps { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. @@ -77,19 +78,21 @@ impl VecZnxDftOps for Module { } fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig, res_col: usize, a: &mut VecZnxDft, a_col: usize) { - #[cfg(debug_assertions)] - { - assert_eq!(res.poly_count(), a.poly_count()); - } + let min_size: usize = min(res.size(), a.size()); unsafe { - vec_znx_dft::vec_znx_idft_tmp_a( - self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t, - res.size() as u64, - a.at_ptr(a_col * a.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t, - a.size() as u64, - ) + (0..min_size).for_each(|j| { + vec_znx_dft::vec_znx_idft_tmp_a( + self.ptr, + res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, + 1 as u64, + a.at_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + ) + }); + (min_size..res.size()).for_each(|j| { + res.zero_at(res_col, j); + }) } } @@ -102,15 +105,22 @@ impl VecZnxDftOps for Module { /// # Panics /// If b.cols < a_cols fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize) { + let min_size: usize = min(res.size(), a.size()); + unsafe { - vec_znx_dft::vec_znx_dft( - self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t, - res.size() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) + (0..min_size).for_each(|j| { + vec_znx_dft::vec_znx_dft( + self.ptr, + res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + a.at_ptr(a_col, j), + 1 as u64, + a.sl() as u64, + ) + }); + (min_size..res.size()).for_each(|j| { + res.zero_at(res_col, j); + }); } } @@ -126,15 +136,23 @@ impl VecZnxDftOps for Module { ); assert_alignement(tmp_bytes.as_ptr()) } + + let min_size: usize = min(res.size(), a.size()); + unsafe { - vec_znx_dft::vec_znx_idft( - self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t, - res.size() as u64, - a.at_ptr(a_col * res.size(), 0) as *const vec_znx_dft::vec_znx_dft_t, - a.size() as u64, - tmp_bytes.as_mut_ptr(), - ) + (0..min_size).for_each(|j| { + vec_znx_dft::vec_znx_idft( + self.ptr, + res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, + 1 as u64, + a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1 as u64, + tmp_bytes.as_mut_ptr(), + ) + }); + (min_size..res.size()).for_each(|j| { + res.zero_at(res_col, j); + }); } } } diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index 64ad85f..4cacb70 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -22,6 +22,33 @@ pub struct ZnxBase { pub ptr: *mut u8, } +impl ZnxBase { + pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { + let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes); + res.data = bytes; + res + } + + pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + #[cfg(debug_assertions)] + { + assert_eq!(n & (n - 1), 0, "n must be a power of two"); + assert!(n > 0, "n must be greater than 0"); + assert!(rows > 0, "rows must be greater than 0"); + assert!(cols > 0, "cols must be greater than 0"); + assert!(size > 0, "size must be greater than 0"); + } + Self { + n: n, + rows: rows, + cols: cols, + size: size, + data: Vec::new(), + ptr: bytes.as_mut_ptr(), + } + } +} + pub trait GetZnxBase { fn znx(&self) -> &ZnxBase; fn znx_mut(&mut self) -> &mut ZnxBase; @@ -52,10 +79,12 @@ pub trait ZnxInfos: GetZnxBase { self.znx().size } + /// Returns the underlying raw bytes array. fn data(&self) -> &[u8] { &self.znx().data } + /// Returns a pointer to the underlying raw bytes array. fn ptr(&self) -> *mut u8 { self.znx().ptr } @@ -72,33 +101,6 @@ pub trait ZnxSliceSize { fn sl(&self) -> usize; } -impl ZnxBase { - pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { - let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes); - res.data = bytes; - res - } - - pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert_eq!(n & (n - 1), 0, "n must be a power of two"); - assert!(n > 0, "n must be greater than 0"); - assert!(rows > 0, "rows must be greater than 0"); - assert!(cols > 0, "cols must be greater than 0"); - assert!(size > 0, "size must be greater than 0"); - } - Self { - n: n, - rows: rows, - cols: cols, - size: size, - data: Vec::new(), - ptr: bytes.as_mut_ptr(), - } - } -} - pub trait ZnxAlloc where Self: Sized + ZnxInfos, @@ -148,25 +150,25 @@ pub trait ZnxLayout: ZnxInfos { unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } } - /// Returns a non-mutable pointer starting at the (i, j)-th small polynomial. + /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column. fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { #[cfg(debug_assertions)] { assert!(i < self.cols()); assert!(j < self.size()); } - let offset = self.n() * (j * self.cols() + i); + let offset: usize = self.n() * (j * self.cols() + i); unsafe { self.as_ptr().add(offset) } } - /// Returns a mutable pointer starting at the (i, j)-th small polynomial. + /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column. fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { #[cfg(debug_assertions)] { assert!(i < self.cols()); assert!(j < self.size()); } - let offset = self.n() * (j * self.cols() + i); + let offset: usize = self.n() * (j * self.cols() + i); unsafe { self.as_mut_ptr().add(offset) } } @@ -179,16 +181,6 @@ pub trait ZnxLayout: ZnxInfos { fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } } - - /// Returns non-mutable reference to the i-th limb. - fn at_limb(&self, j: usize) -> &[Self::Scalar] { - unsafe { std::slice::from_raw_parts(self.at_ptr(0, j), self.n() * self.cols()) } - } - - /// Returns mutable reference to the i-th limb. - fn at_limb_mut(&mut self, j: usize) -> &mut [Self::Scalar] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, j), self.n() * self.cols()) } - } } use std::convert::TryFrom; @@ -221,14 +213,17 @@ impl IntegerType for i128 { const BITS: u32 = 128; } -pub trait ZnxBasics: ZnxLayout +pub trait ZnxZero: ZnxLayout where Self: Sized, - Self::Scalar: IntegerType, { fn zero(&mut self) { unsafe { - std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * size_of::()); + std::ptr::write_bytes( + self.as_mut_ptr(), + 0, + self.n() * size_of::() * self.poly_count(), + ); } } @@ -241,13 +236,19 @@ where ); } } +} - fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { - rsh(log_base2k, self, k, carry) +pub trait ZnxRsh: ZnxLayout + ZnxZero +where + Self: Sized, + Self::Scalar: IntegerType, +{ + fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) { + rsh(k, log_base2k, self, col, carry) } } -pub fn rsh(log_base2k: usize, a: &mut V, k: usize, tmp_bytes: &mut [u8]) +pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) where V::Scalar: IntegerType, { @@ -258,7 +259,7 @@ where #[cfg(debug_assertions)] { assert!( - tmp_bytes.len() >= rsh_tmp_bytes::(n, cols), + tmp_bytes.len() >= rsh_tmp_bytes::(n), "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", tmp_bytes.len() / size_of::(), n, @@ -291,7 +292,7 @@ where let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap(); (steps..size).for_each(|i| { - izip!(carry.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| { + izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| { *xi += *ci << log_base2k_t; *ci = get_base_k_carry(*xi, shift); *xi = (*xi - *ci) >> k_rem_t; @@ -305,11 +306,11 @@ fn get_base_k_carry(x: T, shift: T) -> T { (x << shift) >> shift } -pub fn rsh_tmp_bytes(n: usize, cols: usize) -> usize { - n * cols * std::mem::size_of::() +pub fn rsh_tmp_bytes(n: usize) -> usize { + n * std::mem::size_of::() } -pub fn switch_degree(b: &mut T, col_b: usize, a: &T, col_a: usize) +pub fn switch_degree(b: &mut T, col_b: usize, a: &T, col_a: usize) where ::Scalar: IntegerType, { From 3ed6fa8ab576fe246ced046bc342d4ae141e3b39 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Fri, 2 May 2025 20:49:04 +0530 Subject: [PATCH 24/87] wip --- base2k/spqlios-arithmetic | 2 +- base2k/src/vec_znx.rs | 223 ++++++++++++++++++++++---------- base2k/src/vec_znx_big.rs | 117 +++++++++-------- base2k/src/vec_znx_big_ops.rs | 230 ++++++++++++++++++++-------------- base2k/src/vec_znx_dft.rs | 172 ++++++++++++++++--------- base2k/src/vec_znx_dft_ops.rs | 95 +++++++++----- base2k/src/vec_znx_ops.rs | 209 ++++++++++++++++++------------ base2k/src/znx_base.rs | 165 +++++++++++++++--------- 8 files changed, 770 insertions(+), 443 deletions(-) diff --git a/base2k/spqlios-arithmetic b/base2k/spqlios-arithmetic index 8135d85..e3d3247 160000 --- a/base2k/spqlios-arithmetic +++ b/base2k/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit 8135d85e7ac14601568fdd228e7dedf88994f7cf +Subproject commit e3d3247335faccf2b6361213c354cd61b958325e diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 544c096..b76f93d 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,12 +1,16 @@ use crate::Backend; +use crate::DataView; +use crate::DataViewMut; use crate::Module; +use crate::ZnxView; +use crate::alloc_aligned; use crate::assert_alignement; use crate::cast_mut; use crate::ffi::znx; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxRsh, ZnxSliceSize, ZnxZero, switch_degree}; -use std::cmp::min; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxRsh, ZnxZero, switch_degree}; +use std::{cmp::min, fmt}; -pub const VEC_ZNX_ROWS: usize = 1; +// pub const VEC_ZNX_ROWS: usize = 1; /// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of /// Zn\[X\] with [i64] coefficients. @@ -18,68 +22,57 @@ pub const VEC_ZNX_ROWS: usize = 1; /// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory /// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci /// are small polynomials of Zn\[X\]. -pub struct VecZnx { - pub inner: ZnxBase, +pub struct VecZnx { + data: D, + n: usize, + cols: usize, + size: usize, } -impl GetZnxBase for VecZnx { - fn znx(&self) -> &ZnxBase { - &self.inner +impl ZnxInfos for VecZnx { + fn cols(&self) -> usize { + self.cols } - fn znx_mut(&mut self) -> &mut ZnxBase { - &mut self.inner + fn rows(&self) -> usize { + 1 } -} -impl ZnxInfos for VecZnx {} + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } -impl ZnxSliceSize for VecZnx { fn sl(&self) -> usize { self.cols() * self.n() } } -impl ZnxLayout for VecZnx { - type Scalar = i64; -} - -impl ZnxZero for VecZnx {} - -impl ZnxRsh for VecZnx {} - -impl ZnxAlloc for VecZnx { - type Scalar = i64; - - fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx { - debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); - VecZnx { - inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes), - } - } - - fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { - debug_assert_eq!( - _rows, VEC_ZNX_ROWS, - "rows != {} not supported for VecZnx", - VEC_ZNX_ROWS - ); - module.n() * cols * size * size_of::() +impl DataView for VecZnx { + type D = D; + fn data(&self) -> &Self::D { + &self.data } } -/// Copies the coefficients of `a` on the receiver. -/// Copy is done with the minimum size matching both backing arrays. -/// Panics if the cols do not match. -pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { - assert_eq!(b.cols(), a.cols()); - let data_a: &[i64] = a.raw(); - let data_b: &mut [i64] = b.raw_mut(); - let size = min(data_b.len(), data_a.len()); - data_b[..size].copy_from_slice(&data_a[..size]) +impl DataViewMut for VecZnx { + fn data_mut(&self) -> &mut Self::D { + &mut self.data + } } -impl VecZnx { +impl> ZnxView for VecZnx { + type Scalar = i64; +} + +impl + AsRef<[u8]>> VecZnx { + pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) { + normalize(log_base2k, self, col, carry) + } + /// Truncates the precision of the [VecZnx] by k bits. /// /// # Arguments @@ -91,12 +84,6 @@ impl VecZnx { return; } - if !self.borrowing() { - self.inner - .data - .truncate(self.n() * self.cols() * (self.size() - k / log_base2k)); - } - self.inner.size -= k / log_base2k; let k_rem: usize = k % log_base2k; @@ -109,29 +96,72 @@ impl VecZnx { } } - pub fn copy_from(&mut self, a: &Self) { - copy_vec_znx_from(self, a); - } - - pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) { - normalize(log_base2k, self, col, carry) - } - - pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) { - switch_degree(a, col_a, self, col) + /// Switches degree of from `a.n()` to `self.n()` into `self` + pub fn switch_degree>(&mut self, col: usize, a: &Data, col_a: usize) { + switch_degree(self, col_a, a, col) } // Prints the first `n` coefficients of each limb - pub fn print(&self, n: usize, col: usize) { - (0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n])); + // pub fn print(&self, n: usize, col: usize) { + // (0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n])); + // } +} + +impl>> VecZnx { + pub(crate) fn bytes_of(n: usize, cols: usize, size: usize) -> usize { + n * cols * size * size_of::() + } + + pub(crate) fn new(n: usize, cols: usize, size: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of::(n, cols, size)); + Self { + data: data.into(), + n, + cols, + size, + } + } + + pub(crate) fn new_from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of::(n, cols, size)); + Self { + data: data.into(), + n, + cols, + size, + } } } +//(Jay)TODO: Impl. truncate pow2 for Owned Vector + +/// Copies the coefficients of `a` on the receiver. +/// Copy is done with the minimum size matching both backing arrays. +/// Panics if the cols do not match. +pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) +where + DataMut: AsMut<[u8]> + AsRef<[u8]>, + Data: AsRef<[u8]>, +{ + assert_eq!(b.cols(), a.cols()); + let data_a: &[i64] = a.raw(); + let data_b: &mut [i64] = b.raw_mut(); + let size = min(data_b.len(), data_a.len()); + data_b[..size].copy_from_slice(&data_a[..size]) +} + +// if !self.borrowing() { +// self.inner +// .data +// .truncate(self.n() * self.cols() * (self.size() - k / log_base2k)); +// } + fn normalize_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } -fn normalize(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { +fn normalize>(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); debug_assert!( @@ -162,3 +192,62 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u }); } } + +// impl ZnxAlloc for VecZnx { +// type Scalar = i64; + +// fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx { +// debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); +// VecZnx { +// inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes), +// } +// } + +// fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { +// debug_assert_eq!( +// _rows, VEC_ZNX_ROWS, +// "rows != {} not supported for VecZnx", +// VEC_ZNX_ROWS +// ); +// module.n() * cols * size * size_of::() +// } +// } + +impl> fmt::Display for VecZnx { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "VecZnx(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + } + Ok(()) + } +} + +pub type VecZnxOwned = VecZnx>; +pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; +pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 5ba7dde..682493a 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,74 +1,91 @@ use crate::ffi::vec_znx_big; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize, ZnxZero}; -use crate::{Backend, FFT64, Module, NTT120}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxView}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, alloc_aligned}; use std::marker::PhantomData; const VEC_ZNX_BIG_ROWS: usize = 1; -pub struct VecZnxBig { - pub inner: ZnxBase, - pub _marker: PhantomData, +/// VecZnxBig is Backend dependent, denoted with backend generic `B` +pub struct VecZnxBig { + data: D, + n: usize, + cols: usize, + size: usize, + _phantom: PhantomData, } -impl GetZnxBase for VecZnxBig { - fn znx(&self) -> &ZnxBase { - &self.inner +impl ZnxInfos for VecZnxBig { + fn cols(&self) -> usize { + self.cols } - fn znx_mut(&mut self) -> &mut ZnxBase { - &mut self.inner + fn rows(&self) -> usize { + 1 + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } + + fn sl(&self) -> usize { + self.cols() * self.n() } } -impl ZnxInfos for VecZnxBig {} - -impl ZnxAlloc for VecZnxBig { - type Scalar = u8; - - fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { - debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); - VecZnxBig { - inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_BIG_ROWS, cols, size, bytes), - _marker: PhantomData, - } - } - - fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { - debug_assert_eq!( - _rows, VEC_ZNX_BIG_ROWS, - "rows != {} not supported for VecZnxBig", - VEC_ZNX_BIG_ROWS - ); - unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } +impl DataView for VecZnxBig { + type D = D; + fn data(&self) -> &Self::D { + &self.data } } -impl ZnxLayout for VecZnxBig { +impl DataViewMut for VecZnxBig { + fn data_mut(&self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for VecZnxBig { type Scalar = i64; } -impl ZnxLayout for VecZnxBig { - type Scalar = i128; -} +impl>, B: Backend> VecZnxBig { + pub(crate) fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { + unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } + } -impl ZnxZero for VecZnxBig {} + pub(crate) fn new(module: &Module, cols: usize, size: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of(module, cols, size)); + Self { + data: data.into(), + n: module.n(), + cols, + size, + _phantom: PhantomData, + } + } -impl ZnxSliceSize for VecZnxBig { - fn sl(&self) -> usize { - self.n() * self.cols() + pub(crate) fn new_from_bytes(module: &Module, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(module, cols, size)); + Self { + data: data.into(), + n: module.n(), + cols, + size, + _phantom: PhantomData, + } } } -impl ZnxSliceSize for VecZnxBig { - fn sl(&self) -> usize { - self.n() * 4 * self.cols() - } -} +pub type VecZnxBigOwned = VecZnxBig, B>; -impl ZnxZero for VecZnxBig {} - -impl VecZnxBig { - pub fn print(&self, n: usize, col: usize) { - (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); - } -} +// impl VecZnxBig { +// pub fn print(&self, n: usize, col: usize) { +// (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); +// } +// } diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 8be526e..5353c32 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,10 +1,10 @@ use crate::ffi::vec_znx; -use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize}; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, assert_alignement}; +use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxView, ZnxViewMut}; +use crate::{Backend, DataView, FFT64, Module, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxOps, assert_alignement}; -pub trait VecZnxBigOps { +pub trait VecZnxBigAlloc { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. - fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig; + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -18,98 +18,100 @@ pub trait VecZnxBigOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBig; + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned; - /// Returns a new [VecZnxBig] with the provided bytes array as backing array. - /// - /// Behavior: the backing array is only borrowed. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials.. - /// * `size`: the number of polynomials per column. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; + // /// Returns a new [VecZnxBig] with the provided bytes array as backing array. + // /// + // /// Behavior: the backing array is only borrowed. + // /// + // /// # Arguments + // /// + // /// * `cols`: the number of polynomials.. + // /// * `size`: the number of polynomials per column. + // /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. + // /// + // /// # Panics + // /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. + // fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; +} +pub trait VecZnxBigOps { /// Adds `a` to `b` and stores the result on `c`. fn vec_znx_big_add( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnxBig, + b: &VecZnxBig, b_col: usize, ); /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); + fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); /// Adds `a` to `b` and stores the result on `c`. fn vec_znx_big_add_small( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnx, + b: &VecZnx, b_col: usize, ); /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); /// Subtracts `a` to `b` and stores the result on `c`. fn vec_znx_big_sub( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnxBig, + b: &VecZnxBig, b_col: usize, ); /// Subtracts `a` to `b` and stores the result on `b`. - fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); + fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); /// Subtracts `b` to `a` and stores the result on `b`. - fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); + fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); /// Subtracts `b` to `a` and stores the result on `c`. fn vec_znx_big_sub_small_a( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnx, + a: &VecZnx, a_col: usize, - b: &VecZnxBig, + b: &VecZnxBig, b_col: usize, ); /// Subtracts `a` to `b` and stores the result on `b`. - fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); /// Subtracts `b` to `a` and stores the result on `c`. fn vec_znx_big_sub_small_b( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnx, + b: &VecZnx, b_col: usize, ); /// Subtracts `b` to `a` and stores the result on `b`. - fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; @@ -123,44 +125,57 @@ pub trait VecZnxBigOps { fn vec_znx_big_normalize( &self, log_base2k: usize, - res: &mut VecZnx, + res: &mut VecZnx, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, tmp_bytes: &mut [u8], ); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); + fn vec_znx_big_automorphism( + &self, + k: i64, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnxBig, + a_col: usize, + ); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize); + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize); } -impl VecZnxBigOps for Module { - fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig { - VecZnxBig::new(self, 1, cols, size) +impl VecZnxBigAlloc for Module { + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned { + VecZnxBig::new(self, cols, size) } - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBig { - VecZnxBig::from_bytes(self, 1, cols, size, bytes) + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { + VecZnxBig::new_from_bytes(self, cols, size, bytes) } - fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes_borrow(self, 1, cols, size, tmp_bytes) - } + // fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { + // VecZnxBig::from_bytes_borrow(self, 1, cols, size, tmp_bytes) + // } fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { - VecZnxBig::bytes_of(self, 1, cols, size) + VecZnxBig::bytes_of(self, cols, size) } +} +impl VecZnxBigOps for Module +where + DataMut: AsMut<[u8]> + AsRef<[u8]>, + Data: AsRef<[u8]>, +{ fn vec_znx_big_add( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnxBig, + b: &VecZnxBig, b_col: usize, ) { #[cfg(debug_assertions)] @@ -186,20 +201,25 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { + fn vec_znx_big_add_inplace( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnxBig, + a_col: usize, + ) { unsafe { - let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); + Self::vec_znx_big_add(self, res, res_col, a, a_col, res, res_col); } } fn vec_znx_big_sub( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnxBig, + b: &VecZnxBig, b_col: usize, ) { #[cfg(debug_assertions)] @@ -225,27 +245,38 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { + //(Jay)TODO: check whether definitions sub_ab, sub_ba make sense to you + fn vec_znx_big_sub_ab_inplace( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnxBig, + a_col: usize, + ) { unsafe { - let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); + Self::vec_znx_big_sub(self, res, res_col, a, a_col, res, res_col); } } - fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { + fn vec_znx_big_sub_ba_inplace( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnxBig, + a_col: usize, + ) { unsafe { - let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); + Self::vec_znx_big_sub(self, res, res_col, res, res_col, a, a_col); } } fn vec_znx_big_sub_small_b( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnx, + b: &VecZnx, b_col: usize, ) { #[cfg(debug_assertions)] @@ -271,20 +302,25 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_big_sub_small_b_inplace( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnx, + a_col: usize, + ) { unsafe { - let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); + Self::vec_znx_big_sub_small_b(self, res, res_col, res, res_col, a, a_col); } } fn vec_znx_big_sub_small_a( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnx, + a: &VecZnx, a_col: usize, - b: &VecZnxBig, + b: &VecZnxBig, b_col: usize, ) { #[cfg(debug_assertions)] @@ -310,20 +346,25 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_big_sub_small_a_inplace( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnx, + a_col: usize, + ) { unsafe { - let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); + Self::vec_znx_big_sub_small_a(self, res, res_col, a, a_col, res, res_col); } } fn vec_znx_big_add_small( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnx, + b: &VecZnx, b_col: usize, ) { #[cfg(debug_assertions)] @@ -349,11 +390,8 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { - unsafe { - let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_add_small(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); - } + fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { + Self::vec_znx_big_add_small(self, res, res_col, res, res_col, a, a_col); } fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { @@ -363,9 +401,9 @@ impl VecZnxBigOps for Module { fn vec_znx_big_normalize( &self, log_base2k: usize, - res: &mut VecZnx, + res: &mut VecZnx, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, tmp_bytes: &mut [u8], ) { @@ -391,7 +429,14 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { + fn vec_znx_big_automorphism( + &self, + k: i64, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnxBig, + a_col: usize, + ) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -411,10 +456,9 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize) { + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize) { unsafe { - let a_ptr: *mut VecZnxBig = a as *mut VecZnxBig; - Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col); + Self::vec_znx_big_automorphism(self, k, a, a_col, a, a_col); } } } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index b187645..c192486 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,85 +1,135 @@ -use crate::ffi::vec_znx_dft; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize, ZnxZero}; -use crate::{Backend, FFT64, Module, VecZnxBig}; use std::marker::PhantomData; +use crate::ffi::vec_znx_dft; +use crate::znx_base::{ZnxAlloc, ZnxInfos}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; + const VEC_ZNX_DFT_ROWS: usize = 1; -pub struct VecZnxDft { - inner: ZnxBase, - pub _marker: PhantomData, +pub struct VecZnxDft { + data: D, + n: usize, + cols: usize, + size: usize, + _phantom: PhantomData, } -impl GetZnxBase for VecZnxDft { - fn znx(&self) -> &ZnxBase { - &self.inner +impl ZnxInfos for VecZnxDft { + fn cols(&self) -> usize { + self.cols } - fn znx_mut(&mut self) -> &mut ZnxBase { - &mut self.inner + fn rows(&self) -> usize { + 1 + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } + + fn sl(&self) -> usize { + self.cols() * self.n() } } -impl ZnxInfos for VecZnxDft {} - -impl ZnxAlloc for VecZnxDft { - type Scalar = u8; - - fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { - debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); - Self { - inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes), - _marker: PhantomData, - } - } - - fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { - debug_assert_eq!( - _rows, VEC_ZNX_DFT_ROWS, - "rows != {} not supported for VecZnxDft", - VEC_ZNX_DFT_ROWS - ); - unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } +impl DataView for VecZnxDft { + type D = D; + fn data(&self) -> &Self::D { + &self.data } } -impl ZnxLayout for VecZnxDft { +impl DataViewMut for VecZnxDft { + fn data_mut(&self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for VecZnxDft { type Scalar = f64; } -impl ZnxZero for VecZnxDft {} - -impl ZnxSliceSize for VecZnxDft { - fn sl(&self) -> usize { - self.n() +impl>, B: Backend> VecZnxDft { + pub(crate) fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { + unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } } -} -impl VecZnxDft { - pub fn print(&self, n: usize, col: usize) { - (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); + pub(crate) fn new(module: &Module, cols: usize, size: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of(module, cols, size)); + Self { + data: data.into(), + n: module.n(), + cols, + size, + _phantom: PhantomData, + } } -} -impl VecZnxDft { - /// Cast a [VecZnxDft] into a [VecZnxBig]. - /// The returned [VecZnxBig] shares the backing array - /// with the original [VecZnxDft]. - pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig { - assert!( - self.data().len() == 0, - "cannot alias VecZnxDft into VecZnxBig if it owns the data" - ); - VecZnxBig:: { - inner: ZnxBase { - data: Vec::new(), - ptr: self.ptr(), - n: self.n(), - rows: self.rows(), - cols: self.cols(), - size: self.size(), - }, - _marker: PhantomData, + pub(crate) fn new_from_bytes(module: &Module, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(module, cols, size)); + Self { + data: data.into(), + n: module.n(), + cols, + size, + _phantom: PhantomData, } } } + +pub type VecZnxDftOwned = VecZnxDft, B>; + +// impl ZnxAlloc for VecZnxDft { +// type Scalar = u8; + +// fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { +// debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); +// Self { +// inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes), +// _marker: PhantomData, +// } +// } + +// fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { +// debug_assert_eq!( +// _rows, VEC_ZNX_DFT_ROWS, +// "rows != {} not supported for VecZnxDft", +// VEC_ZNX_DFT_ROWS +// ); +// unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } +// } +// } + +// impl VecZnxDft { +// pub fn print(&self, n: usize, col: usize) { +// (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); +// } +// } + +// impl VecZnxDft { +// /// Cast a [VecZnxDft] into a [VecZnxBig]. +// /// The returned [VecZnxBig] shares the backing array +// /// with the original [VecZnxDft]. +// pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig { +// assert!( +// self.data().len() == 0, +// "cannot alias VecZnxDft into VecZnxBig if it owns the data" +// ); +// VecZnxBig:: { +// inner: ZnxBase { +// data: Vec::new(), +// ptr: self.ptr(), +// n: self.n(), +// rows: self.rows(), +// cols: self.cols(), +// size: self.size(), +// }, +// _marker: PhantomData, +// } +// } +// } diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 679abce..cf2090b 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -1,15 +1,14 @@ +use crate::VecZnxDftOwned; use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_dft; use crate::znx_base::ZnxAlloc; use crate::znx_base::ZnxInfos; -use crate::znx_base::ZnxLayout; -use crate::znx_base::ZnxSliceSize; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxZero, assert_alignement}; +use crate::{FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxView, ZnxViewMut, ZnxZero, assert_alignement}; use std::cmp::min; -pub trait VecZnxDftOps { +pub trait VecZnxDftAlloc { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. - fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft; + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -22,20 +21,20 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDft; + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - /// - /// Behavior: the backing array is only borrowed. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxDft]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; + // /// Returns a new [VecZnxDft] with the provided bytes array as backing array. + // /// + // /// Behavior: the backing array is only borrowed. + // /// + // /// # Arguments + // /// + // /// * `cols`: the number of cols of the [VecZnxDft]. + // /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. + // /// + // /// # Panics + // /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. + // fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -47,37 +46,58 @@ pub trait VecZnxDftOps { /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; +} +pub trait VecZnxDftOps { /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. fn vec_znx_idft_tmp_bytes(&self) -> usize; /// b <- IDFT(a), uses a as scratch space. - fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig, res_col: usize, a: &mut VecZnxDft, a_cols: usize); + fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig, res_col: usize, a: &mut VecZnxDft, a_cols: usize); - fn vec_znx_idft(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxDft, a_col: usize, tmp_bytes: &mut [u8]); + fn vec_znx_idft( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnxDft, + a_col: usize, + tmp_bytes: &mut [u8], + ); - fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize); } -impl VecZnxDftOps for Module { - fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft { - VecZnxDft::::new(&self, 1, cols, size) +impl VecZnxDftAlloc for Module { + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned { + VecZnxDftOwned::new(&self, cols, size) } - fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDft { - VecZnxDft::from_bytes(self, 1, cols, size, bytes) + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { + VecZnxDftOwned::new_from_bytes(self, cols, size, bytes) } - fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes_borrow(self, 1, cols, size, bytes) - } + // fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft { + // VecZnxDft::from_bytes_borrow(self, 1, cols, size, bytes) + // } fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { - VecZnxDft::bytes_of(&self, 1, cols, size) + VecZnxDft::bytes_of(&self, cols, size) } +} - fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig, res_col: usize, a: &mut VecZnxDft, a_col: usize) { +impl VecZnxDftOps for Module +where + DataMut: AsMut<[u8]> + AsRef<[u8]>, + Data: AsRef<[u8]>, +{ + fn vec_znx_idft_tmp_a( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &mut VecZnxDft, + a_col: usize, + ) { let min_size: usize = min(res.size(), a.size()); unsafe { @@ -86,7 +106,7 @@ impl VecZnxDftOps for Module { self.ptr, res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, 1 as u64, - a.at_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + a.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, 1 as u64, ) }); @@ -104,7 +124,7 @@ impl VecZnxDftOps for Module { /// /// # Panics /// If b.cols < a_cols - fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize) { let min_size: usize = min(res.size(), a.size()); unsafe { @@ -125,7 +145,14 @@ impl VecZnxDftOps for Module { } // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. - fn vec_znx_idft(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxDft, a_col: usize, tmp_bytes: &mut [u8]) { + fn vec_znx_idft( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnxDft, + a_col: usize, + tmp_bytes: &mut [u8], + ) { #[cfg(debug_assertions)] { assert!( diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 6365ad3..339bc12 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,14 +1,15 @@ use crate::ffi::vec_znx; -use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree}; -use crate::{Backend, Module, VEC_ZNX_ROWS, VecZnx, assert_alignement}; -pub trait VecZnxOps { +use crate::znx_base::{ZnxInfos, switch_degree}; +use crate::{Backend, Module, VecZnx, VecZnxOwned, ZnxView, ZnxViewMut, assert_alignement}; + +pub trait VecZnxAlloc { /// Allocates a new [VecZnx]. /// /// # Arguments /// /// * `cols`: the number of polynomials. /// * `size`: the number small polynomials per column. - fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx; + fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned; /// Instantiates a new [VecZnx] from a slice of bytes. /// The returned [VecZnx] takes ownership of the slice of bytes. @@ -20,25 +21,28 @@ pub trait VecZnxOps { /// /// # Panic /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. - fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnx; + fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned; - /// Instantiates a new [VecZnx] from a slice of bytes. - /// The returned [VecZnx] does take ownership of the slice of bytes. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials. - /// * `size`: the number small polynomials per column. - /// - /// # Panic - /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. - fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx; + // /// Instantiates a new [VecZnx] from a slice of bytes. + // /// The returned [VecZnx] does take ownership of the slice of bytes. + // /// + // /// # Arguments + // /// + // /// * `cols`: the number of polynomials. + // /// * `size`: the number small polynomials per column. + // /// + // /// # Panic + // /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. + // fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx; + // (Jay)TODO /// Returns the number of bytes necessary to allocate /// a new [VecZnx] through [VecZnxOps::new_vec_znx_from_bytes] /// or [VecZnxOps::new_vec_znx_from_bytes_borrow]. fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; +} +pub trait VecZnxOps { /// Returns the minimum number of bytes necessary for normalization. fn vec_znx_normalize_tmp_bytes(&self) -> usize; @@ -46,48 +50,64 @@ pub trait VecZnxOps { fn vec_znx_normalize( &self, log_base2k: usize, - res: &mut VecZnx, + res: &mut VecZnx, res_col: usize, - a: &VecZnx, + a: &VecZnx, a_col: usize, tmp_bytes: &mut [u8], ); /// Normalizes the selected column of `a`. - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]); + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]); - /// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `c`. - fn vec_znx_add(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize); + /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`. + fn vec_znx_add( + &self, + res: &mut VecZnx, + res_col: usize, + a: &VecZnx, + a_col: usize, + b: &VecZnx, + b_col: usize, + ); - /// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `res`. - fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`. + fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); - /// Subtracts the selected column of `b` to the selected column of `a` and write the result on the selected column of `res`. - fn vec_znx_sub(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize); + /// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`. + fn vec_znx_sub( + &self, + res: &mut VecZnx, + res_col: usize, + a: &VecZnx, + a_col: usize, + b: &VecZnx, + b_col: usize, + ); - /// Subtracts the selected column of `a` to the selected column of `res`. - fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + /// Subtracts the selected column of `a` from the selected column of `res` inplace. + fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); - /// Subtracts the selected column of `a` to the selected column of `res` and negates the selected column of `res`. - fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + // /// Subtracts the selected column of `a` from the selected column of `res` and negates the selected column of `res`. + // fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); - // Negates the selected column of `a` and stores the result on the selected column of `res`. - fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + // Negates the selected column of `a` and stores the result in `res_col` of `res`. + fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Negates the selected column of `a`. - fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize); + fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize); - /// Multiplies the selected column of `a` by X^k and stores the result on the selected column of `res`. - fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + /// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`. + fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Multiplies the selected column of `a` by X^k. - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); - /// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result on the selected column of `res`. - fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + /// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`. + fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Applies the automorphism X^i -> X^ik on the selected column of `a`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); /// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`. /// @@ -95,7 +115,14 @@ pub trait VecZnxOps { /// /// This method requires that all [VecZnx] of b have the same ring degree /// and that b.n() * b.len() <= a.n() - fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &VecZnx, a_col: usize, buf: &mut VecZnx); + fn vec_znx_split( + &self, + res: &mut Vec>, + res_col: usize, + a: &VecZnx, + a_col: usize, + buf: &mut VecZnx, + ); /// Merges the subrings of the selected column of `a` into the selected column of `res`. /// @@ -103,26 +130,29 @@ pub trait VecZnxOps { /// /// This method requires that all [VecZnx] of a have the same ring degree /// and that a.n() * a.len() <= b.n() - fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec, a_col: usize); + fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec>, a_col: usize); } -impl VecZnxOps for Module { - fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx { - VecZnx::new(self, VEC_ZNX_ROWS, cols, size) +impl VecZnxAlloc for Module { + //(Jay)TODO: One must define the Scalar generic param here. + fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned { + VecZnxOwned::new(self.n(), cols, size) } fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize { - VecZnx::bytes_of(self, VEC_ZNX_ROWS, cols, size) + VecZnxOwned::bytes_of(self.n(), cols, size) } - fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnx { - VecZnx::from_bytes(self, VEC_ZNX_ROWS, cols, size, bytes) - } - - fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx { - VecZnx::from_bytes_borrow(self, VEC_ZNX_ROWS, cols, size, tmp_bytes) + fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned { + VecZnxOwned::new_from_bytes(self.n(), cols, size, bytes) } +} +impl VecZnxOps for Module +where + Data: AsRef<[u8]>, + DataMut: AsRef<[u8]> + AsMut<[u8]>, +{ fn vec_znx_normalize_tmp_bytes(&self) -> usize { unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize } } @@ -130,9 +160,9 @@ impl VecZnxOps for Module { fn vec_znx_normalize( &self, log_base2k: usize, - res: &mut VecZnx, + res: &mut VecZnx, res_col: usize, - a: &VecZnx, + a: &VecZnx, a_col: usize, tmp_bytes: &mut [u8], ) { @@ -158,7 +188,7 @@ impl VecZnxOps for Module { } } - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; Self::vec_znx_normalize( @@ -173,7 +203,15 @@ impl VecZnxOps for Module { } } - fn vec_znx_add(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize) { + fn vec_znx_add( + &self, + res: &mut VecZnx, + res_col: usize, + a: &VecZnx, + a_col: usize, + b: &VecZnx, + b_col: usize, + ) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -197,14 +235,21 @@ impl VecZnxOps for Module { } } - fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { - let res_ptr: *mut VecZnx = res as *mut VecZnx; - Self::vec_znx_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); + Self::vec_znx_add(&self, res, res_col, a, a_col, res, res_col); } } - fn vec_znx_sub(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize) { + fn vec_znx_sub( + &self, + res: &mut VecZnx, + res_col: usize, + a: &VecZnx, + a_col: usize, + b: &VecZnx, + b_col: usize, + ) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -228,21 +273,21 @@ impl VecZnxOps for Module { } } - fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { let res_ptr: *mut VecZnx = res as *mut VecZnx; - Self::vec_znx_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); + Self::vec_znx_sub(self, res, res_col, a, a_col, res, res_col); } } - fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { - unsafe { - let res_ptr: *mut VecZnx = res as *mut VecZnx; - Self::vec_znx_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); - } - } + // fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + // unsafe { + // let res_ptr: *mut VecZnx = res as *mut VecZnx; + // Self::vec_znx_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); + // } + // } - fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -261,14 +306,13 @@ impl VecZnxOps for Module { } } - fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize) { + fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize) { unsafe { - let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_negate(self, &mut *a_ptr, a_col, &*a_ptr, a_col); + Self::vec_znx_negate(self, a, a_col, a, a_col); } } - fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -288,14 +332,13 @@ impl VecZnxOps for Module { } } - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { unsafe { - let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_rotate(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col); + Self::vec_znx_rotate(self, k, a, a_col, a, a_col); } } - fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -315,14 +358,20 @@ impl VecZnxOps for Module { } } - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { unsafe { - let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col); + Self::vec_znx_automorphism(self, k, a, a_col, a, a_col); } } - fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &VecZnx, a_col: usize, buf: &mut VecZnx) { + fn vec_znx_split( + &self, + res: &mut Vec>, + res_col: usize, + a: &VecZnx, + a_col: usize, + buf: &mut VecZnx, + ) { let (n_in, n_out) = (a.n(), res[0].n()); debug_assert!( @@ -348,7 +397,7 @@ impl VecZnxOps for Module { }) } - fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec, a_col: usize) { + fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec>, a_col: usize) { let (n_in, n_out) = (res.n(), a[0].n()); debug_assert!( diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index 4cacb70..bf941d4 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -54,11 +54,9 @@ pub trait GetZnxBase { fn znx_mut(&mut self) -> &mut ZnxBase; } -pub trait ZnxInfos: GetZnxBase { +pub trait ZnxInfos { /// Returns the ring degree of the polynomials. - fn n(&self) -> usize { - self.znx().n - } + fn n(&self) -> usize; /// Returns the base two logarithm of the ring dimension of the polynomials. fn log_n(&self) -> usize { @@ -66,41 +64,27 @@ pub trait ZnxInfos: GetZnxBase { } /// Returns the number of rows. - fn rows(&self) -> usize { - self.znx().rows - } + fn rows(&self) -> usize; + /// Returns the number of polynomials in each row. - fn cols(&self) -> usize { - self.znx().cols - } + fn cols(&self) -> usize; /// Returns the number of size per polynomial. - fn size(&self) -> usize { - self.znx().size - } - - /// Returns the underlying raw bytes array. - fn data(&self) -> &[u8] { - &self.znx().data - } - - /// Returns a pointer to the underlying raw bytes array. - fn ptr(&self) -> *mut u8 { - self.znx().ptr - } + fn size(&self) -> usize; /// Returns the total number of small polynomials. fn poly_count(&self) -> usize { self.rows() * self.cols() * self.size() } -} -pub trait ZnxSliceSize { /// Returns the slice size, which is the offset between /// two size of the same column. fn sl(&self) -> usize; } +// pub trait ZnxSliceSize {} + +//(Jay) TODO: Remove ZnxAlloc pub trait ZnxAlloc where Self: Sized + ZnxInfos, @@ -122,22 +106,21 @@ where fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize; } -pub trait ZnxLayout: ZnxInfos { - type Scalar; +pub trait DataView { + type D; + fn data(&self) -> &Self::D; +} - /// Returns true if the receiver is only borrowing the data. - fn borrowing(&self) -> bool { - self.znx().data.len() == 0 - } +pub trait DataViewMut: DataView { + fn data_mut(&self) -> &mut Self::D; +} + +pub trait ZnxView: ZnxInfos + DataView> { + type Scalar; /// Returns a non-mutable pointer to the underlying coefficients array. fn as_ptr(&self) -> *const Self::Scalar { - self.znx().ptr as *const Self::Scalar - } - - /// Returns a mutable pointer to the underlying coefficients array. - fn as_mut_ptr(&mut self) -> *mut Self::Scalar { - self.znx_mut().ptr as *mut Self::Scalar + self.data().as_ref().as_ptr() as *const Self::Scalar } /// Returns a non-mutable reference to the entire underlying coefficient array. @@ -145,11 +128,6 @@ pub trait ZnxLayout: ZnxInfos { unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } } - /// Returns a mutable reference to the entire underlying coefficient array. - fn raw_mut(&mut self) -> &mut [Self::Scalar] { - unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } - } - /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column. fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { #[cfg(debug_assertions)] @@ -161,6 +139,23 @@ pub trait ZnxLayout: ZnxInfos { unsafe { self.as_ptr().add(offset) } } + /// Returns non-mutable reference to the (i, j)-th small polynomial. + fn at(&self, i: usize, j: usize) -> &[Self::Scalar] { + unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } + } +} + +pub trait ZnxViewMut: ZnxView + DataViewMut> { + /// Returns a mutable pointer to the underlying coefficients array. + fn as_mut_ptr(&mut self) -> *mut Self::Scalar { + self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar + } + + /// Returns a mutable reference to the entire underlying coefficient array. + fn raw_mut(&mut self) -> &mut [Self::Scalar] { + unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } + } + /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column. fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { #[cfg(debug_assertions)] @@ -172,17 +167,15 @@ pub trait ZnxLayout: ZnxInfos { unsafe { self.as_mut_ptr().add(offset) } } - /// Returns non-mutable reference to the (i, j)-th small polynomial. - fn at(&self, i: usize, j: usize) -> &[Self::Scalar] { - unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } - } - /// Returns mutable reference to the (i, j)-th small polynomial. fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } } } +//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known +impl ZnxViewMut for T where T: ZnxView + DataViewMut> {} + use std::convert::TryFrom; use std::num::TryFromIntError; use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; @@ -213,7 +206,7 @@ impl IntegerType for i128 { const BITS: u32 = 128; } -pub trait ZnxZero: ZnxLayout +pub trait ZnxZero: ZnxViewMut where Self: Sized, { @@ -238,16 +231,16 @@ where } } -pub trait ZnxRsh: ZnxLayout + ZnxZero -where - Self: Sized, - Self::Scalar: IntegerType, -{ +pub trait ZnxRsh: ZnxZero { fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) { rsh(k, log_base2k, self, col, carry) } } +// Blanket implementations +impl ZnxZero for T where T: ZnxViewMut {} +impl ZnxRsh for T where T: ZnxZero {} + pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) where V::Scalar: IntegerType, @@ -310,10 +303,7 @@ pub fn rsh_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } -pub fn switch_degree(b: &mut T, col_b: usize, a: &T, col_a: usize) -where - ::Scalar: IntegerType, -{ +pub fn switch_degree(b: &mut DMut, col_b: usize, a: &D, col_a: usize) { let (n_in, n_out) = (a.n(), b.n()); let (gap_in, gap_out): (usize, usize); @@ -334,3 +324,64 @@ where .for_each(|(x_in, x_out)| *x_out = *x_in); }); } + +// pub trait ZnxLayout: ZnxInfos { +// type Scalar; + +// /// Returns true if the receiver is only borrowing the data. +// fn borrowing(&self) -> bool { +// self.znx().data.len() == 0 +// } + +// /// Returns a non-mutable pointer to the underlying coefficients array. +// fn as_ptr(&self) -> *const Self::Scalar { +// self.znx().ptr as *const Self::Scalar +// } + +// /// Returns a mutable pointer to the underlying coefficients array. +// fn as_mut_ptr(&mut self) -> *mut Self::Scalar { +// self.znx_mut().ptr as *mut Self::Scalar +// } + +// /// Returns a non-mutable reference to the entire underlying coefficient array. +// fn raw(&self) -> &[Self::Scalar] { +// unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } +// } + +// /// Returns a mutable reference to the entire underlying coefficient array. +// fn raw_mut(&mut self) -> &mut [Self::Scalar] { +// unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } +// } + +// /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column. +// fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { +// #[cfg(debug_assertions)] +// { +// assert!(i < self.cols()); +// assert!(j < self.size()); +// } +// let offset: usize = self.n() * (j * self.cols() + i); +// unsafe { self.as_ptr().add(offset) } +// } + +// /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column. +// fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { +// #[cfg(debug_assertions)] +// { +// assert!(i < self.cols()); +// assert!(j < self.size()); +// } +// let offset: usize = self.n() * (j * self.cols() + i); +// unsafe { self.as_mut_ptr().add(offset) } +// } + +// /// Returns non-mutable reference to the (i, j)-th small polynomial. +// fn at(&self, i: usize, j: usize) -> &[Self::Scalar] { +// unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } +// } + +// /// Returns mutable reference to the (i, j)-th small polynomial. +// fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { +// unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } +// } +// } From ff8370e0235d2b1df8aa559494b48b478467ec92 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Sat, 3 May 2025 16:37:20 +0530 Subject: [PATCH 25/87] everything compiles. Scratchpad not yet implemented --- base2k/examples/rlwe_encrypt.rs | 2 +- base2k/src/encoding.rs | 40 +++-- base2k/src/lib.rs | 26 ++++ base2k/src/mat_znx_dft.rs | 145 +++++++++++------ base2k/src/mat_znx_dft_ops.rs | 258 +++++++++++++++++-------------- base2k/src/module.rs | 2 +- base2k/src/sampling.rs | 61 +++++--- base2k/src/scalar_znx.rs | 174 +++++++++++++-------- base2k/src/scalar_znx_dft.rs | 120 ++++++++------ base2k/src/scalar_znx_dft_ops.rs | 59 ++++--- base2k/src/stats.rs | 2 +- base2k/src/vec_znx.rs | 13 +- base2k/src/vec_znx_big.rs | 6 +- base2k/src/vec_znx_big_ops.rs | 158 ++++++++++++++++--- base2k/src/vec_znx_dft.rs | 17 +- base2k/src/vec_znx_dft_ops.rs | 10 +- base2k/src/vec_znx_ops.rs | 131 ++++++++++++---- base2k/src/znx_base.rs | 195 +++++++++++------------ rlwe/src/automorphism.rs | 4 +- 19 files changed, 919 insertions(+), 504 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 2f08633..afac2f8 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,5 +1,5 @@ use base2k::{ - Encoding, FFT64, Module, Sampling, Scalar, ScalarOps, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, + Encoding, FFT64, Module, Sampling, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned, }; use itertools::izip; diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index b7d014d..ba48474 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -1,5 +1,5 @@ use crate::ffi::znx::znx_zero_i64_ref; -use crate::znx_base::ZnxLayout; +use crate::znx_base::{ZnxView, ZnxViewMut}; use crate::{VecZnx, znx_base::ZnxInfos}; use itertools::izip; use rug::{Assign, Float}; @@ -59,7 +59,7 @@ pub trait Encoding { fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64; } -impl Encoding for VecZnx { +impl + AsRef<[u8]>> Encoding for VecZnx { fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { encode_vec_i64(self, col_i, log_base2k, log_k, data, log_max) } @@ -81,7 +81,14 @@ impl Encoding for VecZnx { } } -fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { +fn encode_vec_i64 + AsRef<[u8]>>( + a: &mut VecZnx, + col_i: usize, + log_base2k: usize, + log_k: usize, + data: &[i64], + log_max: usize, +) { let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] @@ -132,7 +139,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, } } -fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { +fn decode_vec_i64 + AsRef<[u8]>>(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { @@ -160,7 +167,7 @@ fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, dat }) } -fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) { +fn decode_vec_float + AsRef<[u8]>>(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) { let size: usize = a.size(); #[cfg(debug_assertions)] { @@ -194,7 +201,15 @@ fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Flo }); } -fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { +fn encode_coeff_i64 + AsRef<[u8]>>( + a: &mut VecZnx, + col_i: usize, + log_base2k: usize, + log_k: usize, + i: usize, + value: i64, + log_max: usize, +) { let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] @@ -237,7 +252,7 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz } } -fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { +fn decode_coeff_i64 + AsRef<[u8]>>(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { #[cfg(debug_assertions)] { assert!(i < a.n()); @@ -263,10 +278,9 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i #[cfg(test)] mod tests { - use crate::{ - Encoding, FFT64, Module, VecZnx, VecZnxOps, - znx_base::{ZnxInfos, ZnxLayout}, - }; + use crate::vec_znx_ops::*; + use crate::znx_base::*; + use crate::{Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos}; use itertools::izip; use sampling::source::Source; @@ -277,7 +291,7 @@ mod tests { let log_base2k: usize = 17; let size: usize = 5; let log_k: usize = size * log_base2k - 5; - let mut a: VecZnx = module.new_vec_znx(2, size); + let mut a: VecZnx<_> = module.new_vec_znx(2, size); let mut source: Source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); @@ -299,7 +313,7 @@ mod tests { let log_base2k: usize = 17; let size: usize = 5; let log_k: usize = size * log_base2k - 5; - let mut a: VecZnx = module.new_vec_znx(2, size); + let mut a: VecZnx<_> = module.new_vec_znx(2, size); let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 73d90c2..7ae1193 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -125,3 +125,29 @@ pub fn alloc_aligned(size: usize) -> Vec { DEFAULTALIGN, ) } + +pub(crate) struct ScratchSpace { + // data: D, +} + +impl ScratchSpace { + fn tmp_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> VecZnxDft { + todo!() + } + + fn tmp_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> VecZnxBig { + todo!() + } + + fn vec_znx_big_normalize_tmp_bytes(&mut self, module: &Module) -> &mut [u8] { + todo!() + } + + fn vmp_apply_dft_tmp_bytes(&mut self, module: &Module) -> &mut [u8] { + todo!() + } + + fn vmp_apply_dft_to_dft_tmp_bytes(&mut self, module: &Module) -> &mut [u8] { + todo!() + } +} diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index 470adcc..34c711a 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,5 +1,5 @@ -use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; -use crate::{Backend, FFT64, Module, alloc_aligned}; +use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; use std::marker::PhantomData; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], @@ -8,68 +8,67 @@ use std::marker::PhantomData; /// /// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. /// See the trait [MatZnxDftOps] for additional information. -pub struct MatZnxDft { - pub inner: ZnxBase, - pub cols_in: usize, - pub cols_out: usize, +pub struct MatZnxDft { + data: D, + n: usize, + size: usize, + rows: usize, + cols_in: usize, + cols_out: usize, _marker: PhantomData, } -impl GetZnxBase for MatZnxDft { - fn znx(&self) -> &ZnxBase { - &self.inner +impl ZnxInfos for MatZnxDft { + fn cols(&self) -> usize { + self.cols_in } - fn znx_mut(&mut self) -> &mut ZnxBase { - &mut self.inner + fn rows(&self) -> usize { + self.rows } -} -impl ZnxInfos for MatZnxDft {} + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } -impl ZnxSliceSize for MatZnxDft { fn sl(&self) -> usize { self.n() } } -impl ZnxLayout for MatZnxDft { +impl DataView for MatZnxDft { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for MatZnxDft { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for MatZnxDft { type Scalar = f64; } -impl MatZnxDft { - pub fn new(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { - let bytes: Vec = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); - Self::from_bytes(module, rows, cols_in, cols_out, size, bytes) +impl MatZnxDft { + pub(crate) fn cols_in(&self) -> usize { + self.cols_in } - pub fn from_bytes(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize, mut bytes: Vec) -> Self { - let mut mat: MatZnxDft = Self::from_bytes_borrow(module, rows, cols_in, cols_out, size, &mut bytes); - mat.znx_mut().data = bytes; - mat + pub(crate) fn cols_out(&self) -> usize { + self.cols_out } +} - pub fn from_bytes_borrow( - module: &Module, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: &mut [u8], - ) -> Self { - debug_assert_eq!( - bytes.len(), - Self::bytes_of(module, rows, cols_in, cols_out, size) - ); - Self { - inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes), - cols_in: cols_in, - cols_out: cols_out, - _marker: PhantomData, - } - } - - pub fn bytes_of(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { +impl>, B: Backend> MatZnxDft { + pub(crate) fn bytes_of(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { unsafe { crate::ffi::vmp::bytes_of_vmp_pmat( module.ptr, @@ -79,16 +78,62 @@ impl MatZnxDft { } } - pub fn cols_in(&self) -> usize { - self.cols_in + pub(crate) fn new(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + let data: Vec = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n: module.n(), + size, + rows, + cols_in, + cols_out, + _marker: PhantomData, + } } - pub fn cols_out(&self) -> usize { - self.cols_out + pub(crate) fn new_from_bytes( + module: &Module, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: impl Into>, + ) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n: module.n(), + size, + rows, + cols_in, + cols_out, + _marker: PhantomData, + } } + + // pub fn from_bytes_borrow( + // module: &Module, + // rows: usize, + // cols_in: usize, + // cols_out: usize, + // size: usize, + // bytes: &mut [u8], + // ) -> Self { + // debug_assert_eq!( + // bytes.len(), + // Self::bytes_of(module, rows, cols_in, cols_out, size) + // ); + // Self { + // inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes), + // cols_in: cols_in, + // cols_out: cols_out, + // _marker: PhantomData, + // } + // } } -impl MatZnxDft { +impl> MatZnxDft { /// Returns a copy of the backend array at index (i, j) of the [MatZnxDft]. /// /// # Arguments @@ -123,3 +168,5 @@ impl MatZnxDft { } } } + +pub type MatZnxDftAllocOwned = MatZnxDft, B>; diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 48c3834..62b56a1 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -1,20 +1,19 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp; -use crate::znx_base::{ZnxInfos, ZnxLayout}; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, MatZnxDft, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, assert_alignement, is_aligned, + Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchSpace, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, + VecZnxDftAlloc, VecZnxDftOps, assert_alignement, is_aligned, }; -/// This trait implements methods for vector matrix product, -/// that is, multiplying a [VecZnx] with a [MatZnxDft]. -pub trait MatZnxDftOps { +pub trait MatZnxDftAlloc { /// Allocates a new [MatZnxDft] with the given number of rows and columns. /// /// # Arguments /// /// * `rows`: number of rows (number of [VecZnxDft]). /// * `size`: number of size (number of size of each [VecZnxDft]). - fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDft; + fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftAllocOwned; fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; @@ -25,17 +24,21 @@ pub trait MatZnxDftOps { cols_out: usize, size: usize, bytes: Vec, - ) -> MatZnxDft; + ) -> MatZnxDftAllocOwned; - fn new_mat_znx_dft_from_bytes_borrow( - &self, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: &mut [u8], - ) -> MatZnxDft; + // fn new_mat_znx_dft_from_bytes_borrow( + // &self, + // rows: usize, + // cols_in: usize, + // cols_out: usize, + // size: usize, + // bytes: &mut [u8], + // ) -> MatZnxDft; +} +/// This trait implements methods for vector matrix product, +/// that is, multiplying a [VecZnx] with a [MatZnxDft]. +pub trait MatZnxDftOps { /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row] fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; @@ -49,7 +52,14 @@ pub trait MatZnxDftOps { /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. /// /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnx, tmp_bytes: &mut [u8]); + fn vmp_prepare_row( + &self, + b: &mut MatZnxDft, + b_row: usize, + b_col_in: usize, + a: &VecZnx, + scratch: &mut ScratchSpace, + ); /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row] fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; @@ -64,11 +74,11 @@ pub trait MatZnxDftOps { fn vmp_extract_row( &self, log_base2k: usize, - b: &mut VecZnx, - a: &MatZnxDft, + b: &mut VecZnx, + a: &MatZnxDft, b_row: usize, b_col_in: usize, - tmp_bytes: &mut [u8], + scratch: &mut ScratchSpace, ); /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. @@ -80,7 +90,7 @@ pub trait MatZnxDftOps { /// * `row_i`: the index of the row to prepare. /// /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft); + fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft); /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft]. /// @@ -89,7 +99,7 @@ pub trait MatZnxDftOps { /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize); + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize); /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. /// @@ -133,7 +143,7 @@ pub trait MatZnxDftOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, scratch: &mut ScratchSpace); /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. /// @@ -180,16 +190,22 @@ pub trait MatZnxDftOps { /// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); + fn vmp_apply_dft_to_dft( + &self, + c: &mut VecZnxDft, + a: &VecZnxDft, + b: &MatZnxDft, + scratch: &mut ScratchSpace, + ); } -impl MatZnxDftOps for Module { - fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDft { - MatZnxDft::::new(self, rows, cols_in, cols_out, size) +impl MatZnxDftAlloc for Module { + fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + MatZnxDftAllocOwned::bytes_of(self, rows, cols_in, cols_out, size) } - fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - MatZnxDft::::bytes_of(self, rows, cols_in, cols_out, size) + fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftAllocOwned { + MatZnxDftAllocOwned::new(self, rows, cols_in, cols_out, size) } fn new_mat_znx_dft_from_bytes( @@ -199,26 +215,28 @@ impl MatZnxDftOps for Module { cols_out: usize, size: usize, bytes: Vec, - ) -> MatZnxDft { - MatZnxDft::::from_bytes(self, rows, cols_in, cols_out, size, bytes) - } - - fn new_mat_znx_dft_from_bytes_borrow( - &self, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: &mut [u8], - ) -> MatZnxDft { - MatZnxDft::::from_bytes_borrow(self, rows, cols_in, cols_out, size, bytes) + ) -> MatZnxDftAllocOwned { + MatZnxDftAllocOwned::new_from_bytes(self, rows, cols_in, cols_out, size, bytes) } +} +impl MatZnxDftOps for Module +where + DataMut: AsMut<[u8]> + AsRef<[u8]>, + Data: AsRef<[u8]>, +{ fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { - self.bytes_of_vec_znx_dft(cols_out, size) + >::bytes_of_vec_znx_dft(self, cols_out, size) } - fn vmp_prepare_row(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnx, tmp_bytes: &mut [u8]) { + fn vmp_prepare_row( + &self, + b: &mut MatZnxDft, + b_row: usize, + b_col_in: usize, + a: &VecZnx, + scratch: &mut ScratchSpace, + ) { #[cfg(debug_assertions)] { assert_eq!(b.n(), self.n()); @@ -249,33 +267,36 @@ impl MatZnxDftOps for Module { b.size(), a.size() ); - assert!(tmp_bytes.len() >= self.vmp_prepare_row_tmp_bytes(a.cols(), a.size())); - assert!(is_aligned(tmp_bytes.as_ptr())) + // assert!( + // tmp_bytes.len() + // >= >::vmp_prepare_row_tmp_bytes(self, a.cols(), a.size()) + // ); + // assert!(is_aligned(tmp_bytes.as_ptr())) } let cols_out: usize = a.cols(); let a_size: usize = a.size(); - let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size)); - - let mut a_dft: VecZnxDft = self.new_vec_znx_dft_from_bytes_borrow(cols_out, a_size, tmp_bytes_a_dft); + // let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size)); + let mut a_dft = scratch.tmp_vec_znx_dft::(self.n(), cols_out, a_size); (0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i)); Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft); } fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { - self.bytes_of_vec_znx_dft(cols_out, size) + self.vec_znx_big_normalize_tmp_bytes() + self.bytes_of_vec_znx_dft(cols_out, size) + + >::vec_znx_big_normalize_tmp_bytes(self) } fn vmp_extract_row( &self, log_base2k: usize, - b: &mut VecZnx, - a: &MatZnxDft, + b: &mut VecZnx, + a: &MatZnxDft, a_row: usize, a_col_in: usize, - tmp_bytes: &mut [u8], + scratch: &mut ScratchSpace, ) { #[cfg(debug_assertions)] { @@ -307,24 +328,24 @@ impl MatZnxDftOps for Module { b.size(), a.size() ); - assert!(tmp_bytes.len() >= self.vmp_extract_row_tmp_bytes(a.cols(), a.size())); - assert!(is_aligned(tmp_bytes.as_ptr())) + // assert!(tmp_bytes.len() >= self.vmp_extract_row_tmp_bytes(a.cols(), a.size())); + // assert!(is_aligned(tmp_bytes.as_ptr())) } let cols_out: usize = b.cols(); let size: usize = b.size(); - let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size)); - let mut b_dft: VecZnxDft = self.new_vec_znx_dft_from_bytes_borrow(cols_out, size, bytes_a_dft); + // let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size)); + let mut b_dft = scratch.tmp_vec_znx_dft::(self.n(), cols_out, size); Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in); - let mut b_big: VecZnxBig = b_dft.alias_as_vec_znx_big(); + let mut b_big = scratch.tmp_vec_znx_big(self.n(), cols_out, size); (0..cols_out).for_each(|i| { - self.vec_znx_idft_tmp_a(&mut b_big, i, &mut b_dft, i); - self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, tmp_bytes); + >::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i); + self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, scratch); }); } - fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft) { + fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft) { #[cfg(debug_assertions)] { assert_eq!(b.n(), self.n()); @@ -369,7 +390,7 @@ impl MatZnxDftOps for Module { } } - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize) { + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize) { #[cfg(debug_assertions)] { assert_eq!(b.n(), self.n()); @@ -433,18 +454,13 @@ impl MatZnxDftOps for Module { } } - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!( - tmp_bytes.len() - >= self.vmp_apply_dft_tmp_bytes( - c.size(), - a.size(), - b.rows(), - b.cols_in(), - b.cols_out(), - b.size() - ) - ); + fn vmp_apply_dft( + &self, + c: &mut VecZnxDft, + a: &VecZnx, + b: &MatZnxDft, + scratch: &mut ScratchSpace, + ) { #[cfg(debug_assertions)] { assert_eq!(c.n(), self.n()); @@ -464,18 +480,18 @@ impl MatZnxDftOps for Module { a.cols(), b.cols_in() ); - assert!( - tmp_bytes.len() - >= self.vmp_apply_dft_tmp_bytes( - c.size(), - a.size(), - b.rows(), - b.cols_in(), - b.cols_out(), - b.size() - ) - ); - assert_alignement(tmp_bytes.as_ptr()); + // assert!( + // tmp_bytes.len() + // >= self.vmp_apply_dft_tmp_bytes( + // c.size(), + // a.size(), + // b.rows(), + // b.cols_in(), + // b.cols_out(), + // b.size() + // ) + // ); + // assert_alignement(tmp_bytes.as_ptr()); } unsafe { vmp::vmp_apply_dft( @@ -488,7 +504,7 @@ impl MatZnxDftOps for Module { b.as_ptr() as *const vmp::vmp_pmat_t, (b.rows() * b.cols_in()) as u64, (b.size() * b.cols_out()) as u64, - tmp_bytes.as_mut_ptr(), + scratch.vmp_apply_dft_tmp_bytes(self).as_mut_ptr(), ) } } @@ -515,7 +531,13 @@ impl MatZnxDftOps for Module { } } - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, tmp_bytes: &mut [u8]) { + fn vmp_apply_dft_to_dft( + &self, + c: &mut VecZnxDft, + a: &VecZnxDft, + b: &MatZnxDft, + scratch: &mut ScratchSpace, + ) { #[cfg(debug_assertions)] { assert_eq!(c.n(), self.n()); @@ -535,20 +557,20 @@ impl MatZnxDftOps for Module { a.cols(), b.cols_in() ); - assert!( - tmp_bytes.len() - >= self.vmp_apply_dft_to_dft_tmp_bytes( - c.cols(), - c.size(), - a.cols(), - a.size(), - b.rows(), - b.cols_in(), - b.cols_out(), - b.size() - ) - ); - assert_alignement(tmp_bytes.as_ptr()); + // assert!( + // tmp_bytes.len() + // >= self.vmp_apply_dft_to_dft_tmp_bytes( + // c.cols(), + // c.size(), + // a.cols(), + // a.size(), + // b.rows(), + // b.cols_in(), + // b.cols_out(), + // b.size() + // ) + // ); + // assert_alignement(tmp_bytes.as_ptr()); } unsafe { vmp::vmp_apply_dft_to_dft( @@ -560,7 +582,7 @@ impl MatZnxDftOps for Module { b.as_ptr() as *const vmp::vmp_pmat_t, b.rows() as u64, (b.size() * b.cols()) as u64, - tmp_bytes.as_mut_ptr(), + scratch.vmp_apply_dft_to_dft_tmp_bytes(self).as_mut_ptr(), ) } } @@ -568,9 +590,12 @@ impl MatZnxDftOps for Module { #[cfg(test)] mod tests { + use crate::mat_znx_dft_ops::*; + use crate::vec_znx_big_ops::*; + use crate::vec_znx_dft_ops::*; + use crate::vec_znx_ops::*; use crate::{ - FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, - alloc_aligned, znx_base::ZnxLayout, + FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, alloc_aligned, }; use sampling::source::Source; @@ -582,16 +607,19 @@ mod tests { let mat_cols_in: usize = 2; let mat_cols_out: usize = 2; let mat_size: usize = 5; - let mut a: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); - let mut b: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(mat_cols_out, mat_size); - let mut a_big: VecZnxBig = module.new_vec_znx_big(mat_cols_out, mat_size); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(mat_cols_out, mat_size); - let mut vmpmat_0: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - let mut vmpmat_1: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + let mut a: VecZnx<_> = module.new_vec_znx(mat_cols_out, mat_size); + let mut b: VecZnx<_> = module.new_vec_znx(mat_cols_out, mat_size); + let mut a_dft: VecZnxDft<_, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut a_big: VecZnxBig<_, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size); + let mut b_dft: VecZnxDft<_, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut vmpmat_0: MatZnxDft<_, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + let mut vmpmat_1: MatZnxDft<_, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + // let mut tmp_bytes: Vec = + // alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes()); + let mut scratch = ScratchSpace {}; let mut tmp_bytes: Vec = - alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes()); + alloc_aligned::( as VecZnxDftOps, Vec, _>>::vec_znx_idft_tmp_bytes(&module)); for col_in in 0..mat_cols_in { for row_i in 0..mat_rows { @@ -602,7 +630,7 @@ mod tests { module.vec_znx_dft(&mut a_dft, col_out, &a, col_out); }); - module.vmp_prepare_row(&mut vmpmat_0, row_i, col_in, &a, &mut tmp_bytes); + module.vmp_prepare_row(&mut vmpmat_0, row_i, col_in, &a, &mut scratch); // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft); @@ -613,11 +641,11 @@ mod tests { assert_eq!(a_dft.raw(), b_dft.raw()); // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) - module.vmp_extract_row(log_base2k, &mut b, &vmpmat_0, row_i, col_in, &mut tmp_bytes); + module.vmp_extract_row(log_base2k, &mut b, &vmpmat_0, row_i, col_in, &mut scratch); (0..mat_cols_out).for_each(|col_out| { module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes); - module.vec_znx_big_normalize(log_base2k, &mut a, col_out, &a_big, col_out, &mut tmp_bytes); + module.vec_znx_big_normalize(log_base2k, &mut a, col_out, &a_big, col_out, &mut scratch); }); assert_eq!(a.raw(), b.raw()); diff --git a/base2k/src/module.rs b/base2k/src/module.rs index c1799be..0e7d124 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -33,7 +33,7 @@ impl Backend for NTT120 { pub struct Module { pub ptr: *mut MODULE, - pub n: usize, + n: usize, _marker: PhantomData, } diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index b52c4db..a8b1962 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,16 +1,24 @@ -use crate::{Backend, Module, VecZnx, znx_base::ZnxLayout}; +use crate::znx_base::ZnxViewMut; +use crate::{Backend, Module, VecZnx}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; pub trait Sampling { /// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source); - - /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. - fn add_dist_f64>( + fn fill_uniform + AsRef<[u8]>>( &self, log_base2k: usize, - a: &mut VecZnx, + a: &mut VecZnx, + col_i: usize, + size: usize, + source: &mut Source, + ); + + /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. + fn add_dist_f64 + AsRef<[u8]>, D: Distribution>( + &self, + log_base2k: usize, + a: &mut VecZnx, col_i: usize, log_k: usize, source: &mut Source, @@ -19,10 +27,10 @@ pub trait Sampling { ); /// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\]. - fn add_normal( + fn add_normal + AsRef<[u8]>>( &self, log_base2k: usize, - a: &mut VecZnx, + a: &mut VecZnx, col_i: usize, log_k: usize, source: &mut Source, @@ -32,22 +40,29 @@ pub trait Sampling { } impl Sampling for Module { - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, size: usize, source: &mut Source) { + fn fill_uniform + AsRef<[u8]>>( + &self, + log_base2k: usize, + a: &mut VecZnx, + col_i: usize, + size: usize, + source: &mut Source, + ) { let base2k: u64 = 1 << log_base2k; let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; (0..size).for_each(|j| { - a.at_mut(col_a, j) + a.at_mut(col_i, j) .iter_mut() .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); }) } - fn add_dist_f64>( + fn add_dist_f64 + AsRef<[u8]>, D: Distribution>( &self, log_base2k: usize, - a: &mut VecZnx, - col_a: usize, + a: &mut VecZnx, + col_i: usize, log_k: usize, source: &mut Source, dist: D, @@ -63,7 +78,7 @@ impl Sampling for Module { let log_base2k_rem: usize = log_k % log_base2k; if log_base2k_rem != 0 { - a.at_mut(col_a, limb).iter_mut().for_each(|a| { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { let mut dist_f64: f64 = dist.sample(source); while dist_f64.abs() > bound { dist_f64 = dist.sample(source) @@ -71,7 +86,7 @@ impl Sampling for Module { *a += (dist_f64.round() as i64) << log_base2k_rem; }); } else { - a.at_mut(col_a, limb).iter_mut().for_each(|a| { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { let mut dist_f64: f64 = dist.sample(source); while dist_f64.abs() > bound { dist_f64 = dist.sample(source) @@ -81,11 +96,11 @@ impl Sampling for Module { } } - fn add_normal( + fn add_normal + AsRef<[u8]>>( &self, log_base2k: usize, - a: &mut VecZnx, - col_a: usize, + a: &mut VecZnx, + col_i: usize, log_k: usize, source: &mut Source, sigma: f64, @@ -94,7 +109,7 @@ impl Sampling for Module { self.add_dist_f64( log_base2k, a, - col_a, + col_i, log_k, source, Normal::new(0.0, sigma).unwrap(), @@ -106,7 +121,9 @@ impl Sampling for Module { #[cfg(test)] mod tests { use super::Sampling; - use crate::{FFT64, Module, Stats, VecZnx, VecZnxOps, znx_base::ZnxLayout}; + use crate::vec_znx_ops::*; + use crate::znx_base::*; + use crate::{FFT64, Module, Stats, VecZnx}; use sampling::source::Source; #[test] @@ -120,7 +137,7 @@ mod tests { let zero: Vec = vec![0; n]; let one_12_sqrt: f64 = 0.28867513459481287; (0..cols).for_each(|col_i| { - let mut a: VecZnx = module.new_vec_znx(cols, size); + let mut a: VecZnx<_> = module.new_vec_znx(cols, size); module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source); (0..cols).for_each(|col_j| { if col_j != col_i { @@ -154,7 +171,7 @@ mod tests { let zero: Vec = vec![0; n]; let k_f64: f64 = (1u64 << log_k as u64) as f64; (0..cols).for_each(|col_i| { - let mut a: VecZnx = module.new_vec_znx(cols, size); + let mut a: VecZnx<_> = module.new_vec_znx(cols, size); module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound); (0..cols).for_each(|col_j| { if col_j != col_i { diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index df3e6d1..c5052eb 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -1,64 +1,59 @@ -use crate::znx_base::{ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; -use crate::{Backend, GetZnxBase, Module, VecZnx}; +use crate::znx_base::ZnxInfos; +use crate::{Backend, DataView, DataViewMut, Module, ZnxView, ZnxViewMut, alloc_aligned}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; use sampling::source::Source; -pub const SCALAR_ZNX_ROWS: usize = 1; -pub const SCALAR_ZNX_SIZE: usize = 1; +// pub const SCALAR_ZNX_ROWS: usize = 1; +// pub const SCALAR_ZNX_SIZE: usize = 1; -pub struct Scalar { - pub inner: ZnxBase, +pub struct Scalar { + data: D, + n: usize, + cols: usize, } -impl GetZnxBase for Scalar { - fn znx(&self) -> &ZnxBase { - &self.inner +impl ZnxInfos for Scalar { + fn cols(&self) -> usize { + self.cols } - fn znx_mut(&mut self) -> &mut ZnxBase { - &mut self.inner - } -} - -impl ZnxInfos for Scalar {} - -impl ZnxAlloc for Scalar { - type Scalar = i64; - - fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { - Self { - inner: ZnxBase::from_bytes_borrow(module.n(), SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes), - } + fn rows(&self) -> usize { + 1 } - fn bytes_of(module: &Module, _rows: usize, cols: usize, _size: usize) -> usize { - debug_assert_eq!( - _rows, SCALAR_ZNX_ROWS, - "rows != {} not supported for Scalar", - SCALAR_ZNX_ROWS - ); - debug_assert_eq!( - _size, SCALAR_ZNX_SIZE, - "rows != {} not supported for Scalar", - SCALAR_ZNX_SIZE - ); - module.n() * cols * std::mem::size_of::() + fn n(&self) -> usize { + self.n } -} -impl ZnxLayout for Scalar { - type Scalar = i64; -} + fn size(&self) -> usize { + 1 + } -impl ZnxSliceSize for Scalar { fn sl(&self) -> usize { self.n() } } -impl Scalar { +impl DataView for Scalar { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for Scalar { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for Scalar { + type Scalar = i64; +} + +impl + AsRef<[u8]>> Scalar { pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) { let choices: [i64; 3] = [-1, 0, 1]; let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; @@ -76,38 +71,89 @@ impl Scalar { self.at_mut(col, 0).shuffle(source); } - pub fn alias_as_vec_znx(&self) -> VecZnx { - VecZnx { - inner: ZnxBase { - n: self.n(), - rows: 1, - cols: 1, - size: 1, - data: Vec::new(), - ptr: self.ptr() as *mut u8, - }, + // pub fn alias_as_vec_znx(&self) -> VecZnx { + // VecZnx { + // inner: ZnxBase { + // n: self.n(), + // rows: 1, + // cols: 1, + // size: 1, + // data: Vec::new(), + // ptr: self.ptr() as *mut u8, + // }, + // } + // } +} + +impl>> Scalar { + pub(crate) fn bytes_of(n: usize, cols: usize) -> usize { + n * cols * size_of::() + } + + pub(crate) fn new(n: usize, cols: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of::(n, cols)); + Self { + data: data.into(), + n, + cols, + } + } + + pub(crate) fn new_from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of::(n, cols)); + Self { + data: data.into(), + n, + cols, } } } -pub trait ScalarOps { +pub type ScalarOwned = Scalar>; + +pub trait ScalarAlloc { fn bytes_of_scalar(&self, cols: usize) -> usize; - fn new_scalar(&self, cols: usize) -> Scalar; - fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> Scalar; - fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar; + fn new_scalar(&self, cols: usize) -> ScalarOwned; + fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarOwned; + // fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar; } -impl ScalarOps for Module { +impl ScalarAlloc for Module { fn bytes_of_scalar(&self, cols: usize) -> usize { - Scalar::bytes_of(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE) + ScalarOwned::bytes_of::(self.n(), cols) } - fn new_scalar(&self, cols: usize) -> Scalar { - Scalar::new(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE) + fn new_scalar(&self, cols: usize) -> ScalarOwned { + ScalarOwned::new::(self.n(), cols) } - fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> Scalar { - Scalar::from_bytes(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes) - } - fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar { - Scalar::from_bytes_borrow(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes) + fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarOwned { + ScalarOwned::new_from_bytes::(self.n(), cols, bytes) } + // fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar { + // Scalar::from_bytes_borrow(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes) + // } } + +// impl ZnxAlloc for Scalar { +// type Scalar = i64; + +// fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { +// Self { +// inner: ZnxBase::from_bytes_borrow(module.n(), SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes), +// } +// } + +// fn bytes_of(module: &Module, _rows: usize, cols: usize, _size: usize) -> usize { +// debug_assert_eq!( +// _rows, SCALAR_ZNX_ROWS, +// "rows != {} not supported for Scalar", +// SCALAR_ZNX_ROWS +// ); +// debug_assert_eq!( +// _size, SCALAR_ZNX_SIZE, +// "rows != {} not supported for Scalar", +// SCALAR_ZNX_SIZE +// ); +// module.n() * cols * std::mem::size_of::() +// } +// } diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index 6fdb991..09b26d4 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -1,67 +1,97 @@ use std::marker::PhantomData; use crate::ffi::svp; -use crate::znx_base::{ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; -use crate::{Backend, FFT64, GetZnxBase, Module}; +use crate::znx_base::ZnxInfos; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; pub const SCALAR_ZNX_DFT_ROWS: usize = 1; pub const SCALAR_ZNX_DFT_SIZE: usize = 1; -pub struct ScalarZnxDft { - pub inner: ZnxBase, - _marker: PhantomData, +pub struct ScalarZnxDft { + data: D, + n: usize, + cols: usize, + _phantom: PhantomData, } -impl GetZnxBase for ScalarZnxDft { - fn znx(&self) -> &ZnxBase { - &self.inner +impl ZnxInfos for ScalarZnxDft { + fn cols(&self) -> usize { + self.cols } - fn znx_mut(&mut self) -> &mut ZnxBase { - &mut self.inner + fn rows(&self) -> usize { + 1 + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + 1 + } + + fn sl(&self) -> usize { + self.n() } } -impl ZnxInfos for ScalarZnxDft {} - -impl ZnxAlloc for ScalarZnxDft { - type Scalar = u8; - - fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { - debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, _size)); - Self { - inner: ZnxBase::from_bytes_borrow( - module.n(), - SCALAR_ZNX_DFT_ROWS, - cols, - SCALAR_ZNX_DFT_SIZE, - bytes, - ), - _marker: PhantomData, - } - } - - fn bytes_of(module: &Module, _rows: usize, cols: usize, _size: usize) -> usize { - debug_assert_eq!( - _rows, SCALAR_ZNX_DFT_ROWS, - "rows != {} not supported for ScalarZnxDft", - SCALAR_ZNX_DFT_ROWS - ); - debug_assert_eq!( - _size, SCALAR_ZNX_DFT_SIZE, - "rows != {} not supported for ScalarZnxDft", - SCALAR_ZNX_DFT_SIZE - ); - unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols } +impl DataView for ScalarZnxDft { + type D = D; + fn data(&self) -> &Self::D { + &self.data } } -impl ZnxLayout for ScalarZnxDft { +impl DataViewMut for ScalarZnxDft { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for ScalarZnxDft { type Scalar = f64; } -impl ZnxSliceSize for ScalarZnxDft { - fn sl(&self) -> usize { - self.n() * self.cols() +impl>, B: Backend> ScalarZnxDft { + pub(crate) fn bytes_of(module: &Module, cols: usize) -> usize { + unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols } } + + pub(crate) fn new(module: &Module, cols: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of(module, cols)); + Self { + data: data.into(), + n: module.n(), + cols, + _phantom: PhantomData, + } + } + + pub(crate) fn new_from_bytes(module: &Module, cols: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(module, cols)); + Self { + data: data.into(), + n: module.n(), + cols, + _phantom: PhantomData, + } + } + + // fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { + // debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, _size)); + // Self { + // inner: ZnxBase::from_bytes_borrow( + // module.n(), + // SCALAR_ZNX_DFT_ROWS, + // cols, + // SCALAR_ZNX_DFT_SIZE, + // bytes, + // ), + // _phantom: PhantomData, + // } + // } } + +pub type ScalarZnxDftOwned = ScalarZnxDft, B>; diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs index 4fbe99d..fc56e4e 100644 --- a/base2k/src/scalar_znx_dft_ops.rs +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -1,35 +1,52 @@ use crate::ffi::svp::{self, svp_ppol_t}; use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize}; -use crate::{Backend, FFT64, Module, SCALAR_ZNX_DFT_ROWS, SCALAR_ZNX_DFT_SIZE, Scalar, ScalarZnxDft, VecZnx, VecZnxDft}; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; +use crate::{Backend, FFT64, Module, Scalar, ScalarZnxDft, ScalarZnxDftOwned, VecZnx, VecZnxDft}; -pub trait ScalarZnxDftOps { - fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDft; +pub trait ScalarZnxDftAlloc { + fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned; fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; - fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDft; - fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft; - fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize); - fn svp_apply_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &ScalarZnxDft, a_col: usize, b: &VecZnx, b_col: usize); + fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned; + // fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft; } -impl ScalarZnxDftOps for Module { - fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDft { - ScalarZnxDft::::new(&self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE) +pub trait ScalarZnxDftOps { + fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize); + fn svp_apply_dft( + &self, + res: &mut VecZnxDft, + res_col: usize, + a: &ScalarZnxDft, + a_col: usize, + b: &VecZnx, + b_col: usize, + ); +} + +impl ScalarZnxDftAlloc for Module { + fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned { + ScalarZnxDftOwned::new(self, cols) } fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize { - ScalarZnxDft::::bytes_of(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE) + ScalarZnxDftOwned::bytes_of(self, cols) } - fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDft { - ScalarZnxDft::from_bytes(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes) + fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned { + ScalarZnxDftOwned::new_from_bytes(self, cols, bytes) } - fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft { - ScalarZnxDft::from_bytes_borrow(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes) - } + // fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft { + // ScalarZnxDft::from_bytes_borrow(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes) + // } +} - fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize) { +impl ScalarZnxDftOps for Module +where + DataMut: AsMut<[u8]> + AsRef<[u8]>, + Data: AsRef<[u8]>, +{ + fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize) { unsafe { svp::svp_prepare( self.ptr, @@ -41,11 +58,11 @@ impl ScalarZnxDftOps for Module { fn svp_apply_dft( &self, - res: &mut VecZnxDft, + res: &mut VecZnxDft, res_col: usize, - a: &ScalarZnxDft, + a: &ScalarZnxDft, a_col: usize, - b: &VecZnx, + b: &VecZnx, b_col: usize, ) { unsafe { diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs index a1946ab..c6d16b4 100644 --- a/base2k/src/stats.rs +++ b/base2k/src/stats.rs @@ -9,7 +9,7 @@ pub trait Stats { fn std(&self, col_i: usize, log_base2k: usize) -> f64; } -impl Stats for VecZnx { +impl + AsRef<[u8]>> Stats for VecZnx { fn std(&self, col_i: usize, log_base2k: usize) -> f64 { let prec: u32 = (self.size() * log_base2k) as u32; let mut data: Vec = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index b76f93d..3321f8e 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,13 +1,10 @@ -use crate::Backend; use crate::DataView; use crate::DataViewMut; -use crate::Module; -use crate::ZnxView; use crate::alloc_aligned; use crate::assert_alignement; use crate::cast_mut; use crate::ffi::znx; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxRsh, ZnxZero, switch_degree}; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut, switch_degree}; use std::{cmp::min, fmt}; // pub const VEC_ZNX_ROWS: usize = 1; @@ -59,7 +56,7 @@ impl DataView for VecZnx { } impl DataViewMut for VecZnx { - fn data_mut(&self) -> &mut Self::D { + fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } @@ -84,7 +81,7 @@ impl + AsRef<[u8]>> VecZnx { return; } - self.inner.size -= k / log_base2k; + self.size -= k / log_base2k; let k_rem: usize = k % log_base2k; @@ -97,7 +94,7 @@ impl + AsRef<[u8]>> VecZnx { } /// Switches degree of from `a.n()` to `self.n()` into `self` - pub fn switch_degree>(&mut self, col: usize, a: &Data, col_a: usize) { + pub fn switch_degree>(&mut self, col: usize, a: &VecZnx, col_a: usize) { switch_degree(self, col_a, a, col) } @@ -161,7 +158,7 @@ fn normalize_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } -fn normalize>(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { +fn normalize + AsRef<[u8]>>(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); debug_assert!( diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 682493a..72b15d7 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,11 +1,11 @@ use crate::ffi::vec_znx_big; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxView}; +use crate::znx_base::{ZnxInfos, ZnxView}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, alloc_aligned}; use std::marker::PhantomData; const VEC_ZNX_BIG_ROWS: usize = 1; -/// VecZnxBig is Backend dependent, denoted with backend generic `B` +/// VecZnxBig is `Backend` dependent, denoted with backend generic `B` pub struct VecZnxBig { data: D, n: usize, @@ -44,7 +44,7 @@ impl DataView for VecZnxBig { } impl DataViewMut for VecZnxBig { - fn data_mut(&self) -> &mut Self::D { + fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 5353c32..bb46802 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,6 +1,6 @@ use crate::ffi::vec_znx; -use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{Backend, DataView, FFT64, Module, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxOps, assert_alignement}; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; +use crate::{Backend, DataView, FFT64, Module, ScratchSpace, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxOps, assert_alignement}; pub trait VecZnxBigAlloc { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. @@ -79,13 +79,13 @@ pub trait VecZnxBigOps { b_col: usize, ); - /// Subtracts `a` to `b` and stores the result on `b`. + /// Subtracts `a` from `b` and stores the result on `b`. fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); - /// Subtracts `b` to `a` and stores the result on `b`. + /// Subtracts `b` from `a` and stores the result on `b`. fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); - /// Subtracts `b` to `a` and stores the result on `c`. + /// Subtracts `b` from `a` and stores the result on `c`. fn vec_znx_big_sub_small_a( &self, res: &mut VecZnxBig, @@ -96,10 +96,10 @@ pub trait VecZnxBigOps { b_col: usize, ); - /// Subtracts `a` to `b` and stores the result on `b`. + /// Subtracts `a` from `res` and stores the result on `res`. fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); - /// Subtracts `b` to `a` and stores the result on `c`. + /// Subtracts `b` from `a` and stores the result on `c`. fn vec_znx_big_sub_small_b( &self, res: &mut VecZnxBig, @@ -110,7 +110,7 @@ pub trait VecZnxBigOps { b_col: usize, ); - /// Subtracts `b` to `a` and stores the result on `b`. + /// Subtracts `res` from `a` and stores the result on `res`. fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. @@ -129,7 +129,7 @@ pub trait VecZnxBigOps { res_col: usize, a: &VecZnxBig, a_col: usize, - tmp_bytes: &mut [u8], + scratch: &mut ScratchSpace, ); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. @@ -160,7 +160,7 @@ impl VecZnxBigAlloc for Module { // } fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { - VecZnxBig::bytes_of(self, cols, size) + VecZnxBigOwned::bytes_of(self, cols, size) } } @@ -208,8 +208,24 @@ where a: &VecZnxBig, a_col: usize, ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } unsafe { - Self::vec_znx_big_add(self, res, res_col, a, a_col, res, res_col); + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) } } @@ -245,7 +261,6 @@ where } } - //(Jay)TODO: check whether definitions sub_ab, sub_ba make sense to you fn vec_znx_big_sub_ab_inplace( &self, res: &mut VecZnxBig, @@ -253,8 +268,24 @@ where a: &VecZnxBig, a_col: usize, ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } unsafe { - Self::vec_znx_big_sub(self, res, res_col, a, a_col, res, res_col); + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) } } @@ -265,8 +296,24 @@ where a: &VecZnxBig, a_col: usize, ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } unsafe { - Self::vec_znx_big_sub(self, res, res_col, res, res_col, a, a_col); + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) } } @@ -309,8 +356,24 @@ where a: &VecZnx, a_col: usize, ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } unsafe { - Self::vec_znx_big_sub_small_b(self, res, res_col, res, res_col, a, a_col); + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) } } @@ -353,8 +416,24 @@ where a: &VecZnx, a_col: usize, ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } unsafe { - Self::vec_znx_big_sub_small_a(self, res, res_col, a, a_col, res, res_col); + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) } } @@ -391,11 +470,29 @@ where } fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { - Self::vec_znx_big_add_small(self, res, res_col, res, res_col, a, a_col); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } } fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { - Self::vec_znx_normalize_tmp_bytes(self) + >::vec_znx_normalize_tmp_bytes(self) } fn vec_znx_big_normalize( @@ -405,14 +502,16 @@ where res_col: usize, a: &VecZnxBig, a_col: usize, - tmp_bytes: &mut [u8], + scratch: &mut ScratchSpace, ) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); assert_eq!(res.n(), self.n()); - assert!(tmp_bytes.len() >= Self::vec_znx_normalize_tmp_bytes(&self)); - assert_alignement(tmp_bytes.as_ptr()); + //(Jay)Note: This is calling VezZnxOps::vec_znx_normalize_tmp_bytes and not VecZnxBigOps::vec_znx_big_normalize_tmp_bytes. + // In the FFT backend the tmp sizes are same but will be different in the NTT backend + // assert!(tmp_bytes.len() >= >::vec_znx_normalize_tmp_bytes(&self)); + // assert_alignement(tmp_bytes.as_ptr()); } unsafe { vec_znx::vec_znx_normalize_base2k( @@ -424,7 +523,7 @@ where a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, - tmp_bytes.as_mut_ptr(), + scratch.vec_znx_big_normalize_tmp_bytes(self).as_mut_ptr(), ); } } @@ -457,8 +556,21 @@ where } fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } unsafe { - Self::vec_znx_big_automorphism(self, k, a, a_col, a, a_col); + vec_znx::vec_znx_automorphism( + self.ptr, + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) } } } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index c192486..74b559c 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,11 +1,12 @@ use std::marker::PhantomData; use crate::ffi::vec_znx_dft; -use crate::znx_base::{ZnxAlloc, ZnxInfos}; +use crate::znx_base::ZnxInfos; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; const VEC_ZNX_DFT_ROWS: usize = 1; +// VecZnxDft is `Backend` dependent denoted with generic `B` pub struct VecZnxDft { data: D, n: usize, @@ -44,7 +45,7 @@ impl DataView for VecZnxDft { } impl DataViewMut for VecZnxDft { - fn data_mut(&self) -> &mut Self::D { + fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } @@ -84,6 +85,18 @@ impl>, B: Backend> VecZnxDft { pub type VecZnxDftOwned = VecZnxDft, B>; +impl<'a, D: ?Sized, B> VecZnxDft<&'a mut D, B> { + pub(crate) fn from_mut_slice(data: &'a mut D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + _phantom: PhantomData, + } + } +} + // impl ZnxAlloc for VecZnxDft { // type Scalar = u8; diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index cf2090b..2c1cc97 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -1,7 +1,5 @@ use crate::VecZnxDftOwned; -use crate::ffi::vec_znx_big; -use crate::ffi::vec_znx_dft; -use crate::znx_base::ZnxAlloc; +use crate::ffi::{vec_znx_big, vec_znx_dft}; use crate::znx_base::ZnxInfos; use crate::{FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxView, ZnxViewMut, ZnxZero, assert_alignement}; use std::cmp::min; @@ -82,7 +80,7 @@ impl VecZnxDftAlloc for Module { // } fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { - VecZnxDft::bytes_of(&self, cols, size) + VecZnxDftOwned::bytes_of(&self, cols, size) } } @@ -156,10 +154,10 @@ where #[cfg(debug_assertions)] { assert!( - tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self), + tmp_bytes.len() >= >::vec_znx_idft_tmp_bytes(self), "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", tmp_bytes.len(), - Self::vec_znx_idft_tmp_bytes(self) + >::vec_znx_idft_tmp_bytes(self) ); assert_alignement(tmp_bytes.as_ptr()) } diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 339bc12..6951651 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -86,10 +86,14 @@ pub trait VecZnxOps { ); /// Subtracts the selected column of `a` from the selected column of `res` inplace. + /// + /// res[res_col] -= a[a_col] fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); - // /// Subtracts the selected column of `a` from the selected column of `res` and negates the selected column of `res`. - // fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + /// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res` + /// + /// res[res_col] = a[a_col] - res[res_col] + fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); // Negates the selected column of `a` and stores the result in `res_col` of `res`. fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); @@ -136,15 +140,15 @@ pub trait VecZnxOps { impl VecZnxAlloc for Module { //(Jay)TODO: One must define the Scalar generic param here. fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned { - VecZnxOwned::new(self.n(), cols, size) + VecZnxOwned::new::(self.n(), cols, size) } fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize { - VecZnxOwned::bytes_of(self.n(), cols, size) + VecZnxOwned::bytes_of::(self.n(), cols, size) } fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned { - VecZnxOwned::new_from_bytes(self.n(), cols, size, bytes) + VecZnxOwned::new_from_bytes::(self.n(), cols, size, bytes) } } @@ -170,7 +174,7 @@ where { assert_eq!(a.n(), self.n()); assert_eq!(res.n(), self.n()); - assert!(tmp_bytes.len() >= Self::vec_znx_normalize_tmp_bytes(&self)); + assert!(tmp_bytes.len() >= >::vec_znx_normalize_tmp_bytes(&self)); assert_alignement(tmp_bytes.as_ptr()); } unsafe { @@ -190,16 +194,8 @@ where fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { unsafe { - let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_normalize( - self, - log_base2k, - &mut *a_ptr, - a_col, - &*a_ptr, - a_col, - tmp_bytes, - ); + let a_ptr: *const VecZnx<_> = a; + Self::vec_znx_normalize(self, log_base2k, a, a_col, &*a_ptr, a_col, tmp_bytes); } } @@ -236,8 +232,24 @@ where } fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } unsafe { - Self::vec_znx_add(&self, res, res_col, a, a_col, res, res_col); + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) } } @@ -274,18 +286,48 @@ where } fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } unsafe { - let res_ptr: *mut VecZnx = res as *mut VecZnx; - Self::vec_znx_sub(self, res, res_col, a, a_col, res, res_col); + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) } } - // fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { - // unsafe { - // let res_ptr: *mut VecZnx = res as *mut VecZnx; - // Self::vec_znx_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); - // } - // } + fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) + } + } fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { #[cfg(debug_assertions)] @@ -308,7 +350,8 @@ where fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize) { unsafe { - Self::vec_znx_negate(self, a, a_col, a, a_col); + let a_ref: *const VecZnx<_> = a; + Self::vec_znx_negate(self, a, a_col, a_ref.as_ref().unwrap(), a_col); } } @@ -333,8 +376,21 @@ where } fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } unsafe { - Self::vec_znx_rotate(self, k, a, a_col, a, a_col); + vec_znx::vec_znx_rotate( + self.ptr, + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) } } @@ -359,8 +415,21 @@ where } fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } unsafe { - Self::vec_znx_automorphism(self, k, a, a_col, a, a_col); + vec_znx::vec_znx_automorphism( + self.ptr, + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) } } @@ -392,7 +461,7 @@ where self.vec_znx_rotate(-1, buf, 0, a, a_col); } else { switch_degree(bi, res_col, buf, a_col); - self.vec_znx_rotate_inplace(-1, buf, a_col); + >::vec_znx_rotate_inplace(self, -1, buf, a_col); } }) } @@ -414,9 +483,9 @@ where a.iter().enumerate().for_each(|(_, ai)| { switch_degree(res, res_col, ai, a_col); - self.vec_znx_rotate_inplace(-1, res, res_col); + >::vec_znx_rotate_inplace(self, -1, res, res_col); }); - self.vec_znx_rotate_inplace(a.len() as i64, res, res_col); + >::vec_znx_rotate_inplace(self, a.len() as i64, res, res_col); } } diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index bf941d4..a7361ad 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -85,26 +85,26 @@ pub trait ZnxInfos { // pub trait ZnxSliceSize {} //(Jay) TODO: Remove ZnxAlloc -pub trait ZnxAlloc -where - Self: Sized + ZnxInfos, -{ - type Scalar; - fn new(module: &Module, rows: usize, cols: usize, size: usize) -> Self { - let bytes: Vec = alloc_aligned::(Self::bytes_of(module, rows, cols, size)); - Self::from_bytes(module, rows, cols, size, bytes) - } +// pub trait ZnxAlloc +// where +// Self: Sized + ZnxInfos, +// { +// type Scalar; +// fn new(module: &Module, rows: usize, cols: usize, size: usize) -> Self { +// let bytes: Vec = alloc_aligned::(Self::bytes_of(module, rows, cols, size)); +// Self::from_bytes(module, rows, cols, size, bytes) +// } - fn from_bytes(module: &Module, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { - let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes); - res.znx_mut().data = bytes; - res - } +// fn from_bytes(module: &Module, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { +// let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes); +// res.znx_mut().data = bytes; +// res +// } - fn from_bytes_borrow(module: &Module, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self; +// fn from_bytes_borrow(module: &Module, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self; - fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize; -} +// fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize; +// } pub trait DataView { type D; @@ -112,11 +112,11 @@ pub trait DataView { } pub trait DataViewMut: DataView { - fn data_mut(&self) -> &mut Self::D; + fn data_mut(&mut self) -> &mut Self::D; } pub trait ZnxView: ZnxInfos + DataView> { - type Scalar; + type Scalar: Copy; /// Returns a non-mutable pointer to the underlying coefficients array. fn as_ptr(&self) -> *const Self::Scalar { @@ -177,11 +177,9 @@ pub trait ZnxViewMut: ZnxView + DataViewMut> { impl ZnxViewMut for T where T: ZnxView + DataViewMut> {} use std::convert::TryFrom; -use std::num::TryFromIntError; use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; -pub trait IntegerType: +pub trait Num: Copy - + std::fmt::Debug + Default + PartialEq + PartialOrd @@ -190,22 +188,23 @@ pub trait IntegerType: + Mul + Div + Neg - + Shr - + Shl + AddAssign - + TryFrom { const BITS: u32; } -impl IntegerType for i64 { +impl Num for i64 { const BITS: u32 = 64; } -impl IntegerType for i128 { +impl Num for i128 { const BITS: u32 = 128; } +impl Num for f64 { + const BITS: u32 = 64; +} + pub trait ZnxZero: ZnxViewMut where Self: Sized, @@ -231,79 +230,16 @@ where } } -pub trait ZnxRsh: ZnxZero { - fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) { - rsh(k, log_base2k, self, col, carry) - } -} - // Blanket implementations impl ZnxZero for T where T: ZnxViewMut {} -impl ZnxRsh for T where T: ZnxZero {} +// impl ZnxRsh for T where T: ZnxZero {} -pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) -where - V::Scalar: IntegerType, -{ - let n: usize = a.n(); - let size: usize = a.size(); - let cols: usize = a.cols(); - - #[cfg(debug_assertions)] - { - assert!( - tmp_bytes.len() >= rsh_tmp_bytes::(n), - "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", - tmp_bytes.len() / size_of::(), - n, - size, - ); - assert_alignement(tmp_bytes.as_ptr()); - } - - let size: usize = a.size(); - let steps: usize = k / log_base2k; - - a.raw_mut().rotate_right(n * steps * cols); - (0..cols).for_each(|i| { - (0..steps).for_each(|j| { - a.zero_at(i, j); - }) - }); - - let k_rem: usize = k % log_base2k; - - if k_rem != 0 { - let carry: &mut [V::Scalar] = cast_mut(tmp_bytes); - - unsafe { - std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::()); - } - - let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap(); - let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap(); - let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap(); - - (steps..size).for_each(|i| { - izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| { - *xi += *ci << log_base2k_t; - *ci = get_base_k_carry(*xi, shift); - *xi = (*xi - *ci) >> k_rem_t; - }); - }) - } -} - -#[inline(always)] -fn get_base_k_carry(x: T, shift: T) -> T { - (x << shift) >> shift -} - -pub fn rsh_tmp_bytes(n: usize) -> usize { - n * std::mem::size_of::() -} - -pub fn switch_degree(b: &mut DMut, col_b: usize, a: &D, col_a: usize) { +pub fn switch_degree + ZnxZero, D: ZnxView>( + b: &mut DMut, + col_b: usize, + a: &D, + col_a: usize, +) { let (n_in, n_out) = (a.n(), b.n()); let (gap_in, gap_out): (usize, usize); @@ -325,6 +261,71 @@ pub fn switch_degree(b: &mut DMut, col_b }); } +// (Jay)TODO: implement rsh for VecZnx, VecZnxBig +// pub trait ZnxRsh: ZnxZero { +// fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) { +// rsh(k, log_base2k, self, col, carry) +// } +// } +// pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) { +// let n: usize = a.n(); +// let size: usize = a.size(); +// let cols: usize = a.cols(); + +// #[cfg(debug_assertions)] +// { +// assert!( +// tmp_bytes.len() >= rsh_tmp_bytes::(n), +// "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", +// tmp_bytes.len() / size_of::(), +// n, +// size, +// ); +// assert_alignement(tmp_bytes.as_ptr()); +// } + +// let size: usize = a.size(); +// let steps: usize = k / log_base2k; + +// a.raw_mut().rotate_right(n * steps * cols); +// (0..cols).for_each(|i| { +// (0..steps).for_each(|j| { +// a.zero_at(i, j); +// }) +// }); + +// let k_rem: usize = k % log_base2k; + +// if k_rem != 0 { +// let carry: &mut [V::Scalar] = cast_mut(tmp_bytes); + +// unsafe { +// std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::()); +// } + +// let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap(); +// let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap(); +// let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap(); + +// (steps..size).for_each(|i| { +// izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| { +// *xi += *ci << log_base2k_t; +// *ci = get_base_k_carry(*xi, shift); +// *xi = (*xi - *ci) >> k_rem_t; +// }); +// }) +// } +// } + +// #[inline(always)] +// fn get_base_k_carry(x: T, shift: T) -> T { +// (x << shift) >> shift +// } + +// pub fn rsh_tmp_bytes(n: usize) -> usize { +// n * std::mem::size_of::() +// } + // pub trait ZnxLayout: ZnxInfos { // type Scalar; diff --git a/rlwe/src/automorphism.rs b/rlwe/src/automorphism.rs index 95a935f..ea2b834 100644 --- a/rlwe/src/automorphism.rs +++ b/rlwe/src/automorphism.rs @@ -7,8 +7,8 @@ use crate::{ parameters::Parameters, }; use base2k::{ - Module, Scalar, ScalarOps, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, MatZnxDft, - MatZnxDftOps, assert_alignement, + MatZnxDft, MatZnxDftOps, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, + VecZnxDft, VecZnxDftOps, VecZnxOps, assert_alignement, }; use sampling::source::Source; use std::collections::HashMap; From b82a1ca1b4cfb124a65d76f1e6bd01249555ecbe Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Sun, 4 May 2025 18:39:28 +0530 Subject: [PATCH 26/87] wip --- base2k/src/lib.rs | 196 ++++++++++++++++++++--- base2k/src/mat_znx_dft.rs | 48 +++--- base2k/src/mat_znx_dft_ops.rs | 278 ++++++++++++++++++-------------- base2k/src/vec_znx.rs | 53 +++--- base2k/src/vec_znx_big.rs | 46 +++++- base2k/src/vec_znx_big_ops.rs | 35 ++-- base2k/src/vec_znx_dft.rs | 16 +- base2k/src/vec_znx_dft_ops.rs | 11 +- base2k/src/vec_znx_ops.rs | 20 ++- base2k/src/znx_base.rs | 294 +++++++++------------------------- 10 files changed, 551 insertions(+), 446 deletions(-) diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 7ae1193..f33ce60 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -18,10 +18,17 @@ pub mod vec_znx_dft_ops; pub mod vec_znx_ops; pub mod znx_base; +use std::{ + any::type_name, + ops::{DerefMut, Sub}, +}; + pub use encoding::*; pub use mat_znx_dft::*; pub use mat_znx_dft_ops::*; pub use module::*; +use rand_core::le; +use rand_distr::num_traits::sign; pub use sampling::*; pub use scalar_znx::*; pub use scalar_znx_dft::*; @@ -126,28 +133,177 @@ pub fn alloc_aligned(size: usize) -> Vec { ) } -pub(crate) struct ScratchSpace { - // data: D, -} +pub struct ScratchOwned(Vec); -impl ScratchSpace { - fn tmp_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> VecZnxDft { - todo!() +impl ScratchOwned { + pub fn new(byte_count: usize) -> Self { + let data: Vec = alloc_aligned(byte_count); + Self(data) } - fn tmp_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> VecZnxBig { - todo!() - } - - fn vec_znx_big_normalize_tmp_bytes(&mut self, module: &Module) -> &mut [u8] { - todo!() - } - - fn vmp_apply_dft_tmp_bytes(&mut self, module: &Module) -> &mut [u8] { - todo!() - } - - fn vmp_apply_dft_to_dft_tmp_bytes(&mut self, module: &Module) -> &mut [u8] { - todo!() + pub fn borrow(&mut self) -> &mut ScratchBorr { + ScratchBorr::new(&mut self.0) } } + +pub struct ScratchBorr { + data: [u8], +} + +impl ScratchBorr { + fn new(data: &mut [u8]) -> &mut Self { + unsafe { &mut *(data as *mut [u8] as *mut Self) } + } + + fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { + let ptr = data.as_mut_ptr(); + let self_len = data.len(); + + let aligned_offset = ptr.align_offset(DEFAULTALIGN); + let aligned_len = self_len.saturating_sub(aligned_offset); + + if let Some(rem_len) = aligned_len.checked_sub(take_len) { + unsafe { + let rem_ptr = ptr.add(aligned_offset).add(take_len); + let rem_slice = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); + + let take_slice = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len); + + return (take_slice, rem_slice); + } + } else { + panic!( + "Attempted to take {} from scratch with {} aligned bytes left", + take_len, + take_len, + // type_name::(), + // aligned_len + ); + } + } + + fn tmp_scalar_slice(&mut self, len: usize) -> (&mut [T], &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::()); + + unsafe { + ( + &mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)), + Self::new(rem_slice), + ) + } + } + + fn tmp_vec_znx_dft( + &mut self, + module: &Module, + cols: usize, + size: usize, + ) -> (VecZnxDft<&mut [u8], B>, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_dft(module, cols, size)); + + ( + VecZnxDft::from_data(take_slice, module.n(), cols, size), + Self::new(rem_slice), + ) + } + + fn tmp_vec_znx_big From<&'a mut [u8]>, B: Backend>( + &mut self, + module: &Module, + cols: usize, + size: usize, + ) -> (VecZnxBig, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_big(module, cols, size)); + + ( + VecZnxBig::from_data(D::from(take_slice), module.n(), cols, size), + Self::new(rem_slice), + ) + } +} + +// pub struct ScratchBorrowed<'a> { +// data: &'a mut [u8], +// } + +// impl<'a> ScratchBorrowed<'a> { +// fn take_slice(&mut self, take_len: usize) -> (&mut [T], ScratchBorrowed<'_>) { +// let ptr = self.data.as_mut_ptr(); +// let self_len = self.data.len(); + +// //TODO(Jay): print the offset sometimes, just to check +// let aligned_offset = ptr.align_offset(DEFAULTALIGN); +// let aligned_len = self_len.saturating_sub(aligned_offset); + +// let take_len_bytes = take_len * std::mem::size_of::(); + +// if let Some(rem_len) = aligned_len.checked_sub(take_len_bytes) { +// unsafe { +// let rem_ptr = ptr.add(aligned_offset).add(take_len_bytes); +// let rem_slice = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); + +// let take_slice = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset) as *mut T, take_len_bytes); + +// return (take_slice, ScratchBorrowed { data: rem_slice }); +// } +// } else { +// panic!( +// "Attempted to take {} (={} elements of {}) from scratch with {} aligned bytes left", +// take_len_bytes, +// take_len, +// type_name::(), +// aligned_len +// ); +// } +// } + +// fn reborrow(&mut self) -> ScratchBorrowed<'a> { +// //(Jay)TODO: `data: &mut *self.data` does not work because liftime of &mut self is different from 'a. +// // But it feels that there should be a simpler impl. than the one below +// Self { +// data: unsafe { &mut *std::ptr::slice_from_raw_parts_mut(self.data.as_mut_ptr(), self.data.len()) }, +// } +// } + +// fn tmp_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, Self) { +// let (data, re_scratch) = self.take_slice::(vec_znx_dft::bytes_of_vec_znx_dft(module, cols, size)); +// ( +// VecZnxDft::from_data(data, module.n(), cols, size), +// re_scratch, +// ) +// } + +// pub(crate) fn len(&self) -> usize { +// self.data.len() +// } +// } + +// pub trait Scratch { +// fn tmp_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (D, &mut Self); +// } + +// impl<'a> Scratch<&'a mut [u8]> for ScratchBorr { +// fn tmp_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (&'a mut [u8], &mut Self) { +// let (data, rem_scratch) = self.tmp_scalar_slice(vec_znx_dft::bytes_of_vec_znx_dft(module, cols, size)); +// ( +// data +// rem_scratch, +// ) +// } + +// // fn tmp_vec_znx_big(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, Self) { +// // // let (data, re_scratch) = self.take_slice(vec_znx_big::bytes_of_vec_znx_big(module, cols, size)); +// // // ( +// // // VecZnxBig::from_data(data, module.n(), cols, size), +// // // re_scratch, +// // // ) +// // } + +// // fn scalar_slice(&mut self, len: usize) -> (&mut [T], Self) { +// // self.take_slice::(len) +// // } + +// // fn reborrow(&mut self) -> Self { +// // self.reborrow() +// // } +// } diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index 34c711a..7a39dd1 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,4 +1,4 @@ -use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos}; +use crate::znx_base::ZnxInfos; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; use std::marker::PhantomData; @@ -111,26 +111,6 @@ impl>, B: Backend> MatZnxDft { _marker: PhantomData, } } - - // pub fn from_bytes_borrow( - // module: &Module, - // rows: usize, - // cols_in: usize, - // cols_out: usize, - // size: usize, - // bytes: &mut [u8], - // ) -> Self { - // debug_assert_eq!( - // bytes.len(), - // Self::bytes_of(module, rows, cols_in, cols_out, size) - // ); - // Self { - // inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes), - // cols_in: cols_in, - // cols_out: cols_out, - // _marker: PhantomData, - // } - // } } impl> MatZnxDft { @@ -170,3 +150,29 @@ impl> MatZnxDft { } pub type MatZnxDftAllocOwned = MatZnxDft, B>; + +impl MatZnxDft, B> { + pub fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + MatZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + size: self.size, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + _marker: PhantomData, + } + } + + pub fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data.as_slice(), + n: self.n, + size: self.size, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + _marker: PhantomData, + } + } +} diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 62b56a1..5ab44df 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -2,8 +2,8 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchSpace, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, - VecZnxDftAlloc, VecZnxDftOps, assert_alignement, is_aligned, + Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchBorr, VecZnx, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, + VecZnxDftAlloc, VecZnxDftOps, }; pub trait MatZnxDftAlloc { @@ -36,12 +36,55 @@ pub trait MatZnxDftAlloc { // ) -> MatZnxDft; } -/// This trait implements methods for vector matrix product, -/// that is, multiplying a [VecZnx] with a [MatZnxDft]. -pub trait MatZnxDftOps { +pub trait MatZnxDftScratch { /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row] fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; + /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row] + fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; + + /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. + /// + /// # Arguments + /// + /// * `c_size`: number of size of the output [VecZnxDft]. + /// * `a_size`: number of size of the input [VecZnx]. + /// * `rows`: number of rows of the input [MatZnxDft]. + /// * `size`: number of size of the input [MatZnxDft]. + fn vmp_apply_dft_tmp_bytes( + &self, + c_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; + + /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. + /// + /// # Arguments + /// + /// * `c_size`: number of size of the output [VecZnxDft]. + /// * `a_size`: number of size of the input [VecZnxDft]. + /// * `rows`: number of rows of the input [MatZnxDft]. + /// * `size`: number of size of the input [MatZnxDft]. + fn vmp_apply_dft_to_dft_tmp_bytes( + &self, + c_cols: usize, + c_size: usize, + a_cols: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; +} + +/// This trait implements methods for vector matrix product, +/// that is, multiplying a [VecZnx] with a [MatZnxDft]. +pub trait MatZnxDftOps { /// Prepares the ith-row of [MatZnxDft] from a [VecZnx]. /// /// # Arguments @@ -58,12 +101,9 @@ pub trait MatZnxDftOps { b_row: usize, b_col_in: usize, a: &VecZnx, - scratch: &mut ScratchSpace, + scratch: &mut ScratchBorr, ); - /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row] - fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; - /// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig]. /// /// # Arguments @@ -78,7 +118,7 @@ pub trait MatZnxDftOps { a: &MatZnxDft, b_row: usize, b_col_in: usize, - scratch: &mut ScratchSpace, + scratch: &mut ScratchBorr, ); /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. @@ -101,24 +141,6 @@ pub trait MatZnxDftOps { /// * `row_i`: the index of the row to extract. fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize); - /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. - /// - /// # Arguments - /// - /// * `c_size`: number of size of the output [VecZnxDft]. - /// * `a_size`: number of size of the input [VecZnx]. - /// * `rows`: number of rows of the input [MatZnxDft]. - /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_tmp_bytes( - &self, - c_size: usize, - a_size: usize, - b_rows: usize, - b_cols_in: usize, - b_cols_out: usize, - b_size: usize, - ) -> usize; - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] @@ -143,27 +165,7 @@ pub trait MatZnxDftOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, scratch: &mut ScratchSpace); - - /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. - /// - /// # Arguments - /// - /// * `c_size`: number of size of the output [VecZnxDft]. - /// * `a_size`: number of size of the input [VecZnxDft]. - /// * `rows`: number of rows of the input [MatZnxDft]. - /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_to_dft_tmp_bytes( - &self, - c_cols: usize, - c_size: usize, - a_cols: usize, - a_size: usize, - b_rows: usize, - b_cols_in: usize, - b_cols_out: usize, - b_size: usize, - ) -> usize; + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, scratch: &mut ScratchBorr); /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -195,7 +197,7 @@ pub trait MatZnxDftOps { c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, - scratch: &mut ScratchSpace, + scratch: &mut ScratchBorr, ); } @@ -220,22 +222,70 @@ impl MatZnxDftAlloc for Module { } } -impl MatZnxDftOps for Module -where - DataMut: AsMut<[u8]> + AsRef<[u8]>, - Data: AsRef<[u8]>, -{ +impl MatZnxDftScratch for Module { fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { - >::bytes_of_vec_znx_dft(self, cols_out, size) + >::bytes_of_vec_znx_dft(self, cols_out, size) } + fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { + >::bytes_of_vec_znx_dft(self, cols_out, size) + + ::vec_znx_big_normalize_tmp_bytes(self) + } + + fn vmp_apply_dft_tmp_bytes( + &self, + c_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize { + unsafe { + vmp::vmp_apply_dft_tmp_bytes( + self.ptr, + c_size as u64, + a_size as u64, + (b_rows * b_cols_in) as u64, + (b_size * b_cols_out) as u64, + ) as usize + } + } + fn vmp_apply_dft_to_dft_tmp_bytes( + &self, + c_cols: usize, + c_size: usize, + a_cols: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize { + unsafe { + vmp::vmp_apply_dft_to_dft_tmp_bytes( + self.ptr, + (c_size * c_cols) as u64, + (a_size * a_cols) as u64, + (b_rows * b_cols_in) as u64, + (b_size * b_cols_out) as u64, + ) as usize + } + } +} + +impl MatZnxDftOps for Module +where + DataMut: AsMut<[u8]> + AsRef<[u8]> + for<'a> From<&'a mut [u8]>, + Data: AsRef<[u8]>, +{ fn vmp_prepare_row( &self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnx, - scratch: &mut ScratchSpace, + scratch: &mut ScratchBorr, ) { #[cfg(debug_assertions)] { @@ -278,17 +328,13 @@ where let a_size: usize = a.size(); // let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size)); - let mut a_dft = scratch.tmp_vec_znx_dft::(self.n(), cols_out, a_size); + let (mut a_dft, _) = scratch.tmp_scalar_slice(12); + DataMut::from(a_dft); + // let (mut a_dft, _) = scratch.tmp_vec_znx_dft::(self, cols_out, a_size); (0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i)); - Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft); } - fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { - self.bytes_of_vec_znx_dft(cols_out, size) - + >::vec_znx_big_normalize_tmp_bytes(self) - } - fn vmp_extract_row( &self, log_base2k: usize, @@ -296,7 +342,7 @@ where a: &MatZnxDft, a_row: usize, a_col_in: usize, - scratch: &mut ScratchSpace, + mut scratch: &mut ScratchBorr, ) { #[cfg(debug_assertions)] { @@ -336,9 +382,9 @@ where let size: usize = b.size(); // let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size)); - let mut b_dft = scratch.tmp_vec_znx_dft::(self.n(), cols_out, size); + let (mut b_dft, scratch) = scratch.tmp_vec_znx_dft(self, cols_out, size); Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in); - let mut b_big = scratch.tmp_vec_znx_big(self.n(), cols_out, size); + let (mut b_big, scratch) = scratch.tmp_vec_znx_big(self, cols_out, size); (0..cols_out).for_each(|i| { >::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i); self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, scratch); @@ -434,32 +480,12 @@ where } } - fn vmp_apply_dft_tmp_bytes( - &self, - res_size: usize, - a_size: usize, - b_rows: usize, - b_cols_in: usize, - b_cols_out: usize, - b_size: usize, - ) -> usize { - unsafe { - vmp::vmp_apply_dft_tmp_bytes( - self.ptr, - res_size as u64, - a_size as u64, - (b_rows * b_cols_in) as u64, - (b_size * b_cols_out) as u64, - ) as usize - } - } - fn vmp_apply_dft( &self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, - scratch: &mut ScratchSpace, + mut scratch: &mut ScratchBorr, ) { #[cfg(debug_assertions)] { @@ -493,6 +519,16 @@ where // ); // assert_alignement(tmp_bytes.as_ptr()); } + let (tmp_bytes, _) = scratch.tmp_scalar_slice(::vmp_apply_dft_tmp_bytes( + self, + c.size(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size(), + )); + unsafe { vmp::vmp_apply_dft( self.ptr, @@ -504,39 +540,17 @@ where b.as_ptr() as *const vmp::vmp_pmat_t, (b.rows() * b.cols_in()) as u64, (b.size() * b.cols_out()) as u64, - scratch.vmp_apply_dft_tmp_bytes(self).as_mut_ptr(), + tmp_bytes.as_mut_ptr(), ) } } - fn vmp_apply_dft_to_dft_tmp_bytes( - &self, - res_cols: usize, - res_size: usize, - a_size: usize, - a_cols: usize, - b_rows: usize, - b_cols_in: usize, - b_cols_out: usize, - b_size: usize, - ) -> usize { - unsafe { - vmp::vmp_apply_dft_to_dft_tmp_bytes( - self.ptr, - (res_size * res_cols) as u64, - (a_size * a_cols) as u64, - (b_rows * b_cols_in) as u64, - (b_size * b_cols_out) as u64, - ) as usize - } - } - fn vmp_apply_dft_to_dft( &self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, - scratch: &mut ScratchSpace, + mut scratch: &mut ScratchBorr, ) { #[cfg(debug_assertions)] { @@ -572,6 +586,17 @@ where // ); // assert_alignement(tmp_bytes.as_ptr()); } + + let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_dft_to_dft_tmp_bytes( + c.cols(), + c.size(), + a.cols(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size(), + )); unsafe { vmp::vmp_apply_dft_to_dft( self.ptr, @@ -582,7 +607,7 @@ where b.as_ptr() as *const vmp::vmp_pmat_t, b.rows() as u64, (b.size() * b.cols()) as u64, - scratch.vmp_apply_dft_to_dft_tmp_bytes(self).as_mut_ptr(), + tmp_bytes.as_mut_ptr(), ) } } @@ -590,6 +615,7 @@ where #[cfg(test)] mod tests { + use crate::ScratchOwned; use crate::mat_znx_dft_ops::*; use crate::vec_znx_big_ops::*; use crate::vec_znx_dft_ops::*; @@ -617,7 +643,9 @@ mod tests { // let mut tmp_bytes: Vec = // alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes()); - let mut scratch = ScratchSpace {}; + let mut scratch = ScratchOwned::new( + 2 * (module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) + module.vec_znx_big_normalize_tmp_bytes()), + ); let mut tmp_bytes: Vec = alloc_aligned::( as VecZnxDftOps, Vec, _>>::vec_znx_idft_tmp_bytes(&module)); @@ -630,7 +658,9 @@ mod tests { module.vec_znx_dft(&mut a_dft, col_out, &a, col_out); }); - module.vmp_prepare_row(&mut vmpmat_0, row_i, col_in, &a, &mut scratch); + // let g = vmpmat_0.to_mut(); + + module.vmp_prepare_row(&mut vmpmat_0.to_mut(), row_i, col_in, &a, scratch.borrow()); // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft); @@ -641,11 +671,25 @@ mod tests { assert_eq!(a_dft.raw(), b_dft.raw()); // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) - module.vmp_extract_row(log_base2k, &mut b, &vmpmat_0, row_i, col_in, &mut scratch); + // module.vmp_extract_row( + // log_base2k, + // &mut b.to_mut(), + // &vmpmat_0.to_ref(), + // row_i, + // col_in, + // scratch.borrow(), + // ); (0..mat_cols_out).for_each(|col_out| { module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes); - module.vec_znx_big_normalize(log_base2k, &mut a, col_out, &a_big, col_out, &mut scratch); + module.vec_znx_big_normalize( + log_base2k, + &mut a.to_mut(), + col_out, + &a_big.to_ref(), + col_out, + scratch.borrow(), + ); }); assert_eq!(a.raw(), b.raw()); diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 3321f8e..b386604 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -97,11 +97,6 @@ impl + AsRef<[u8]>> VecZnx { pub fn switch_degree>(&mut self, col: usize, a: &VecZnx, col_a: usize) { switch_degree(self, col_a, a, col) } - - // Prints the first `n` coefficients of each limb - // pub fn print(&self, n: usize, col: usize) { - // (0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n])); - // } } impl>> VecZnx { @@ -131,8 +126,6 @@ impl>> VecZnx { } } -//(Jay)TODO: Impl. truncate pow2 for Owned Vector - /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. /// Panics if the cols do not match. @@ -148,12 +141,6 @@ where data_b[..size].copy_from_slice(&data_a[..size]) } -// if !self.borrowing() { -// self.inner -// .data -// .truncate(self.n() * self.cols() * (self.size() - k / log_base2k)); -// } - fn normalize_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } @@ -190,26 +177,6 @@ fn normalize + AsRef<[u8]>>(log_base2k: usize, a: &mut VecZnx, } } -// impl ZnxAlloc for VecZnx { -// type Scalar = i64; - -// fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx { -// debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); -// VecZnx { -// inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes), -// } -// } - -// fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { -// debug_assert_eq!( -// _rows, VEC_ZNX_ROWS, -// "rows != {} not supported for VecZnx", -// VEC_ZNX_ROWS -// ); -// module.n() * cols * size * size_of::() -// } -// } - impl> fmt::Display for VecZnx { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( @@ -248,3 +215,23 @@ impl> fmt::Display for VecZnx { pub type VecZnxOwned = VecZnx>; pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; + +impl VecZnx> { + pub(crate) fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + VecZnx { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: self.size, + } + } + + pub(crate) fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: self.size, + } + } +} diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 72b15d7..7442f11 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -53,13 +53,13 @@ impl> ZnxView for VecZnxBig { type Scalar = i64; } -impl>, B: Backend> VecZnxBig { - pub(crate) fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { - unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } - } +pub(crate) fn bytes_of_vec_znx_big(module: &Module, cols: usize, size: usize) -> usize { + unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } +} +impl>, B: Backend> VecZnxBig { pub(crate) fn new(module: &Module, cols: usize, size: usize) -> Self { - let data = alloc_aligned::(Self::bytes_of(module, cols, size)); + let data = alloc_aligned::(bytes_of_vec_znx_big(module, cols, size)); Self { data: data.into(), n: module.n(), @@ -71,7 +71,7 @@ impl>, B: Backend> VecZnxBig { pub(crate) fn new_from_bytes(module: &Module, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(module, cols, size)); + assert!(data.len() == bytes_of_vec_znx_big(module, cols, size)); Self { data: data.into(), n: module.n(), @@ -82,8 +82,42 @@ impl>, B: Backend> VecZnxBig { } } +impl VecZnxBig { + pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + _phantom: PhantomData, + } + } +} + pub type VecZnxBigOwned = VecZnxBig, B>; +impl VecZnxBig, B> { + pub(crate) fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { + VecZnxBig { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } + + pub(crate) fn to_ref(&self) -> VecZnxBig<&[u8], B> { + VecZnxBig { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + // impl VecZnxBig { // pub fn print(&self, n: usize, col: usize) { // (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index bb46802..20b4f2e 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,6 +1,9 @@ use crate::ffi::vec_znx; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{Backend, DataView, FFT64, Module, ScratchSpace, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxOps, assert_alignement}; +use crate::{ + Backend, DataView, FFT64, Module, ScratchBorr, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxScratch, assert_alignement, + bytes_of_vec_znx_big, +}; pub trait VecZnxBigAlloc { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. @@ -113,9 +116,6 @@ pub trait VecZnxBigOps { /// Subtracts `res` from `a` and stores the result on `res`. fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); - /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; - /// Normalizes `a` and stores the result on `b`. /// /// # Arguments @@ -129,7 +129,7 @@ pub trait VecZnxBigOps { res_col: usize, a: &VecZnxBig, a_col: usize, - scratch: &mut ScratchSpace, + scratch: &mut ScratchBorr, ); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. @@ -146,6 +146,11 @@ pub trait VecZnxBigOps { fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize); } +pub trait VecZnxBigScratch { + /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; +} + impl VecZnxBigAlloc for Module { fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned { VecZnxBig::new(self, cols, size) @@ -160,7 +165,7 @@ impl VecZnxBigAlloc for Module { // } fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { - VecZnxBigOwned::bytes_of(self, cols, size) + bytes_of_vec_znx_big(self, cols, size) } } @@ -491,10 +496,6 @@ where } } - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { - >::vec_znx_normalize_tmp_bytes(self) - } - fn vec_znx_big_normalize( &self, log_base2k: usize, @@ -502,7 +503,7 @@ where res_col: usize, a: &VecZnxBig, a_col: usize, - scratch: &mut ScratchSpace, + scratch: &mut ScratchBorr, ) { #[cfg(debug_assertions)] { @@ -513,6 +514,10 @@ where // assert!(tmp_bytes.len() >= >::vec_znx_normalize_tmp_bytes(&self)); // assert_alignement(tmp_bytes.as_ptr()); } + + let (tmp_bytes, _) = scratch.tmp_scalar_slice(::vec_znx_big_normalize_tmp_bytes( + &self, + )); unsafe { vec_znx::vec_znx_normalize_base2k( self.ptr, @@ -523,7 +528,7 @@ where a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, - scratch.vec_znx_big_normalize_tmp_bytes(self).as_mut_ptr(), + tmp_bytes.as_mut_ptr(), ); } } @@ -574,3 +579,9 @@ where } } } + +impl VecZnxBigScratch for Module { + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { + ::vec_znx_normalize_tmp_bytes(self) + } +} diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 74b559c..5d15c00 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -54,13 +54,13 @@ impl> ZnxView for VecZnxDft { type Scalar = f64; } -impl>, B: Backend> VecZnxDft { - pub(crate) fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { - unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } - } +pub(crate) fn bytes_of_vec_znx_dft(module: &Module, cols: usize, size: usize) -> usize { + unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } +} +impl>, B: Backend> VecZnxDft { pub(crate) fn new(module: &Module, cols: usize, size: usize) -> Self { - let data = alloc_aligned::(Self::bytes_of(module, cols, size)); + let data = alloc_aligned::(bytes_of_vec_znx_dft(module, cols, size)); Self { data: data.into(), n: module.n(), @@ -72,7 +72,7 @@ impl>, B: Backend> VecZnxDft { pub(crate) fn new_from_bytes(module: &Module, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(module, cols, size)); + assert!(data.len() == bytes_of_vec_znx_dft(module, cols, size)); Self { data: data.into(), n: module.n(), @@ -85,8 +85,8 @@ impl>, B: Backend> VecZnxDft { pub type VecZnxDftOwned = VecZnxDft, B>; -impl<'a, D: ?Sized, B> VecZnxDft<&'a mut D, B> { - pub(crate) fn from_mut_slice(data: &'a mut D, n: usize, cols: usize, size: usize) -> Self { +impl VecZnxDft { + pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { Self { data, n, diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 2c1cc97..9a1db2a 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -1,6 +1,7 @@ -use crate::VecZnxDftOwned; use crate::ffi::{vec_znx_big, vec_znx_dft}; +use crate::vec_znx_dft::bytes_of_vec_znx_dft; use crate::znx_base::ZnxInfos; +use crate::{Backend, VecZnxDftOwned}; use crate::{FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxView, ZnxViewMut, ZnxZero, assert_alignement}; use std::cmp::min; @@ -66,12 +67,12 @@ pub trait VecZnxDftOps { fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize); } -impl VecZnxDftAlloc for Module { - fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned { +impl VecZnxDftAlloc for Module { + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned { VecZnxDftOwned::new(&self, cols, size) } - fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { VecZnxDftOwned::new_from_bytes(self, cols, size, bytes) } @@ -80,7 +81,7 @@ impl VecZnxDftAlloc for Module { // } fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { - VecZnxDftOwned::bytes_of(&self, cols, size) + bytes_of_vec_znx_dft(self, cols, size) } } diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 6951651..d647860 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -43,9 +43,6 @@ pub trait VecZnxAlloc { } pub trait VecZnxOps { - /// Returns the minimum number of bytes necessary for normalization. - fn vec_znx_normalize_tmp_bytes(&self) -> usize; - /// Normalizes the selected column of `a` and stores the result into the selected column of `res`. fn vec_znx_normalize( &self, @@ -137,6 +134,11 @@ pub trait VecZnxOps { fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec>, a_col: usize); } +pub trait VecZnxScratch { + /// Returns the minimum number of bytes necessary for normalization. + fn vec_znx_normalize_tmp_bytes(&self) -> usize; +} + impl VecZnxAlloc for Module { //(Jay)TODO: One must define the Scalar generic param here. fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned { @@ -157,10 +159,6 @@ where Data: AsRef<[u8]>, DataMut: AsRef<[u8]> + AsMut<[u8]>, { - fn vec_znx_normalize_tmp_bytes(&self) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize } - } - fn vec_znx_normalize( &self, log_base2k: usize, @@ -174,7 +172,7 @@ where { assert_eq!(a.n(), self.n()); assert_eq!(res.n(), self.n()); - assert!(tmp_bytes.len() >= >::vec_znx_normalize_tmp_bytes(&self)); + assert!(tmp_bytes.len() >= ::vec_znx_normalize_tmp_bytes(&self)); assert_alignement(tmp_bytes.as_ptr()); } unsafe { @@ -489,3 +487,9 @@ where >::vec_znx_rotate_inplace(self, a.len() as i64, res, res_col); } } + +impl VecZnxScratch for Module { + fn vec_znx_normalize_tmp_bytes(&self) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize } + } +} diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index a7361ad..69afef8 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -1,59 +1,6 @@ -use crate::{Backend, Module, alloc_aligned, assert_alignement, cast_mut}; use itertools::izip; use std::cmp::min; -pub struct ZnxBase { - /// The ring degree - pub n: usize, - - /// The number of rows (in the third dimension) - pub rows: usize, - - /// The number of polynomials - pub cols: usize, - - /// The number of size per polynomial (a.k.a small polynomials). - pub size: usize, - - /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. - pub data: Vec, - - /// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it). - pub ptr: *mut u8, -} - -impl ZnxBase { - pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { - let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes); - res.data = bytes; - res - } - - pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert_eq!(n & (n - 1), 0, "n must be a power of two"); - assert!(n > 0, "n must be greater than 0"); - assert!(rows > 0, "rows must be greater than 0"); - assert!(cols > 0, "cols must be greater than 0"); - assert!(size > 0, "size must be greater than 0"); - } - Self { - n: n, - rows: rows, - cols: cols, - size: size, - data: Vec::new(), - ptr: bytes.as_mut_ptr(), - } - } -} - -pub trait GetZnxBase { - fn znx(&self) -> &ZnxBase; - fn znx_mut(&mut self) -> &mut ZnxBase; -} - pub trait ZnxInfos { /// Returns the ring degree of the polynomials. fn n(&self) -> usize; @@ -82,30 +29,6 @@ pub trait ZnxInfos { fn sl(&self) -> usize; } -// pub trait ZnxSliceSize {} - -//(Jay) TODO: Remove ZnxAlloc -// pub trait ZnxAlloc -// where -// Self: Sized + ZnxInfos, -// { -// type Scalar; -// fn new(module: &Module, rows: usize, cols: usize, size: usize) -> Self { -// let bytes: Vec = alloc_aligned::(Self::bytes_of(module, rows, cols, size)); -// Self::from_bytes(module, rows, cols, size, bytes) -// } - -// fn from_bytes(module: &Module, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { -// let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes); -// res.znx_mut().data = bytes; -// res -// } - -// fn from_bytes_borrow(module: &Module, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self; - -// fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize; -// } - pub trait DataView { type D; fn data(&self) -> &Self::D; @@ -176,35 +99,6 @@ pub trait ZnxViewMut: ZnxView + DataViewMut> { //(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known impl ZnxViewMut for T where T: ZnxView + DataViewMut> {} -use std::convert::TryFrom; -use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; -pub trait Num: - Copy - + Default - + PartialEq - + PartialOrd - + Add - + Sub - + Mul - + Div - + Neg - + AddAssign -{ - const BITS: u32; -} - -impl Num for i64 { - const BITS: u32 = 64; -} - -impl Num for i128 { - const BITS: u32 = 128; -} - -impl Num for f64 { - const BITS: u32 = 64; -} - pub trait ZnxZero: ZnxViewMut where Self: Sized, @@ -261,128 +155,96 @@ pub fn switch_degree + ZnxZero, D: ZnxView }); } +use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; + +use crate::{ScratchBorr, cast_mut}; +pub trait Integer: + Copy + + Default + + PartialEq + + PartialOrd + + Add + + Sub + + Mul + + Div + + Neg + + Shl + + Shr + + AddAssign +{ + const BITS: u32; +} + +impl Integer for i64 { + const BITS: u32 = 64; +} + +impl Integer for i128 { + const BITS: u32 = 128; +} + // (Jay)TODO: implement rsh for VecZnx, VecZnxBig // pub trait ZnxRsh: ZnxZero { // fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) { // rsh(k, log_base2k, self, col, carry) // } // } -// pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) { -// let n: usize = a.n(); -// let size: usize = a.size(); -// let cols: usize = a.cols(); +pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, scratch: &mut ScratchBorr) +where + V::Scalar: From + Integer, +{ + let n: usize = a.n(); + let size: usize = a.size(); + let cols: usize = a.cols(); -// #[cfg(debug_assertions)] -// { -// assert!( -// tmp_bytes.len() >= rsh_tmp_bytes::(n), -// "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", -// tmp_bytes.len() / size_of::(), -// n, -// size, -// ); -// assert_alignement(tmp_bytes.as_ptr()); -// } + // #[cfg(debug_assertions)] + // { + // assert!( + // tmp_bytes.len() >= rsh_tmp_bytes::(n), + // "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", + // tmp_bytes.len() / size_of::(), + // n, + // size, + // ); + // assert_alignement(tmp_bytes.as_ptr()); + // } -// let size: usize = a.size(); -// let steps: usize = k / log_base2k; + let size: usize = a.size(); + let steps: usize = k / log_base2k; -// a.raw_mut().rotate_right(n * steps * cols); -// (0..cols).for_each(|i| { -// (0..steps).for_each(|j| { -// a.zero_at(i, j); -// }) -// }); + a.raw_mut().rotate_right(n * steps * cols); + (0..cols).for_each(|i| { + (0..steps).for_each(|j| { + a.zero_at(i, j); + }) + }); -// let k_rem: usize = k % log_base2k; + let k_rem: usize = k % log_base2k; -// if k_rem != 0 { -// let carry: &mut [V::Scalar] = cast_mut(tmp_bytes); + if k_rem != 0 { + let (carry, _) = scratch.tmp_scalar_slice::(rsh_tmp_bytes::(n)); -// unsafe { -// std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::()); -// } + unsafe { + std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::()); + } -// let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap(); -// let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap(); -// let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap(); + let log_base2k_t = V::Scalar::from(log_base2k); + let shift = V::Scalar::from(V::Scalar::BITS as usize - k_rem); + let k_rem_t = V::Scalar::from(k_rem); -// (steps..size).for_each(|i| { -// izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| { -// *xi += *ci << log_base2k_t; -// *ci = get_base_k_carry(*xi, shift); -// *xi = (*xi - *ci) >> k_rem_t; -// }); -// }) -// } -// } + (0..cols).for_each(|i| { + (steps..size).for_each(|j| { + izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| { + *xi += *ci << log_base2k_t; + *ci = (*xi << shift) >> shift; + *xi = (*xi - *ci) >> k_rem_t; + }); + }); + //TODO: ZERO CARRYcarry + }) + } +} -// #[inline(always)] -// fn get_base_k_carry(x: T, shift: T) -> T { -// (x << shift) >> shift -// } - -// pub fn rsh_tmp_bytes(n: usize) -> usize { -// n * std::mem::size_of::() -// } - -// pub trait ZnxLayout: ZnxInfos { -// type Scalar; - -// /// Returns true if the receiver is only borrowing the data. -// fn borrowing(&self) -> bool { -// self.znx().data.len() == 0 -// } - -// /// Returns a non-mutable pointer to the underlying coefficients array. -// fn as_ptr(&self) -> *const Self::Scalar { -// self.znx().ptr as *const Self::Scalar -// } - -// /// Returns a mutable pointer to the underlying coefficients array. -// fn as_mut_ptr(&mut self) -> *mut Self::Scalar { -// self.znx_mut().ptr as *mut Self::Scalar -// } - -// /// Returns a non-mutable reference to the entire underlying coefficient array. -// fn raw(&self) -> &[Self::Scalar] { -// unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } -// } - -// /// Returns a mutable reference to the entire underlying coefficient array. -// fn raw_mut(&mut self) -> &mut [Self::Scalar] { -// unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } -// } - -// /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column. -// fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { -// #[cfg(debug_assertions)] -// { -// assert!(i < self.cols()); -// assert!(j < self.size()); -// } -// let offset: usize = self.n() * (j * self.cols() + i); -// unsafe { self.as_ptr().add(offset) } -// } - -// /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column. -// fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { -// #[cfg(debug_assertions)] -// { -// assert!(i < self.cols()); -// assert!(j < self.size()); -// } -// let offset: usize = self.n() * (j * self.cols() + i); -// unsafe { self.as_mut_ptr().add(offset) } -// } - -// /// Returns non-mutable reference to the (i, j)-th small polynomial. -// fn at(&self, i: usize, j: usize) -> &[Self::Scalar] { -// unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } -// } - -// /// Returns mutable reference to the (i, j)-th small polynomial. -// fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { -// unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } -// } -// } +pub fn rsh_tmp_bytes(n: usize) -> usize { + n * std::mem::size_of::() +} From bd105497fd097284f6e0fa7c029577eedb473014 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Sun, 4 May 2025 19:46:22 +0530 Subject: [PATCH 27/87] amend rlwe_encrypt example and minor changes at multiple places --- base2k/examples/rlwe_encrypt.rs | 52 ++++---- base2k/examples/vmp.rs | 122 +++++++++--------- base2k/src/lib.rs | 115 ++--------------- base2k/src/mat_znx_dft_ops.rs | 219 ++++++++++++++++---------------- base2k/src/vec_znx.rs | 15 ++- base2k/src/vec_znx_big.rs | 12 +- base2k/src/vec_znx_big_ops.rs | 9 +- base2k/src/vec_znx_dft.rs | 80 +++++------- base2k/src/vec_znx_dft_ops.rs | 13 -- base2k/src/vec_znx_ops.rs | 14 -- base2k/src/znx_base.rs | 30 +---- 11 files changed, 267 insertions(+), 414 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index afac2f8..742dcea 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,6 @@ use base2k::{ - Encoding, FFT64, Module, Sampling, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, - VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned, + Encoding, FFT64, Module, Sampling, ScalarAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, VecZnxAlloc, VecZnxBigOps, + VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, }; use itertools::izip; use sampling::source::Source; @@ -13,24 +13,24 @@ fn main() { let log_scale: usize = msg_size * log_base2k - 5; let module: Module = Module::::new(n); - let mut tmp_bytes_norm: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); - let mut tmp_bytes_dft = alloc_aligned(module.bytes_of_vec_znx_dft(1, ct_size)); + let mut scratch = + ScratchOwned::new((2 * module.bytes_of_vec_znx_dft(1, ct_size)) + 2 * module.vec_znx_big_normalize_tmp_bytes()); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) - let mut s: Scalar = module.new_scalar(1); + let mut s = module.new_scalar(1); s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_dft: ScalarZnxDft = module.new_scalar_znx_dft(s.cols()); + let mut s_dft = module.new_scalar_znx_dft(s.cols()); // s_dft <- DFT(s) module.svp_prepare(&mut s_dft, 0, &s, 0); // Allocates a VecZnx with two columns: ct=(0, 0) - let mut ct: VecZnx = module.new_vec_znx( + let mut ct = module.new_vec_znx( 2, // Number of columns ct_size, // Number of small poly per column ); @@ -39,11 +39,8 @@ fn main() { module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source); // Scratch space for DFT values - let mut buf_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow( - 1, // Number of columns - ct.size(), // Number of polynomials per column - &mut tmp_bytes_dft, - ); + let scratch = scratch.borrow(); + let (mut buf_dft, scratch) = scratch.tmp_vec_znx_dft(&module, 1, ct_size); // Applies DFT(ct[1]) * DFT(s) module.svp_apply_dft( @@ -56,13 +53,14 @@ fn main() { ); // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) - let mut buf_big: VecZnxBig = buf_dft.alias_as_vec_znx_big(); + let (mut buf_big, scratch) = scratch.tmp_vec_znx_big(&module, 1, ct_size); // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) - module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); + // Note: Since `vec_znx_idft_tmp_a` takes no argument for generic `Data` a full qualified path seems necessary + as VecZnxDftOps<_, &[u8], _>>::vec_znx_idft_tmp_a(&module, &mut buf_big, 0, &mut buf_dft, 0); // Creates a plaintext: VecZnx with 1 column - let mut m: VecZnx = module.new_vec_znx( + let mut m = module.new_vec_znx( 1, // Number of columns msg_size, // Number of small polynomials ); @@ -70,10 +68,11 @@ fn main() { want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); m.encode_vec_i64(0, log_base2k, log_scale, &want, 4); - m.normalize(log_base2k, 0, &mut tmp_bytes_norm); + let (tmp_bytes_norm, scratch) = scratch.tmp_scalar_slice(n * std::mem::size_of::()); + m.normalize(log_base2k, 0, tmp_bytes_norm); // m - BIG(ct[1] * s) - module.vec_znx_big_sub_small_a_inplace( + module.vec_znx_big_sub_small_b_inplace( &mut buf_big, 0, // Selects the first column of the receiver &m, @@ -83,12 +82,9 @@ fn main() { // Normalizes back to VecZnx // ct[0] <- m - BIG(c1 * s) module.vec_znx_big_normalize( - log_base2k, - &mut ct, - 0, // Selects the first column of ct (ct[0]) - &buf_big, - 0, // Selects the first column of buf_big - &mut tmp_bytes_norm, + log_base2k, &mut ct, 0, // Selects the first column of ct (ct[0]) + &buf_big, 0, // Selects the first column of buf_big + scratch, ); // Add noise to ct[0] @@ -118,14 +114,14 @@ fn main() { ); // BIG(c1 * s) = IDFT(DFT(c1 * s)) - module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); + as VecZnxDftOps<_, &[u8], _>>::vec_znx_idft_tmp_a(&module, &mut buf_big, 0, &mut buf_dft, 0); // BIG(c1 * s) + ct[0] module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0); // m + e <- BIG(ct[1] * s + ct[0]) - let mut res: VecZnx = module.new_vec_znx(1, ct_size); - module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, &mut tmp_bytes_norm); + let mut res = module.new_vec_znx(1, ct_size); + module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, scratch); // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; @@ -136,5 +132,7 @@ fn main() { .enumerate() .for_each(|(i, (a, b))| { println!("{}: {} {}", i, a, (*b as f64) / scale); - }) + }); + + module.free(); } diff --git a/base2k/examples/vmp.rs b/base2k/examples/vmp.rs index 710744e..36943f7 100644 --- a/base2k/examples/vmp.rs +++ b/base2k/examples/vmp.rs @@ -1,78 +1,78 @@ -use base2k::{ - Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, - ZnxInfos, ZnxLayout, alloc_aligned, -}; +// use base2k::{ +// Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, +// ZnxInfos, ZnxLayout, alloc_aligned, +// }; fn main() { - let log_n: i32 = 5; - let n: usize = 1 << log_n; + // let log_n: i32 = 5; + // let n: usize = 1 << log_n; - let module: Module = Module::::new(n); - let log_base2k: usize = 15; + // let module: Module = Module::::new(n); + // let log_base2k: usize = 15; - let a_cols: usize = 2; - let a_size: usize = 5; + // let a_cols: usize = 2; + // let a_size: usize = 5; - let log_k: usize = log_base2k * a_size - 5; + // let log_k: usize = log_base2k * a_size - 5; - let mat_rows: usize = a_size; - let mat_cols_in: usize = a_cols; - let mat_cols_out: usize = 2; - let mat_size: usize = a_size + 1; + // let mat_rows: usize = a_size; + // let mat_cols_in: usize = a_cols; + // let mat_cols_out: usize = 2; + // let mat_size: usize = a_size + 1; - let mut tmp_bytes_vmp: Vec = alloc_aligned( - module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) - | module.vmp_apply_dft_tmp_bytes( - a_size, - a_size, - mat_rows, - mat_cols_in, - mat_cols_out, - mat_size, - ), - ); + // let mut tmp_bytes_vmp: Vec = alloc_aligned( + // module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) + // | module.vmp_apply_dft_tmp_bytes( + // a_size, + // a_size, + // mat_rows, + // mat_cols_in, + // mat_cols_out, + // mat_size, + // ), + // ); - let mut tmp_bytes_dft: Vec = alloc_aligned(module.bytes_of_vec_znx_dft(mat_cols_out, mat_size)); + // let mut tmp_bytes_dft: Vec = alloc_aligned(module.bytes_of_vec_znx_dft(mat_cols_out, mat_size)); - let mut a: VecZnx = module.new_vec_znx(a_cols, a_size); + // let mut a: VecZnx = module.new_vec_znx(a_cols, a_size); - (0..a_cols).for_each(|i| { - let mut values: Vec = vec![i64::default(); n]; - values[1 + i] = (1 << log_base2k) + 1; - a.encode_vec_i64(i, log_base2k, log_k, &values, 32); - a.normalize(log_base2k, i, &mut tmp_bytes_vmp); - a.print(n, i); - println!(); - }); + // (0..a_cols).for_each(|i| { + // let mut values: Vec = vec![i64::default(); n]; + // values[1 + i] = (1 << log_base2k) + 1; + // a.encode_vec_i64(i, log_base2k, log_k, &values, 32); + // a.normalize(log_base2k, i, &mut tmp_bytes_vmp); + // a.print(n, i); + // println!(); + // }); - let mut mat_znx_dft: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + // let mut mat_znx_dft: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - (0..a.size()).for_each(|row_i| { - let mut tmp: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); - (0..mat_cols_out).for_each(|j| { - tmp.at_mut(j, row_i)[1 + j] = 1 as i64; - }); - (0..mat_cols_in).for_each(|j| { - module.vmp_prepare_row(&mut mat_znx_dft, row_i, j, &tmp, &mut tmp_bytes_vmp); - }) - }); + // (0..a.size()).for_each(|row_i| { + // let mut tmp: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); + // (0..mat_cols_out).for_each(|j| { + // tmp.at_mut(j, row_i)[1 + j] = 1 as i64; + // }); + // (0..mat_cols_in).for_each(|j| { + // module.vmp_prepare_row(&mut mat_znx_dft, row_i, j, &tmp, &mut tmp_bytes_vmp); + // }) + // }); - let mut c_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(mat_cols_out, mat_size, &mut tmp_bytes_dft); - module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut tmp_bytes_vmp); + // let mut c_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(mat_cols_out, mat_size, &mut tmp_bytes_dft); + // module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut tmp_bytes_vmp); - let mut res: VecZnx = module.new_vec_znx(mat_cols_out, a_size); - let mut c_big: VecZnxBig = c_dft.alias_as_vec_znx_big(); - (0..mat_cols_out).for_each(|i| { - module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); - module.vec_znx_big_normalize(log_base2k, &mut res, i, &c_big, i, &mut tmp_bytes_vmp); + // let mut res: VecZnx = module.new_vec_znx(mat_cols_out, a_size); + // let mut c_big: VecZnxBig = c_dft.alias_as_vec_znx_big(); + // (0..mat_cols_out).for_each(|i| { + // module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); + // module.vec_znx_big_normalize(log_base2k, &mut res, i, &c_big, i, &mut tmp_bytes_vmp); - let mut values_res: Vec = vec![i64::default(); n]; - res.decode_vec_i64(i, log_base2k, log_k, &mut values_res); - res.print(n, i); - println!(); - println!("{:?}", values_res); - println!(); - }); + // let mut values_res: Vec = vec![i64::default(); n]; + // res.decode_vec_i64(i, log_base2k, log_k, &mut values_res); + // res.print(n, i); + // println!(); + // println!("{:?}", values_res); + // println!(); + // }); - module.free(); + // module.free(); } diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index f33ce60..38d6b4e 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -18,17 +18,10 @@ pub mod vec_znx_dft_ops; pub mod vec_znx_ops; pub mod znx_base; -use std::{ - any::type_name, - ops::{DerefMut, Sub}, -}; - pub use encoding::*; pub use mat_znx_dft::*; pub use mat_znx_dft_ops::*; pub use module::*; -use rand_core::le; -use rand_distr::num_traits::sign; pub use sampling::*; pub use scalar_znx::*; pub use scalar_znx_dft::*; @@ -133,6 +126,8 @@ pub fn alloc_aligned(size: usize) -> Vec { ) } +// Scratch implementation below + pub struct ScratchOwned(Vec); impl ScratchOwned { @@ -141,16 +136,16 @@ impl ScratchOwned { Self(data) } - pub fn borrow(&mut self) -> &mut ScratchBorr { - ScratchBorr::new(&mut self.0) + pub fn borrow(&mut self) -> &mut Scratch { + Scratch::new(&mut self.0) } } -pub struct ScratchBorr { +pub struct Scratch { data: [u8], } -impl ScratchBorr { +impl Scratch { fn new(data: &mut [u8]) -> &mut Self { unsafe { &mut *(data as *mut [u8] as *mut Self) } } @@ -175,14 +170,14 @@ impl ScratchBorr { panic!( "Attempted to take {} from scratch with {} aligned bytes left", take_len, - take_len, + aligned_len, // type_name::(), // aligned_len ); } } - fn tmp_scalar_slice(&mut self, len: usize) -> (&mut [T], &mut Self) { + pub fn tmp_scalar_slice(&mut self, len: usize) -> (&mut [T], &mut Self) { let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::()); unsafe { @@ -193,7 +188,7 @@ impl ScratchBorr { } } - fn tmp_vec_znx_dft( + pub fn tmp_vec_znx_dft( &mut self, module: &Module, cols: usize, @@ -207,103 +202,17 @@ impl ScratchBorr { ) } - fn tmp_vec_znx_big From<&'a mut [u8]>, B: Backend>( + pub fn tmp_vec_znx_big( &mut self, module: &Module, cols: usize, size: usize, - ) -> (VecZnxBig, &mut Self) { + ) -> (VecZnxBig<&mut [u8], B>, &mut Self) { let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_big(module, cols, size)); ( - VecZnxBig::from_data(D::from(take_slice), module.n(), cols, size), + VecZnxBig::from_data(take_slice, module.n(), cols, size), Self::new(rem_slice), ) } } - -// pub struct ScratchBorrowed<'a> { -// data: &'a mut [u8], -// } - -// impl<'a> ScratchBorrowed<'a> { -// fn take_slice(&mut self, take_len: usize) -> (&mut [T], ScratchBorrowed<'_>) { -// let ptr = self.data.as_mut_ptr(); -// let self_len = self.data.len(); - -// //TODO(Jay): print the offset sometimes, just to check -// let aligned_offset = ptr.align_offset(DEFAULTALIGN); -// let aligned_len = self_len.saturating_sub(aligned_offset); - -// let take_len_bytes = take_len * std::mem::size_of::(); - -// if let Some(rem_len) = aligned_len.checked_sub(take_len_bytes) { -// unsafe { -// let rem_ptr = ptr.add(aligned_offset).add(take_len_bytes); -// let rem_slice = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); - -// let take_slice = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset) as *mut T, take_len_bytes); - -// return (take_slice, ScratchBorrowed { data: rem_slice }); -// } -// } else { -// panic!( -// "Attempted to take {} (={} elements of {}) from scratch with {} aligned bytes left", -// take_len_bytes, -// take_len, -// type_name::(), -// aligned_len -// ); -// } -// } - -// fn reborrow(&mut self) -> ScratchBorrowed<'a> { -// //(Jay)TODO: `data: &mut *self.data` does not work because liftime of &mut self is different from 'a. -// // But it feels that there should be a simpler impl. than the one below -// Self { -// data: unsafe { &mut *std::ptr::slice_from_raw_parts_mut(self.data.as_mut_ptr(), self.data.len()) }, -// } -// } - -// fn tmp_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, Self) { -// let (data, re_scratch) = self.take_slice::(vec_znx_dft::bytes_of_vec_znx_dft(module, cols, size)); -// ( -// VecZnxDft::from_data(data, module.n(), cols, size), -// re_scratch, -// ) -// } - -// pub(crate) fn len(&self) -> usize { -// self.data.len() -// } -// } - -// pub trait Scratch { -// fn tmp_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (D, &mut Self); -// } - -// impl<'a> Scratch<&'a mut [u8]> for ScratchBorr { -// fn tmp_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (&'a mut [u8], &mut Self) { -// let (data, rem_scratch) = self.tmp_scalar_slice(vec_znx_dft::bytes_of_vec_znx_dft(module, cols, size)); -// ( -// data -// rem_scratch, -// ) -// } - -// // fn tmp_vec_znx_big(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, Self) { -// // // let (data, re_scratch) = self.take_slice(vec_znx_big::bytes_of_vec_znx_big(module, cols, size)); -// // // ( -// // // VecZnxBig::from_data(data, module.n(), cols, size), -// // // re_scratch, -// // // ) -// // } - -// // fn scalar_slice(&mut self, len: usize) -> (&mut [T], Self) { -// // self.take_slice::(len) -// // } - -// // fn reborrow(&mut self) -> Self { -// // self.reborrow() -// // } -// } diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 5ab44df..658ff5d 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -2,7 +2,7 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchBorr, VecZnx, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, + Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, Scratch, VecZnx, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, }; @@ -25,15 +25,6 @@ pub trait MatZnxDftAlloc { size: usize, bytes: Vec, ) -> MatZnxDftAllocOwned; - - // fn new_mat_znx_dft_from_bytes_borrow( - // &self, - // rows: usize, - // cols_in: usize, - // cols_out: usize, - // size: usize, - // bytes: &mut [u8], - // ) -> MatZnxDft; } pub trait MatZnxDftScratch { @@ -101,7 +92,7 @@ pub trait MatZnxDftOps { b_row: usize, b_col_in: usize, a: &VecZnx, - scratch: &mut ScratchBorr, + scratch: &mut Scratch, ); /// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig]. @@ -118,7 +109,7 @@ pub trait MatZnxDftOps { a: &MatZnxDft, b_row: usize, b_col_in: usize, - scratch: &mut ScratchBorr, + scratch: &mut Scratch, ); /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. @@ -165,7 +156,7 @@ pub trait MatZnxDftOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, scratch: &mut ScratchBorr); + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, scratch: &mut Scratch); /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -197,7 +188,7 @@ pub trait MatZnxDftOps { c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, - scratch: &mut ScratchBorr, + scratch: &mut Scratch, ); } @@ -274,18 +265,14 @@ impl MatZnxDftScratch for Module { } } -impl MatZnxDftOps for Module -where - DataMut: AsMut<[u8]> + AsRef<[u8]> + for<'a> From<&'a mut [u8]>, - Data: AsRef<[u8]>, -{ +impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module { fn vmp_prepare_row( &self, - b: &mut MatZnxDft, + b: &mut MatZnxDft<&mut [u8], FFT64>, b_row: usize, b_col_in: usize, - a: &VecZnx, - scratch: &mut ScratchBorr, + a: &VecZnx<&[u8]>, + scratch: &mut Scratch, ) { #[cfg(debug_assertions)] { @@ -328,21 +315,19 @@ where let a_size: usize = a.size(); // let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size)); - let (mut a_dft, _) = scratch.tmp_scalar_slice(12); - DataMut::from(a_dft); - // let (mut a_dft, _) = scratch.tmp_vec_znx_dft::(self, cols_out, a_size); + let (mut a_dft, _) = scratch.tmp_vec_znx_dft::<_>(self, cols_out, a_size); (0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i)); - Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft); + Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft.to_ref()); } fn vmp_extract_row( &self, log_base2k: usize, - b: &mut VecZnx, - a: &MatZnxDft, + b: &mut VecZnx<&mut [u8]>, + a: &MatZnxDft<&[u8], FFT64>, a_row: usize, a_col_in: usize, - mut scratch: &mut ScratchBorr, + scratch: &mut Scratch, ) { #[cfg(debug_assertions)] { @@ -386,12 +371,18 @@ where Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in); let (mut b_big, scratch) = scratch.tmp_vec_znx_big(self, cols_out, size); (0..cols_out).for_each(|i| { - >::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i); + >::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i); self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, scratch); }); } - fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft) { + fn vmp_prepare_row_dft( + &self, + b: &mut MatZnxDft<&mut [u8], FFT64>, + b_row: usize, + b_col_in: usize, + a: &VecZnxDft<&[u8], FFT64>, + ) { #[cfg(debug_assertions)] { assert_eq!(b.n(), self.n()); @@ -436,7 +427,13 @@ where } } - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize) { + fn vmp_extract_row_dft( + &self, + b: &mut VecZnxDft<&mut [u8], FFT64>, + a: &MatZnxDft<&[u8], FFT64>, + a_row: usize, + a_col_in: usize, + ) { #[cfg(debug_assertions)] { assert_eq!(b.n(), self.n()); @@ -482,10 +479,10 @@ where fn vmp_apply_dft( &self, - c: &mut VecZnxDft, - a: &VecZnx, - b: &MatZnxDft, - mut scratch: &mut ScratchBorr, + c: &mut VecZnxDft<&mut [u8], FFT64>, + a: &VecZnx<&[u8]>, + b: &MatZnxDft<&[u8], FFT64>, + scratch: &mut Scratch, ) { #[cfg(debug_assertions)] { @@ -547,68 +544,70 @@ where fn vmp_apply_dft_to_dft( &self, - c: &mut VecZnxDft, - a: &VecZnxDft, - b: &MatZnxDft, - mut scratch: &mut ScratchBorr, + c: &mut VecZnxDft<&mut [u8], FFT64>, + a: &VecZnxDft<&[u8], FFT64>, + b: &MatZnxDft<&[u8], FFT64>, + scratch: &mut Scratch, ) { - #[cfg(debug_assertions)] { - assert_eq!(c.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(a.n(), self.n()); - assert_eq!( - c.cols(), - b.cols_out(), - "c.cols(): {} != b.cols_out: {}", - c.cols(), - b.cols_out() - ); - assert_eq!( - a.cols(), - b.cols_in(), - "a.cols(): {} != b.cols_in: {}", - a.cols(), - b.cols_in() - ); - // assert!( - // tmp_bytes.len() - // >= self.vmp_apply_dft_to_dft_tmp_bytes( - // c.cols(), - // c.size(), - // a.cols(), - // a.size(), - // b.rows(), - // b.cols_in(), - // b.cols_out(), - // b.size() - // ) - // ); - // assert_alignement(tmp_bytes.as_ptr()); - } + #[cfg(debug_assertions)] + { + assert_eq!(c.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + c.cols(), + b.cols_out(), + "c.cols(): {} != b.cols_out: {}", + c.cols(), + b.cols_out() + ); + assert_eq!( + a.cols(), + b.cols_in(), + "a.cols(): {} != b.cols_in: {}", + a.cols(), + b.cols_in() + ); + // assert!( + // tmp_bytes.len() + // >= self.vmp_apply_dft_to_dft_tmp_bytes( + // c.cols(), + // c.size(), + // a.cols(), + // a.size(), + // b.rows(), + // b.cols_in(), + // b.cols_out(), + // b.size() + // ) + // ); + // assert_alignement(tmp_bytes.as_ptr()); + } - let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_dft_to_dft_tmp_bytes( - c.cols(), - c.size(), - a.cols(), - a.size(), - b.rows(), - b.cols_in(), - b.cols_out(), - b.size(), - )); - unsafe { - vmp::vmp_apply_dft_to_dft( - self.ptr, - c.as_mut_ptr() as *mut vec_znx_dft_t, - c.poly_count() as u64, - a.as_ptr() as *const vec_znx_dft_t, - a.poly_count() as u64, - b.as_ptr() as *const vmp::vmp_pmat_t, - b.rows() as u64, - (b.size() * b.cols()) as u64, - tmp_bytes.as_mut_ptr(), - ) + let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_dft_to_dft_tmp_bytes( + c.cols(), + c.size(), + a.cols(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size(), + )); + unsafe { + vmp::vmp_apply_dft_to_dft( + self.ptr, + c.as_mut_ptr() as *mut vec_znx_dft_t, + c.poly_count() as u64, + a.as_ptr() as *const vec_znx_dft_t, + a.poly_count() as u64, + b.as_ptr() as *const vmp::vmp_pmat_t, + b.rows() as u64, + (b.size() * b.cols()) as u64, + tmp_bytes.as_mut_ptr(), + ) + } } } } @@ -658,27 +657,31 @@ mod tests { module.vec_znx_dft(&mut a_dft, col_out, &a, col_out); }); - // let g = vmpmat_0.to_mut(); - - module.vmp_prepare_row(&mut vmpmat_0.to_mut(), row_i, col_in, &a, scratch.borrow()); + module.vmp_prepare_row( + &mut vmpmat_0.to_mut(), + row_i, + col_in, + &a.to_ref(), + scratch.borrow(), + ); // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) - module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft); + module.vmp_prepare_row_dft(&mut vmpmat_1.to_mut(), row_i, col_in, &a_dft.to_ref()); assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); // Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) - module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i, col_in); + module.vmp_extract_row_dft(&mut b_dft.to_mut(), &vmpmat_0.to_ref(), row_i, col_in); assert_eq!(a_dft.raw(), b_dft.raw()); // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) - // module.vmp_extract_row( - // log_base2k, - // &mut b.to_mut(), - // &vmpmat_0.to_ref(), - // row_i, - // col_in, - // scratch.borrow(), - // ); + module.vmp_extract_row( + log_base2k, + &mut b.to_mut(), + &vmpmat_0.to_ref(), + row_i, + col_in, + scratch.borrow(), + ); (0..mat_cols_out).for_each(|col_out| { module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes); diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index b386604..09b0051 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -217,7 +217,7 @@ pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; impl VecZnx> { - pub(crate) fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + pub fn to_mut(&mut self) -> VecZnx<&mut [u8]> { VecZnx { data: self.data.as_mut_slice(), n: self.n, @@ -226,7 +226,7 @@ impl VecZnx> { } } - pub(crate) fn to_ref(&self) -> VecZnx<&[u8]> { + pub fn to_ref(&self) -> VecZnx<&[u8]> { VecZnx { data: self.data.as_slice(), n: self.n, @@ -235,3 +235,14 @@ impl VecZnx> { } } } + +impl VecZnx<&mut [u8]> { + pub fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: &self.data, + n: self.n, + cols: self.cols, + size: self.size, + } + } +} diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 7442f11..fe67516 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -3,7 +3,7 @@ use crate::znx_base::{ZnxInfos, ZnxView}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, alloc_aligned}; use std::marker::PhantomData; -const VEC_ZNX_BIG_ROWS: usize = 1; +// const VEC_ZNX_BIG_ROWS: usize = 1; /// VecZnxBig is `Backend` dependent, denoted with backend generic `B` pub struct VecZnxBig { @@ -97,7 +97,7 @@ impl VecZnxBig { pub type VecZnxBigOwned = VecZnxBig, B>; impl VecZnxBig, B> { - pub(crate) fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { + pub fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { VecZnxBig { data: self.data.as_mut_slice(), n: self.n, @@ -107,7 +107,7 @@ impl VecZnxBig, B> { } } - pub(crate) fn to_ref(&self) -> VecZnxBig<&[u8], B> { + pub fn to_ref(&self) -> VecZnxBig<&[u8], B> { VecZnxBig { data: self.data.as_slice(), n: self.n, @@ -117,9 +117,3 @@ impl VecZnxBig, B> { } } } - -// impl VecZnxBig { -// pub fn print(&self, n: usize, col: usize) { -// (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); -// } -// } diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 20b4f2e..d0e4bd3 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,9 +1,6 @@ use crate::ffi::vec_znx; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{ - Backend, DataView, FFT64, Module, ScratchBorr, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxScratch, assert_alignement, - bytes_of_vec_znx_big, -}; +use crate::{Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxScratch, bytes_of_vec_znx_big}; pub trait VecZnxBigAlloc { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. @@ -129,7 +126,7 @@ pub trait VecZnxBigOps { res_col: usize, a: &VecZnxBig, a_col: usize, - scratch: &mut ScratchBorr, + scratch: &mut Scratch, ); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. @@ -503,7 +500,7 @@ where res_col: usize, a: &VecZnxBig, a_col: usize, - scratch: &mut ScratchBorr, + scratch: &mut Scratch, ) { #[cfg(debug_assertions)] { diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 5d15c00..a4a3242 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -4,7 +4,7 @@ use crate::ffi::vec_znx_dft; use crate::znx_base::ZnxInfos; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; -const VEC_ZNX_DFT_ROWS: usize = 1; +// const VEC_ZNX_DFT_ROWS: usize = 1; // VecZnxDft is `Backend` dependent denoted with generic `B` pub struct VecZnxDft { @@ -97,52 +97,36 @@ impl VecZnxDft { } } -// impl ZnxAlloc for VecZnxDft { -// type Scalar = u8; +impl VecZnxDft, B> { + pub fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + VecZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } -// fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { -// debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); -// Self { -// inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes), -// _marker: PhantomData, -// } -// } + pub fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} -// fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { -// debug_assert_eq!( -// _rows, VEC_ZNX_DFT_ROWS, -// "rows != {} not supported for VecZnxDft", -// VEC_ZNX_DFT_ROWS -// ); -// unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } -// } -// } - -// impl VecZnxDft { -// pub fn print(&self, n: usize, col: usize) { -// (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); -// } -// } - -// impl VecZnxDft { -// /// Cast a [VecZnxDft] into a [VecZnxBig]. -// /// The returned [VecZnxBig] shares the backing array -// /// with the original [VecZnxDft]. -// pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig { -// assert!( -// self.data().len() == 0, -// "cannot alias VecZnxDft into VecZnxBig if it owns the data" -// ); -// VecZnxBig:: { -// inner: ZnxBase { -// data: Vec::new(), -// ptr: self.ptr(), -// n: self.n(), -// rows: self.rows(), -// cols: self.cols(), -// size: self.size(), -// }, -// _marker: PhantomData, -// } -// } -// } +impl VecZnxDft<&mut [u8], B> { + pub fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: &self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 9a1db2a..e894ef4 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -22,19 +22,6 @@ pub trait VecZnxDftAlloc { /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; - // /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - // /// - // /// Behavior: the backing array is only borrowed. - // /// - // /// # Arguments - // /// - // /// * `cols`: the number of cols of the [VecZnxDft]. - // /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - // /// - // /// # Panics - // /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - // fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// /// # Arguments diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index d647860..a8edb12 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -23,19 +23,6 @@ pub trait VecZnxAlloc { /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned; - // /// Instantiates a new [VecZnx] from a slice of bytes. - // /// The returned [VecZnx] does take ownership of the slice of bytes. - // /// - // /// # Arguments - // /// - // /// * `cols`: the number of polynomials. - // /// * `size`: the number small polynomials per column. - // /// - // /// # Panic - // /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. - // fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx; - // (Jay)TODO - /// Returns the number of bytes necessary to allocate /// a new [VecZnx] through [VecZnxOps::new_vec_znx_from_bytes] /// or [VecZnxOps::new_vec_znx_from_bytes_borrow]. @@ -140,7 +127,6 @@ pub trait VecZnxScratch { } impl VecZnxAlloc for Module { - //(Jay)TODO: One must define the Scalar generic param here. fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned { VecZnxOwned::new::(self.n(), cols, size) } diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index 69afef8..9eea5bb 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -1,4 +1,5 @@ use itertools::izip; +use rand_distr::num_traits::Zero; use std::cmp::min; pub trait ZnxInfos { @@ -157,7 +158,7 @@ pub fn switch_degree + ZnxZero, D: ZnxView use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; -use crate::{ScratchBorr, cast_mut}; +use crate::Scratch; pub trait Integer: Copy + Default @@ -183,32 +184,15 @@ impl Integer for i128 { const BITS: u32 = 128; } -// (Jay)TODO: implement rsh for VecZnx, VecZnxBig -// pub trait ZnxRsh: ZnxZero { -// fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) { -// rsh(k, log_base2k, self, col, carry) -// } -// } -pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, scratch: &mut ScratchBorr) +//(Jay)Note: `rsh` impl. ignores the column +pub fn rsh(k: usize, log_base2k: usize, a: &mut V, _a_col: usize, scratch: &mut Scratch) where - V::Scalar: From + Integer, + V::Scalar: From + Integer + Zero, { let n: usize = a.n(); - let size: usize = a.size(); + let _size: usize = a.size(); let cols: usize = a.cols(); - // #[cfg(debug_assertions)] - // { - // assert!( - // tmp_bytes.len() >= rsh_tmp_bytes::(n), - // "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", - // tmp_bytes.len() / size_of::(), - // n, - // size, - // ); - // assert_alignement(tmp_bytes.as_ptr()); - // } - let size: usize = a.size(); let steps: usize = k / log_base2k; @@ -240,7 +224,7 @@ where *xi = (*xi - *ci) >> k_rem_t; }); }); - //TODO: ZERO CARRYcarry + carry.iter_mut().for_each(|r| *r = V::Scalar::zero()); }) } } From ffa363804b062cc0cceffd4eadbb18950a8b75bd Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 5 May 2025 17:35:35 +0200 Subject: [PATCH 28/87] rework as discussed --- base2k/examples/rlwe_encrypt.rs | 43 +- base2k/examples/vmp.rs | 78 ---- base2k/src/lib.rs | 8 + base2k/src/mat_znx_dft.rs | 104 +++-- base2k/src/mat_znx_dft_ops.rs | 649 +++++++++---------------------- base2k/src/sampling.rs | 63 ++- base2k/src/scalar_znx.rs | 100 +++-- base2k/src/scalar_znx_dft.rs | 92 +++-- base2k/src/scalar_znx_dft_ops.rs | 70 ++-- base2k/src/vec_znx.rs | 77 +++- base2k/src/vec_znx_big.rs | 71 +++- base2k/src/vec_znx_big_ops.rs | 363 +++++++++-------- base2k/src/vec_znx_dft.rs | 65 +++- base2k/src/vec_znx_dft_ops.rs | 123 +++--- base2k/src/vec_znx_ops.rs | 371 ++++++++++++------ base2k/src/znx_base.rs | 30 +- 16 files changed, 1154 insertions(+), 1153 deletions(-) delete mode 100644 base2k/examples/vmp.rs diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 742dcea..b55efba 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,7 @@ use base2k::{ - Encoding, FFT64, Module, Sampling, ScalarAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, VecZnxAlloc, VecZnxBigOps, - VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, + Encoding, FFT64, Module, Sampling, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, + VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, + VecZnxOps, ZnxInfos, }; use itertools::izip; use sampling::source::Source; @@ -13,24 +14,23 @@ fn main() { let log_scale: usize = msg_size * log_base2k - 5; let module: Module = Module::::new(n); - let mut scratch = - ScratchOwned::new((2 * module.bytes_of_vec_znx_dft(1, ct_size)) + 2 * module.vec_znx_big_normalize_tmp_bytes()); + let mut scratch: ScratchOwned = ScratchOwned::new(module.vec_znx_big_normalize_tmp_bytes()); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) - let mut s = module.new_scalar(1); + let mut s: Scalar> = module.new_scalar(1); s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_dft = module.new_scalar_znx_dft(s.cols()); + let mut s_dft: ScalarZnxDft, FFT64> = module.new_scalar_znx_dft(s.cols()); // s_dft <- DFT(s) module.svp_prepare(&mut s_dft, 0, &s, 0); // Allocates a VecZnx with two columns: ct=(0, 0) - let mut ct = module.new_vec_znx( + let mut ct: VecZnx> = module.new_vec_znx( 2, // Number of columns ct_size, // Number of small poly per column ); @@ -38,12 +38,10 @@ fn main() { // Fill the second column with random values: ct = (0, a) module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source); - // Scratch space for DFT values - let scratch = scratch.borrow(); - let (mut buf_dft, scratch) = scratch.tmp_vec_znx_dft(&module, 1, ct_size); + let mut buf_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_size); // Applies DFT(ct[1]) * DFT(s) - module.svp_apply_dft( + module.svp_apply( &mut buf_dft, // DFT(ct[1] * s) 0, // Selects the first column of res &s_dft, // DFT(s) @@ -53,11 +51,10 @@ fn main() { ); // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) - let (mut buf_big, scratch) = scratch.tmp_vec_znx_big(&module, 1, ct_size); // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) - // Note: Since `vec_znx_idft_tmp_a` takes no argument for generic `Data` a full qualified path seems necessary - as VecZnxDftOps<_, &[u8], _>>::vec_znx_idft_tmp_a(&module, &mut buf_big, 0, &mut buf_dft, 0); + let mut buf_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_size); + module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); // Creates a plaintext: VecZnx with 1 column let mut m = module.new_vec_znx( @@ -68,8 +65,7 @@ fn main() { want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); m.encode_vec_i64(0, log_base2k, log_scale, &want, 4); - let (tmp_bytes_norm, scratch) = scratch.tmp_scalar_slice(n * std::mem::size_of::()); - m.normalize(log_base2k, 0, tmp_bytes_norm); + module.vec_znx_normalize_inplace(log_base2k, &mut m, 0, scratch.borrow()); // m - BIG(ct[1] * s) module.vec_znx_big_sub_small_b_inplace( @@ -82,9 +78,12 @@ fn main() { // Normalizes back to VecZnx // ct[0] <- m - BIG(c1 * s) module.vec_znx_big_normalize( - log_base2k, &mut ct, 0, // Selects the first column of ct (ct[0]) - &buf_big, 0, // Selects the first column of buf_big - scratch, + log_base2k, + &mut ct, + 0, // Selects the first column of ct (ct[0]) + &buf_big, + 0, // Selects the first column of buf_big + scratch.borrow(), ); // Add noise to ct[0] @@ -104,7 +103,7 @@ fn main() { // Decryption // DFT(ct[1] * s) - module.svp_apply_dft( + module.svp_apply( &mut buf_dft, 0, // Selects the first column of res. &s_dft, @@ -114,14 +113,14 @@ fn main() { ); // BIG(c1 * s) = IDFT(DFT(c1 * s)) - as VecZnxDftOps<_, &[u8], _>>::vec_znx_idft_tmp_a(&module, &mut buf_big, 0, &mut buf_dft, 0); + module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); // BIG(c1 * s) + ct[0] module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0); // m + e <- BIG(ct[1] * s + ct[0]) let mut res = module.new_vec_znx(1, ct_size); - module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, scratch); + module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, scratch.borrow()); // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; diff --git a/base2k/examples/vmp.rs b/base2k/examples/vmp.rs deleted file mode 100644 index 36943f7..0000000 --- a/base2k/examples/vmp.rs +++ /dev/null @@ -1,78 +0,0 @@ -// use base2k::{ -// Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, -// ZnxInfos, ZnxLayout, alloc_aligned, -// }; - -fn main() { - // let log_n: i32 = 5; - // let n: usize = 1 << log_n; - - // let module: Module = Module::::new(n); - // let log_base2k: usize = 15; - - // let a_cols: usize = 2; - // let a_size: usize = 5; - - // let log_k: usize = log_base2k * a_size - 5; - - // let mat_rows: usize = a_size; - // let mat_cols_in: usize = a_cols; - // let mat_cols_out: usize = 2; - // let mat_size: usize = a_size + 1; - - // let mut tmp_bytes_vmp: Vec = alloc_aligned( - // module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) - // | module.vmp_apply_dft_tmp_bytes( - // a_size, - // a_size, - // mat_rows, - // mat_cols_in, - // mat_cols_out, - // mat_size, - // ), - // ); - - // let mut tmp_bytes_dft: Vec = alloc_aligned(module.bytes_of_vec_znx_dft(mat_cols_out, mat_size)); - - // let mut a: VecZnx = module.new_vec_znx(a_cols, a_size); - - // (0..a_cols).for_each(|i| { - // let mut values: Vec = vec![i64::default(); n]; - // values[1 + i] = (1 << log_base2k) + 1; - // a.encode_vec_i64(i, log_base2k, log_k, &values, 32); - // a.normalize(log_base2k, i, &mut tmp_bytes_vmp); - // a.print(n, i); - // println!(); - // }); - - // let mut mat_znx_dft: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - - // (0..a.size()).for_each(|row_i| { - // let mut tmp: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); - // (0..mat_cols_out).for_each(|j| { - // tmp.at_mut(j, row_i)[1 + j] = 1 as i64; - // }); - // (0..mat_cols_in).for_each(|j| { - // module.vmp_prepare_row(&mut mat_znx_dft, row_i, j, &tmp, &mut tmp_bytes_vmp); - // }) - // }); - - // let mut c_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(mat_cols_out, mat_size, &mut tmp_bytes_dft); - // module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut tmp_bytes_vmp); - - // let mut res: VecZnx = module.new_vec_znx(mat_cols_out, a_size); - // let mut c_big: VecZnxBig = c_dft.alias_as_vec_znx_big(); - // (0..mat_cols_out).for_each(|i| { - // module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); - // module.vec_znx_big_normalize(log_base2k, &mut res, i, &c_big, i, &mut tmp_bytes_vmp); - - // let mut values_res: Vec = vec![i64::default(); n]; - // res.decode_vec_i64(i, log_base2k, log_k, &mut values_res); - // res.print(n, i); - // println!(); - // println!("{:?}", values_res); - // println!(); - // }); - - // module.free(); -} diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 38d6b4e..f3b2525 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -215,4 +215,12 @@ impl Scratch { Self::new(rem_slice), ) } + + pub fn tmp_vec_znx(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, module.bytes_of_vec_znx(cols, size)); + ( + VecZnx::from_data(take_slice, module.n(), cols, size), + Self::new(rem_slice), + ) + } } diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index 7a39dd1..1f18b48 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,5 +1,5 @@ use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; use std::marker::PhantomData; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], @@ -8,17 +8,17 @@ use std::marker::PhantomData; /// /// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. /// See the trait [MatZnxDftOps] for additional information. -pub struct MatZnxDft { +pub struct MatZnxDft { data: D, n: usize, size: usize, rows: usize, cols_in: usize, cols_out: usize, - _marker: PhantomData, + _phantom: PhantomData, } -impl ZnxInfos for MatZnxDft { +impl ZnxInfos for MatZnxDft { fn cols(&self) -> usize { self.cols_in } @@ -34,20 +34,22 @@ impl ZnxInfos for MatZnxDft { fn size(&self) -> usize { self.size } +} +impl ZnxSliceSize for MatZnxDft { fn sl(&self) -> usize { - self.n() + self.n() * self.cols_out() } } -impl DataView for MatZnxDft { +impl DataView for MatZnxDft { type D = D; fn data(&self) -> &Self::D { &self.data } } -impl DataViewMut for MatZnxDft { +impl DataViewMut for MatZnxDft { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } @@ -57,7 +59,7 @@ impl> ZnxView for MatZnxDft { type Scalar = f64; } -impl MatZnxDft { +impl MatZnxDft { pub(crate) fn cols_in(&self) -> usize { self.cols_in } @@ -87,7 +89,7 @@ impl>, B: Backend> MatZnxDft { rows, cols_in, cols_out, - _marker: PhantomData, + _phantom: PhantomData, } } @@ -108,7 +110,7 @@ impl>, B: Backend> MatZnxDft { rows, cols_in, cols_out, - _marker: PhantomData, + _phantom: PhantomData, } } } @@ -151,28 +153,80 @@ impl> MatZnxDft { pub type MatZnxDftAllocOwned = MatZnxDft, B>; -impl MatZnxDft, B> { - pub fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { +pub trait MatZnxDftToRef { + fn to_ref(&self) -> MatZnxDft<&[u8], B>; +} + +pub trait MatZnxDftToMut { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>; +} + +impl MatZnxDftToMut for MatZnxDft, B> { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { MatZnxDft { data: self.data.as_mut_slice(), n: self.n, - size: self.size, rows: self.rows, cols_in: self.cols_in, cols_out: self.cols_out, - _marker: PhantomData, - } - } - - pub fn to_ref(&self) -> MatZnxDft<&[u8], B> { - MatZnxDft { - data: self.data.as_slice(), - n: self.n, size: self.size, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - _marker: PhantomData, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft, B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data.as_slice(), + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToMut for MatZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft<&mut [u8], B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft<&[u8], B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, } } } diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 658ff5d..9b79a2c 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -2,11 +2,11 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, Scratch, VecZnx, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, - VecZnxDftAlloc, VecZnxDftOps, + Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, + VecZnxDftToRef, }; -pub trait MatZnxDftAlloc { +pub trait MatZnxDftAlloc { /// Allocates a new [MatZnxDft] with the given number of rows and columns. /// /// # Arguments @@ -28,43 +28,10 @@ pub trait MatZnxDftAlloc { } pub trait MatZnxDftScratch { - /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row] - fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; - - /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row] - fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; - - /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. - /// - /// # Arguments - /// - /// * `c_size`: number of size of the output [VecZnxDft]. - /// * `a_size`: number of size of the input [VecZnx]. - /// * `rows`: number of rows of the input [MatZnxDft]. - /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_tmp_bytes( - &self, - c_size: usize, - a_size: usize, - b_rows: usize, - b_cols_in: usize, - b_cols_out: usize, - b_size: usize, - ) -> usize; - /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. - /// - /// # Arguments - /// - /// * `c_size`: number of size of the output [VecZnxDft]. - /// * `a_size`: number of size of the input [VecZnxDft]. - /// * `rows`: number of rows of the input [MatZnxDft]. - /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_to_dft_tmp_bytes( + fn vmp_apply_tmp_bytes( &self, - c_cols: usize, - c_size: usize, - a_cols: usize, + res_size: usize, a_size: usize, b_rows: usize, b_cols_in: usize, @@ -75,43 +42,7 @@ pub trait MatZnxDftScratch { /// This trait implements methods for vector matrix product, /// that is, multiplying a [VecZnx] with a [MatZnxDft]. -pub trait MatZnxDftOps { - /// Prepares the ith-row of [MatZnxDft] from a [VecZnx]. - /// - /// # Arguments - /// - /// * `b`: [MatZnxDft] on which the values are encoded. - /// * `row_i`: the row of the [MatZnxDft] to prepare. - /// * `a`: the [VecZnx] to encode on the i-th row of the [MatZnxDft]. - /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - /// - /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row( - &self, - b: &mut MatZnxDft, - b_row: usize, - b_col_in: usize, - a: &VecZnx, - scratch: &mut Scratch, - ); - - /// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig]. - /// - /// # Arguments - /// - /// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft]. - /// * `a`: [MatZnxDft] on which the values are encoded. - /// * `row_i`: the index of the row to extract. - fn vmp_extract_row( - &self, - log_base2k: usize, - b: &mut VecZnx, - a: &MatZnxDft, - b_row: usize, - b_col_in: usize, - scratch: &mut Scratch, - ); - +pub trait MatZnxDftOps { /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. /// /// # Arguments @@ -121,7 +52,10 @@ pub trait MatZnxDftOps { /// * `row_i`: the index of the row to prepare. /// /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft); + fn vmp_prepare_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) + where + R: MatZnxDftToMut, + A: VecZnxDftToRef; /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft]. /// @@ -130,33 +64,10 @@ pub trait MatZnxDftOps { /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize); - - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. - /// - /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and - /// `j` size, the output is a [VecZnx] of `j` size. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnx] of the vector matrix product. - /// * `b`: the right operand [MatZnxDft] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, scratch: &mut Scratch); + fn vmp_extract_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) + where + R: VecZnxDftToMut, + A: MatZnxDftToRef; /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -183,13 +94,11 @@ pub trait MatZnxDftOps { /// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft( - &self, - c: &mut VecZnxDft, - a: &VecZnxDft, - b: &MatZnxDft, - scratch: &mut Scratch, - ); + fn vmp_apply(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: MatZnxDftToRef; } impl MatZnxDftAlloc for Module { @@ -213,40 +122,10 @@ impl MatZnxDftAlloc for Module { } } -impl MatZnxDftScratch for Module { - fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { - >::bytes_of_vec_znx_dft(self, cols_out, size) - } - - fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { - >::bytes_of_vec_znx_dft(self, cols_out, size) - + ::vec_znx_big_normalize_tmp_bytes(self) - } - - fn vmp_apply_dft_tmp_bytes( +impl MatZnxDftScratch for Module { + fn vmp_apply_tmp_bytes( &self, - c_size: usize, - a_size: usize, - b_rows: usize, - b_cols_in: usize, - b_cols_out: usize, - b_size: usize, - ) -> usize { - unsafe { - vmp::vmp_apply_dft_tmp_bytes( - self.ptr, - c_size as u64, - a_size as u64, - (b_rows * b_cols_in) as u64, - (b_size * b_cols_out) as u64, - ) as usize - } - } - fn vmp_apply_dft_to_dft_tmp_bytes( - &self, - c_cols: usize, - c_size: usize, - a_cols: usize, + res_size: usize, a_size: usize, b_rows: usize, b_cols_in: usize, @@ -256,8 +135,8 @@ impl MatZnxDftScratch for Module { unsafe { vmp::vmp_apply_dft_to_dft_tmp_bytes( self.ptr, - (c_size * c_cols) as u64, - (a_size * a_cols) as u64, + (res_size * b_cols_out) as u64, + (a_size * b_cols_in) as u64, (b_rows * b_cols_in) as u64, (b_size * b_cols_out) as u64, ) as usize @@ -265,152 +144,43 @@ impl MatZnxDftScratch for Module { } } -impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module { - fn vmp_prepare_row( - &self, - b: &mut MatZnxDft<&mut [u8], FFT64>, - b_row: usize, - b_col_in: usize, - a: &VecZnx<&[u8]>, - scratch: &mut Scratch, - ) { +impl MatZnxDftOps for Module { + fn vmp_prepare_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) + where + R: MatZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res: MatZnxDft<&mut [u8], _> = res.to_mut(); + let a: VecZnxDft<&[u8], _> = a.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); assert_eq!(a.n(), self.n()); assert_eq!( a.cols(), - b.cols_out(), - "a.cols(): {} != b.cols_out(): {}", + res.cols_out(), + "a.cols(): {} != res.cols_out(): {}", a.cols(), - b.cols_out() + res.cols_out() ); assert!( - b_row < b.rows(), - "b_row: {} >= b.rows(): {}", - b_row, - b.rows() + res_row < res.rows(), + "res_row: {} >= res.rows(): {}", + res_row, + res.rows() ); assert!( - b_col_in < b.cols_in(), - "b_col_in: {} >= b.cols_in(): {}", - b_col_in, - b.cols_in() + res_col_in < res.cols_in(), + "res_col_in: {} >= res.cols_in(): {}", + res_col_in, + res.cols_in() ); assert_eq!( - b.size(), + res.size(), a.size(), - "b.size(): {} != a.size(): {}", - b.size(), - a.size() - ); - // assert!( - // tmp_bytes.len() - // >= >::vmp_prepare_row_tmp_bytes(self, a.cols(), a.size()) - // ); - // assert!(is_aligned(tmp_bytes.as_ptr())) - } - - let cols_out: usize = a.cols(); - let a_size: usize = a.size(); - - // let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size)); - let (mut a_dft, _) = scratch.tmp_vec_znx_dft::<_>(self, cols_out, a_size); - (0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i)); - Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft.to_ref()); - } - - fn vmp_extract_row( - &self, - log_base2k: usize, - b: &mut VecZnx<&mut [u8]>, - a: &MatZnxDft<&[u8], FFT64>, - a_row: usize, - a_col_in: usize, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(b.n(), self.n()); - assert_eq!(a.n(), self.n()); - assert_eq!( - b.cols(), - a.cols_out(), - "b.cols(): {} != a.cols_out(): {}", - b.cols(), - a.cols_out() - ); - assert!( - a_row < a.rows(), - "a_row: {} >= a.rows(): {}", - a_row, - a.rows() - ); - assert!( - a_col_in < a.cols_in(), - "a_col_in: {} >= a.cols_in(): {}", - a_col_in, - a.cols_in() - ); - assert_eq!( - b.size(), - a.size(), - "b.size(): {} != a.size(): {}", - b.size(), - a.size() - ); - // assert!(tmp_bytes.len() >= self.vmp_extract_row_tmp_bytes(a.cols(), a.size())); - // assert!(is_aligned(tmp_bytes.as_ptr())) - } - - let cols_out: usize = b.cols(); - let size: usize = b.size(); - - // let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size)); - let (mut b_dft, scratch) = scratch.tmp_vec_znx_dft(self, cols_out, size); - Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in); - let (mut b_big, scratch) = scratch.tmp_vec_znx_big(self, cols_out, size); - (0..cols_out).for_each(|i| { - >::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i); - self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, scratch); - }); - } - - fn vmp_prepare_row_dft( - &self, - b: &mut MatZnxDft<&mut [u8], FFT64>, - b_row: usize, - b_col_in: usize, - a: &VecZnxDft<&[u8], FFT64>, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(b.n(), self.n()); - assert_eq!(a.n(), self.n()); - assert_eq!( - a.cols(), - b.cols_out(), - "a.cols(): {} != b.cols_out(): {}", - a.cols(), - b.cols_out() - ); - assert!( - b_row < b.rows(), - "b_row: {} >= b.rows(): {}", - b_row, - b.rows() - ); - assert!( - b_col_in < b.cols_in(), - "b_col_in: {} >= b.cols_in(): {}", - b_col_in, - b.cols_in() - ); - assert_eq!( - b.size(), - a.size(), - "b.size(): {} != a.size(): {}", - b.size(), + "res.size(): {} != a.size(): {}", + res.size(), a.size() ); } @@ -418,31 +188,32 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module { unsafe { vmp::vmp_prepare_row_dft( self.ptr, - b.as_mut_ptr() as *mut vmp::vmp_pmat_t, + res.as_mut_ptr() as *mut vmp::vmp_pmat_t, a.as_ptr() as *const vec_znx_dft_t, - (b_row * b.cols_in() + b_col_in) as u64, - (b.rows() * b.cols_in()) as u64, - (b.size() * b.cols_out()) as u64, + (res_row * res.cols_in() + res_col_in) as u64, + (res.rows() * res.cols_in()) as u64, + (res.size() * res.cols_out()) as u64, ); } } - fn vmp_extract_row_dft( - &self, - b: &mut VecZnxDft<&mut [u8], FFT64>, - a: &MatZnxDft<&[u8], FFT64>, - a_row: usize, - a_col_in: usize, - ) { + fn vmp_extract_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) + where + R: VecZnxDftToMut, + A: MatZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); + let a: MatZnxDft<&[u8], _> = a.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); assert_eq!(a.n(), self.n()); assert_eq!( - b.cols(), + res.cols(), a.cols_out(), - "b.cols(): {} != a.cols_out(): {}", - b.cols(), + "res.cols(): {} != a.cols_out(): {}", + res.cols(), a.cols_out() ); assert!( @@ -458,17 +229,17 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module { a.cols_in() ); assert_eq!( - b.size(), + res.size(), a.size(), - "b.size(): {} != a.size(): {}", - b.size(), + "res.size(): {} != a.size(): {}", + res.size(), a.size() ); } unsafe { vmp::vmp_extract_row_dft( self.ptr, - b.as_mut_ptr() as *mut vec_znx_dft_t, + res.as_mut_ptr() as *mut vec_znx_dft_t, a.as_ptr() as *const vmp::vmp_pmat_t, (a_row * a.cols_in() + a_col_in) as u64, (a.rows() * a.cols_in()) as u64, @@ -477,23 +248,26 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module { } } - fn vmp_apply_dft( - &self, - c: &mut VecZnxDft<&mut [u8], FFT64>, - a: &VecZnx<&[u8]>, - b: &MatZnxDft<&[u8], FFT64>, - scratch: &mut Scratch, - ) { + fn vmp_apply(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: MatZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); + let a: VecZnxDft<&[u8], _> = a.to_ref(); + let b: MatZnxDft<&[u8], _> = b.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(c.n(), self.n()); + assert_eq!(res.n(), self.n()); assert_eq!(b.n(), self.n()); assert_eq!(a.n(), self.n()); assert_eq!( - c.cols(), + res.cols(), b.cols_out(), - "c.cols(): {} != b.cols_out: {}", - c.cols(), + "res.cols(): {} != b.cols_out: {}", + res.cols(), b.cols_out() ); assert_eq!( @@ -503,37 +277,23 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module { a.cols(), b.cols_in() ); - // assert!( - // tmp_bytes.len() - // >= self.vmp_apply_dft_tmp_bytes( - // c.size(), - // a.size(), - // b.rows(), - // b.cols_in(), - // b.cols_out(), - // b.size() - // ) - // ); - // assert_alignement(tmp_bytes.as_ptr()); } - let (tmp_bytes, _) = scratch.tmp_scalar_slice(::vmp_apply_dft_tmp_bytes( - self, - c.size(), + + let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_tmp_bytes( + res.size(), a.size(), b.rows(), b.cols_in(), b.cols_out(), b.size(), )); - unsafe { - vmp::vmp_apply_dft( + vmp::vmp_apply_dft_to_dft( self.ptr, - c.as_mut_ptr() as *mut vec_znx_dft_t, - (c.size() * c.cols()) as u64, - a.as_ptr(), + res.as_mut_ptr() as *mut vec_znx_dft_t, + (res.size() * res.cols()) as u64, + a.as_ptr() as *const vec_znx_dft_t, (a.size() * a.cols()) as u64, - a.n() as u64, b.as_ptr() as *const vmp::vmp_pmat_t, (b.rows() * b.cols_in()) as u64, (b.size() * b.cols_out()) as u64, @@ -541,164 +301,131 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module { ) } } - - fn vmp_apply_dft_to_dft( - &self, - c: &mut VecZnxDft<&mut [u8], FFT64>, - a: &VecZnxDft<&[u8], FFT64>, - b: &MatZnxDft<&[u8], FFT64>, - scratch: &mut Scratch, - ) { - { - #[cfg(debug_assertions)] - { - assert_eq!(c.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(a.n(), self.n()); - assert_eq!( - c.cols(), - b.cols_out(), - "c.cols(): {} != b.cols_out: {}", - c.cols(), - b.cols_out() - ); - assert_eq!( - a.cols(), - b.cols_in(), - "a.cols(): {} != b.cols_in: {}", - a.cols(), - b.cols_in() - ); - // assert!( - // tmp_bytes.len() - // >= self.vmp_apply_dft_to_dft_tmp_bytes( - // c.cols(), - // c.size(), - // a.cols(), - // a.size(), - // b.rows(), - // b.cols_in(), - // b.cols_out(), - // b.size() - // ) - // ); - // assert_alignement(tmp_bytes.as_ptr()); - } - - let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_dft_to_dft_tmp_bytes( - c.cols(), - c.size(), - a.cols(), - a.size(), - b.rows(), - b.cols_in(), - b.cols_out(), - b.size(), - )); - unsafe { - vmp::vmp_apply_dft_to_dft( - self.ptr, - c.as_mut_ptr() as *mut vec_znx_dft_t, - c.poly_count() as u64, - a.as_ptr() as *const vec_znx_dft_t, - a.poly_count() as u64, - b.as_ptr() as *const vmp::vmp_pmat_t, - b.rows() as u64, - (b.size() * b.cols()) as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - } } - #[cfg(test)] mod tests { - use crate::ScratchOwned; - use crate::mat_znx_dft_ops::*; - use crate::vec_znx_big_ops::*; - use crate::vec_znx_dft_ops::*; - use crate::vec_znx_ops::*; use crate::{ - FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, alloc_aligned, + Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, + VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut, }; use sampling::source::Source; + use super::{MatZnxDftAlloc, MatZnxDftScratch}; + #[test] - fn vmp_prepare_row_dft() { + fn vmp_prepare_row() { let module: Module = Module::::new(16); let log_base2k: usize = 8; let mat_rows: usize = 4; let mat_cols_in: usize = 2; let mat_cols_out: usize = 2; let mat_size: usize = 5; - let mut a: VecZnx<_> = module.new_vec_znx(mat_cols_out, mat_size); - let mut b: VecZnx<_> = module.new_vec_znx(mat_cols_out, mat_size); - let mut a_dft: VecZnxDft<_, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); - let mut a_big: VecZnxBig<_, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size); - let mut b_dft: VecZnxDft<_, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); - let mut vmpmat_0: MatZnxDft<_, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - let mut vmpmat_1: MatZnxDft<_, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - - // let mut tmp_bytes: Vec = - // alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes()); - let mut scratch = ScratchOwned::new( - 2 * (module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) + module.vec_znx_big_normalize_tmp_bytes()), - ); - let mut tmp_bytes: Vec = - alloc_aligned::( as VecZnxDftOps, Vec, _>>::vec_znx_idft_tmp_bytes(&module)); + let mut a: VecZnx> = module.new_vec_znx(mat_cols_out, mat_size); + let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut b_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut mat: MatZnxDft, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); for col_in in 0..mat_cols_in { for row_i in 0..mat_rows { let mut source: Source = Source::new([0u8; 32]); - (0..mat_cols_out).for_each(|col_out| { module.fill_uniform(log_base2k, &mut a, col_out, mat_size, &mut source); module.vec_znx_dft(&mut a_dft, col_out, &a, col_out); }); - - module.vmp_prepare_row( - &mut vmpmat_0.to_mut(), - row_i, - col_in, - &a.to_ref(), - scratch.borrow(), - ); - - // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) - module.vmp_prepare_row_dft(&mut vmpmat_1.to_mut(), row_i, col_in, &a_dft.to_ref()); - assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); - - // Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) - module.vmp_extract_row_dft(&mut b_dft.to_mut(), &vmpmat_0.to_ref(), row_i, col_in); + module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft); + module.vmp_extract_row(&mut b_dft, &mat, row_i, col_in); assert_eq!(a_dft.raw(), b_dft.raw()); - - // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) - module.vmp_extract_row( - log_base2k, - &mut b.to_mut(), - &vmpmat_0.to_ref(), - row_i, - col_in, - scratch.borrow(), - ); - - (0..mat_cols_out).for_each(|col_out| { - module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes); - module.vec_znx_big_normalize( - log_base2k, - &mut a.to_mut(), - col_out, - &a_big.to_ref(), - col_out, - scratch.borrow(), - ); - }); - - assert_eq!(a.raw(), b.raw()); } } module.free(); } + + #[test] + fn vmp_apply() { + let log_n: i32 = 5; + let n: usize = 1 << log_n; + + let module: Module = Module::::new(n); + let log_base2k: usize = 15; + let a_size: usize = 5; + let mat_size: usize = 6; + let res_size: usize = 5; + + [1, 2].iter().for_each(|in_cols| { + [1, 2].iter().for_each(|out_cols| { + let a_cols: usize = *in_cols; + let res_cols: usize = *out_cols; + + let mat_rows: usize = a_size; + let mat_cols_in: usize = a_cols; + let mat_cols_out: usize = res_cols; + let res_cols: usize = mat_cols_out; + + let mut scratch: ScratchOwned = ScratchOwned::new( + module.vmp_apply_tmp_bytes( + res_size, + a_size, + mat_rows, + mat_cols_in, + mat_cols_out, + mat_size, + ) | module.vec_znx_big_normalize_tmp_bytes(), + ); + + let mut a: VecZnx> = module.new_vec_znx(a_cols, a_size); + + (0..a_cols).for_each(|i| { + a.at_mut(i, 2)[i + 1] = 1; + }); + + let mut mat_znx_dft: MatZnxDft, FFT64> = + module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + + let mut c_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut c_big: VecZnxBig, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size); + + let mut tmp: VecZnx> = module.new_vec_znx(mat_cols_out, mat_size); + + // Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix. + (0..a.size()).for_each(|row_i| { + (0..mat_cols_in).for_each(|col_in_i| { + (0..mat_cols_out).for_each(|col_out_i| { + let idx = 1 + col_in_i * mat_cols_out + col_out_i; + tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx} + module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i); + tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64; + }); + module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); + }); + }); + + let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(a_cols, a_size); + (0..a_cols).for_each(|i| { + module.vec_znx_dft(&mut a_dft, i, &a, i); + }); + + module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow()); + + let mut res_have_vi64: Vec = vec![i64::default(); n]; + + let mut res_have: VecZnx> = module.new_vec_znx(res_cols, res_size); + (0..mat_cols_out).for_each(|i| { + module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); + module.vec_znx_big_normalize(log_base2k, &mut res_have, i, &c_big, i, scratch.borrow()); + }); + + (0..mat_cols_out).for_each(|col_i| { + let mut res_want_vi64: Vec = vec![i64::default(); n]; + (0..a_cols).for_each(|i| { + res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1; + }); + res_have.decode_vec_i64(col_i, log_base2k, log_base2k * 3, &mut res_have_vi64); + assert_eq!(res_have_vi64, res_want_vi64); + }); + }); + }); + + module.free(); + } } diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index a8b1962..b254286 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,53 +1,47 @@ use crate::znx_base::ZnxViewMut; -use crate::{Backend, Module, VecZnx}; +use crate::{Backend, Module, VecZnx, VecZnxToMut}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; pub trait Sampling { /// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] - fn fill_uniform + AsRef<[u8]>>( - &self, - log_base2k: usize, - a: &mut VecZnx, - col_i: usize, - size: usize, - source: &mut Source, - ); + fn fill_uniform(&self, log_base2k: usize, a: &mut A, col_i: usize, size: usize, source: &mut Source) + where + A: VecZnxToMut; /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. - fn add_dist_f64 + AsRef<[u8]>, D: Distribution>( + fn add_dist_f64>( &self, log_base2k: usize, - a: &mut VecZnx, + a: &mut A, col_i: usize, log_k: usize, source: &mut Source, dist: D, bound: f64, - ); + ) where + A: VecZnxToMut; /// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\]. - fn add_normal + AsRef<[u8]>>( + fn add_normal( &self, log_base2k: usize, - a: &mut VecZnx, + a: &mut A, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64, - ); + ) where + A: VecZnxToMut; } impl Sampling for Module { - fn fill_uniform + AsRef<[u8]>>( - &self, - log_base2k: usize, - a: &mut VecZnx, - col_i: usize, - size: usize, - source: &mut Source, - ) { + fn fill_uniform(&self, log_base2k: usize, a: &mut A, col_i: usize, size: usize, source: &mut Source) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); let base2k: u64 = 1 << log_base2k; let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; @@ -58,16 +52,19 @@ impl Sampling for Module { }) } - fn add_dist_f64 + AsRef<[u8]>, D: Distribution>( + fn add_dist_f64>( &self, log_base2k: usize, - a: &mut VecZnx, + a: &mut A, col_i: usize, log_k: usize, source: &mut Source, dist: D, bound: f64, - ) { + ) where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); assert!( (bound.log2().ceil() as i64) < 64, "invalid bound: ceil(log2(bound))={} > 63", @@ -96,16 +93,10 @@ impl Sampling for Module { } } - fn add_normal + AsRef<[u8]>>( - &self, - log_base2k: usize, - a: &mut VecZnx, - col_i: usize, - log_k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) { + fn add_normal(&self, log_base2k: usize, a: &mut A, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) + where + A: VecZnxToMut, + { self.add_dist_f64( log_base2k, a, diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index c5052eb..acdac8c 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -1,13 +1,10 @@ use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, Module, ZnxView, ZnxViewMut, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, Module, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; use sampling::source::Source; -// pub const SCALAR_ZNX_ROWS: usize = 1; -// pub const SCALAR_ZNX_SIZE: usize = 1; - pub struct Scalar { data: D, n: usize, @@ -30,7 +27,9 @@ impl ZnxInfos for Scalar { fn size(&self) -> usize { 1 } +} +impl ZnxSliceSize for Scalar { fn sl(&self) -> usize { self.n() } @@ -70,19 +69,6 @@ impl + AsRef<[u8]>> Scalar { .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); self.at_mut(col, 0).shuffle(source); } - - // pub fn alias_as_vec_znx(&self) -> VecZnx { - // VecZnx { - // inner: ZnxBase { - // n: self.n(), - // rows: 1, - // cols: 1, - // size: 1, - // data: Vec::new(), - // ptr: self.ptr() as *mut u8, - // }, - // } - // } } impl>> Scalar { @@ -116,7 +102,6 @@ pub trait ScalarAlloc { fn bytes_of_scalar(&self, cols: usize) -> usize; fn new_scalar(&self, cols: usize) -> ScalarOwned; fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarOwned; - // fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar; } impl ScalarAlloc for Module { @@ -129,31 +114,62 @@ impl ScalarAlloc for Module { fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarOwned { ScalarOwned::new_from_bytes::(self.n(), cols, bytes) } - // fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar { - // Scalar::from_bytes_borrow(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes) - // } } -// impl ZnxAlloc for Scalar { -// type Scalar = i64; +pub trait ScalarToRef { + fn to_ref(&self) -> Scalar<&[u8]>; +} -// fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { -// Self { -// inner: ZnxBase::from_bytes_borrow(module.n(), SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes), -// } -// } +pub trait ScalarToMut { + fn to_mut(&mut self) -> Scalar<&mut [u8]>; +} -// fn bytes_of(module: &Module, _rows: usize, cols: usize, _size: usize) -> usize { -// debug_assert_eq!( -// _rows, SCALAR_ZNX_ROWS, -// "rows != {} not supported for Scalar", -// SCALAR_ZNX_ROWS -// ); -// debug_assert_eq!( -// _size, SCALAR_ZNX_SIZE, -// "rows != {} not supported for Scalar", -// SCALAR_ZNX_SIZE -// ); -// module.n() * cols * std::mem::size_of::() -// } -// } +impl ScalarToMut for Scalar> { + fn to_mut(&mut self) -> Scalar<&mut [u8]> { + Scalar { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + } + } +} + +impl ScalarToRef for Scalar> { + fn to_ref(&self) -> Scalar<&[u8]> { + Scalar { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + } + } +} + +impl ScalarToMut for Scalar<&mut [u8]> { + fn to_mut(&mut self) -> Scalar<&mut [u8]> { + Scalar { + data: self.data, + n: self.n, + cols: self.cols, + } + } +} + +impl ScalarToRef for Scalar<&mut [u8]> { + fn to_ref(&self) -> Scalar<&[u8]> { + Scalar { + data: self.data, + n: self.n, + cols: self.cols, + } + } +} + +impl ScalarToRef for Scalar<&[u8]> { + fn to_ref(&self) -> Scalar<&[u8]> { + Scalar { + data: self.data, + n: self.n, + cols: self.cols, + } + } +} diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index 09b26d4..c93609f 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -2,19 +2,16 @@ use std::marker::PhantomData; use crate::ffi::svp; use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; -pub const SCALAR_ZNX_DFT_ROWS: usize = 1; -pub const SCALAR_ZNX_DFT_SIZE: usize = 1; - -pub struct ScalarZnxDft { +pub struct ScalarZnxDft { data: D, n: usize, cols: usize, _phantom: PhantomData, } -impl ZnxInfos for ScalarZnxDft { +impl ZnxInfos for ScalarZnxDft { fn cols(&self) -> usize { self.cols } @@ -30,20 +27,22 @@ impl ZnxInfos for ScalarZnxDft { fn size(&self) -> usize { 1 } +} +impl ZnxSliceSize for ScalarZnxDft { fn sl(&self) -> usize { self.n() } } -impl DataView for ScalarZnxDft { +impl DataView for ScalarZnxDft { type D = D; fn data(&self) -> &Self::D { &self.data } } -impl DataViewMut for ScalarZnxDft { +impl DataViewMut for ScalarZnxDft { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } @@ -78,20 +77,69 @@ impl>, B: Backend> ScalarZnxDft { _phantom: PhantomData, } } - - // fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { - // debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, _size)); - // Self { - // inner: ZnxBase::from_bytes_borrow( - // module.n(), - // SCALAR_ZNX_DFT_ROWS, - // cols, - // SCALAR_ZNX_DFT_SIZE, - // bytes, - // ), - // _phantom: PhantomData, - // } - // } } pub type ScalarZnxDftOwned = ScalarZnxDft, B>; + +pub trait ScalarZnxDftToRef { + fn to_ref(&self) -> ScalarZnxDft<&[u8], B>; +} + +pub trait ScalarZnxDftToMut { + fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B>; +} + +impl ScalarZnxDftToMut for ScalarZnxDft, B> { + fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { + ScalarZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +impl ScalarZnxDftToRef for ScalarZnxDft, B> { + fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { + ScalarZnxDft { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +impl ScalarZnxDftToMut for ScalarZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { + ScalarZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +impl ScalarZnxDftToRef for ScalarZnxDft<&mut [u8], B> { + fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { + ScalarZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +impl ScalarZnxDftToRef for ScalarZnxDft<&[u8], B> { + fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { + ScalarZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs index fc56e4e..ea98a57 100644 --- a/base2k/src/scalar_znx_dft_ops.rs +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -1,26 +1,28 @@ -use crate::ffi::svp::{self, svp_ppol_t}; +use crate::ffi::svp; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{Backend, FFT64, Module, Scalar, ScalarZnxDft, ScalarZnxDftOwned, VecZnx, VecZnxDft}; +use crate::{ + Backend, FFT64, Module, ScalarToRef, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, VecZnx, + VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxSliceSize, +}; -pub trait ScalarZnxDftAlloc { +pub trait ScalarZnxDftAlloc { fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned; fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned; // fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft; } -pub trait ScalarZnxDftOps { - fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize); - fn svp_apply_dft( - &self, - res: &mut VecZnxDft, - res_col: usize, - a: &ScalarZnxDft, - a_col: usize, - b: &VecZnx, - b_col: usize, - ); +pub trait ScalarZnxDftOps { + fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxDftToMut, + A: ScalarToRef; + fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + B: VecZnxToRef; } impl ScalarZnxDftAlloc for Module { @@ -35,42 +37,38 @@ impl ScalarZnxDftAlloc for Module { fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned { ScalarZnxDftOwned::new_from_bytes(self, cols, bytes) } - - // fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft { - // ScalarZnxDft::from_bytes_borrow(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes) - // } } -impl ScalarZnxDftOps for Module -where - DataMut: AsMut<[u8]> + AsRef<[u8]>, - Data: AsRef<[u8]>, -{ - fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize) { +impl ScalarZnxDftOps for Module { + fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxDftToMut, + A: ScalarToRef, + { unsafe { svp::svp_prepare( self.ptr, - res.at_mut_ptr(res_col, 0) as *mut svp_ppol_t, - a.at_ptr(a_col, 0), + res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t, + a.to_ref().at_ptr(a_col, 0), ) } } - fn svp_apply_dft( - &self, - res: &mut VecZnxDft, - res_col: usize, - a: &ScalarZnxDft, - a_col: usize, - b: &VecZnx, - b_col: usize, - ) { + fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + B: VecZnxToRef, + { + let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); unsafe { svp::svp_apply_dft( self.ptr, res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, res.size() as u64, - a.at_ptr(a_col, 0) as *const svp_ppol_t, + a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, b.at_ptr(b_col, 0), b.size() as u64, b.sl() as u64, diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 09b0051..70d8fb3 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,14 +1,13 @@ use crate::DataView; use crate::DataViewMut; +use crate::ZnxSliceSize; use crate::alloc_aligned; use crate::assert_alignement; use crate::cast_mut; use crate::ffi::znx; -use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut, switch_degree}; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use std::{cmp::min, fmt}; -// pub const VEC_ZNX_ROWS: usize = 1; - /// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of /// Zn\[X\] with [i64] coefficients. /// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array @@ -20,7 +19,7 @@ use std::{cmp::min, fmt}; /// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci /// are small polynomials of Zn\[X\]. pub struct VecZnx { - data: D, + pub data: D, n: usize, cols: usize, size: usize, @@ -42,9 +41,11 @@ impl ZnxInfos for VecZnx { fn size(&self) -> usize { self.size } +} +impl ZnxSliceSize for VecZnx { fn sl(&self) -> usize { - self.cols() * self.n() + self.n() * self.cols() } } @@ -66,10 +67,6 @@ impl> ZnxView for VecZnx { } impl + AsRef<[u8]>> VecZnx { - pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) { - normalize(log_base2k, self, col, carry) - } - /// Truncates the precision of the [VecZnx] by k bits. /// /// # Arguments @@ -92,11 +89,6 @@ impl + AsRef<[u8]>> VecZnx { .for_each(|x: &mut i64| *x &= mask) } } - - /// Switches degree of from `a.n()` to `self.n()` into `self` - pub fn switch_degree>(&mut self, col: usize, a: &VecZnx, col_a: usize) { - switch_degree(self, col_a, a, col) - } } impl>> VecZnx { @@ -126,6 +118,17 @@ impl>> VecZnx { } } +impl VecZnx { + pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + } + } +} + /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. /// Panics if the cols do not match. @@ -141,10 +144,12 @@ where data_b[..size].copy_from_slice(&data_a[..size]) } +#[allow(dead_code)] fn normalize_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } +#[allow(dead_code)] fn normalize + AsRef<[u8]>>(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); @@ -216,8 +221,16 @@ pub type VecZnxOwned = VecZnx>; pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; -impl VecZnx> { - pub fn to_mut(&mut self) -> VecZnx<&mut [u8]> { +pub trait VecZnxToRef { + fn to_ref(&self) -> VecZnx<&[u8]>; +} + +pub trait VecZnxToMut { + fn to_mut(&mut self) -> VecZnx<&mut [u8]>; +} + +impl VecZnxToMut for VecZnx> { + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { VecZnx { data: self.data.as_mut_slice(), n: self.n, @@ -225,8 +238,10 @@ impl VecZnx> { size: self.size, } } +} - pub fn to_ref(&self) -> VecZnx<&[u8]> { +impl VecZnxToRef for VecZnx> { + fn to_ref(&self) -> VecZnx<&[u8]> { VecZnx { data: self.data.as_slice(), n: self.n, @@ -236,10 +251,32 @@ impl VecZnx> { } } -impl VecZnx<&mut [u8]> { - pub fn to_ref(&self) -> VecZnx<&[u8]> { +impl VecZnxToMut for VecZnx<&mut [u8]> { + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { VecZnx { - data: &self.data, + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + } + } +} + +impl VecZnxToRef for VecZnx<&mut [u8]> { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + } + } +} + +impl VecZnxToRef for VecZnx<&[u8]> { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data, n: self.n, cols: self.cols, size: self.size, diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index fe67516..8f70272 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,12 +1,9 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned}; use std::marker::PhantomData; -// const VEC_ZNX_BIG_ROWS: usize = 1; - -/// VecZnxBig is `Backend` dependent, denoted with backend generic `B` -pub struct VecZnxBig { +pub struct VecZnxBig { data: D, n: usize, cols: usize, @@ -14,7 +11,7 @@ pub struct VecZnxBig { _phantom: PhantomData, } -impl ZnxInfos for VecZnxBig { +impl ZnxInfos for VecZnxBig { fn cols(&self) -> usize { self.cols } @@ -30,20 +27,22 @@ impl ZnxInfos for VecZnxBig { fn size(&self) -> usize { self.size } +} +impl ZnxSliceSize for VecZnxBig { fn sl(&self) -> usize { - self.cols() * self.n() + self.n() * self.cols() } } -impl DataView for VecZnxBig { +impl DataView for VecZnxBig { type D = D; fn data(&self) -> &Self::D { &self.data } } -impl DataViewMut for VecZnxBig { +impl DataViewMut for VecZnxBig { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } @@ -82,7 +81,7 @@ impl>, B: Backend> VecZnxBig { } } -impl VecZnxBig { +impl VecZnxBig { pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { Self { data, @@ -96,8 +95,16 @@ impl VecZnxBig { pub type VecZnxBigOwned = VecZnxBig, B>; -impl VecZnxBig, B> { - pub fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { +pub trait VecZnxBigToRef { + fn to_ref(&self) -> VecZnxBig<&[u8], B>; +} + +pub trait VecZnxBigToMut { + fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>; +} + +impl VecZnxBigToMut for VecZnxBig, B> { + fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { VecZnxBig { data: self.data.as_mut_slice(), n: self.n, @@ -106,8 +113,10 @@ impl VecZnxBig, B> { _phantom: PhantomData, } } +} - pub fn to_ref(&self) -> VecZnxBig<&[u8], B> { +impl VecZnxBigToRef for VecZnxBig, B> { + fn to_ref(&self) -> VecZnxBig<&[u8], B> { VecZnxBig { data: self.data.as_slice(), n: self.n, @@ -117,3 +126,39 @@ impl VecZnxBig, B> { } } } + +impl VecZnxBigToMut for VecZnxBig<&mut [u8], B> { + fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { + VecZnxBig { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxBigToRef for VecZnxBig<&mut [u8], B> { + fn to_ref(&self) -> VecZnxBig<&[u8], B> { + VecZnxBig { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxBigToRef for VecZnxBig<&[u8], B> { + fn to_ref(&self) -> VecZnxBig<&[u8], B> { + VecZnxBig { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index d0e4bd3..185a20c 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,8 +1,11 @@ use crate::ffi::vec_znx; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxScratch, bytes_of_vec_znx_big}; +use crate::{ + Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxScratch, + VecZnxToMut, VecZnxToRef, ZnxSliceSize, bytes_of_vec_znx_big, +}; -pub trait VecZnxBigAlloc { +pub trait VecZnxBigAlloc { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned; @@ -39,79 +42,77 @@ pub trait VecZnxBigAlloc { fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; } -pub trait VecZnxBigOps { +pub trait VecZnxBigOps { /// Adds `a` to `b` and stores the result on `c`. - fn vec_znx_big_add( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnxBig, - b_col: usize, - ); + fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef; /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); + fn vec_znx_big_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; /// Adds `a` to `b` and stores the result on `c`. - fn vec_znx_big_add_small( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnx, - b_col: usize, - ); + fn vec_znx_big_add_small(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef; /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_big_add_small_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; /// Subtracts `a` to `b` and stores the result on `c`. - fn vec_znx_big_sub( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnxBig, - b_col: usize, - ); + fn vec_znx_big_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef; /// Subtracts `a` from `b` and stores the result on `b`. - fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); + fn vec_znx_big_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; /// Subtracts `b` from `a` and stores the result on `b`. - fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); + fn vec_znx_big_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; /// Subtracts `b` from `a` and stores the result on `c`. - fn vec_znx_big_sub_small_a( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnx, - a_col: usize, - b: &VecZnxBig, - b_col: usize, - ); + fn vec_znx_big_sub_small_a(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + B: VecZnxBigToRef; /// Subtracts `a` from `res` and stores the result on `res`. - fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; /// Subtracts `b` from `a` and stores the result on `c`. - fn vec_znx_big_sub_small_b( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnx, - b_col: usize, - ); + fn vec_znx_big_sub_small_b(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef; /// Subtracts `res` from `a` and stores the result on `res`. - fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; /// Normalizes `a` and stores the result on `b`. /// @@ -119,28 +120,28 @@ pub trait VecZnxBigOps { /// /// * `log_base2k`: normalization basis. /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize( + fn vec_znx_big_normalize( &self, log_base2k: usize, - res: &mut VecZnx, + res: &mut R, res_col: usize, - a: &VecZnxBig, + a: &A, a_col: usize, scratch: &mut Scratch, - ); + ) where + R: VecZnxToMut, + A: VecZnxBigToRef; /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_big_automorphism( - &self, - k: i64, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - ); + fn vec_znx_big_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize); + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut; } pub trait VecZnxBigScratch { @@ -157,29 +158,22 @@ impl VecZnxBigAlloc for Module { VecZnxBig::new_from_bytes(self, cols, size, bytes) } - // fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { - // VecZnxBig::from_bytes_borrow(self, 1, cols, size, tmp_bytes) - // } - fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { bytes_of_vec_znx_big(self, cols, size) } } -impl VecZnxBigOps for Module -where - DataMut: AsMut<[u8]> + AsRef<[u8]>, - Data: AsRef<[u8]>, -{ - fn vec_znx_big_add( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnxBig, - b_col: usize, - ) { +impl VecZnxBigOps for Module { + fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -203,13 +197,14 @@ where } } - fn vec_znx_big_add_inplace( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - ) { + fn vec_znx_big_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -231,15 +226,16 @@ where } } - fn vec_znx_big_sub( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnxBig, - b_col: usize, - ) { + fn vec_znx_big_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -263,13 +259,14 @@ where } } - fn vec_znx_big_sub_ab_inplace( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - ) { + fn vec_znx_big_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -291,13 +288,14 @@ where } } - fn vec_znx_big_sub_ba_inplace( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - ) { + fn vec_znx_big_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -319,15 +317,16 @@ where } } - fn vec_znx_big_sub_small_b( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnx, - b_col: usize, - ) { + fn vec_znx_big_sub_small_b(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -351,13 +350,14 @@ where } } - fn vec_znx_big_sub_small_b_inplace( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnx, - a_col: usize, - ) { + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -379,15 +379,16 @@ where } } - fn vec_znx_big_sub_small_a( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnx, - a_col: usize, - b: &VecZnxBig, - b_col: usize, - ) { + fn vec_znx_big_sub_small_a(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + B: VecZnxBigToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -411,13 +412,14 @@ where } } - fn vec_znx_big_sub_small_a_inplace( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnx, - a_col: usize, - ) { + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -439,15 +441,16 @@ where } } - fn vec_znx_big_add_small( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnx, - b_col: usize, - ) { + fn vec_znx_big_add_small(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -471,7 +474,14 @@ where } } - fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_big_add_small_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -493,22 +503,28 @@ where } } - fn vec_znx_big_normalize( + fn vec_znx_big_normalize( &self, log_base2k: usize, - res: &mut VecZnx, + res: &mut R, res_col: usize, - a: &VecZnxBig, + a: &A, a_col: usize, scratch: &mut Scratch, - ) { + ) where + R: VecZnxToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); assert_eq!(res.n(), self.n()); //(Jay)Note: This is calling VezZnxOps::vec_znx_normalize_tmp_bytes and not VecZnxBigOps::vec_znx_big_normalize_tmp_bytes. // In the FFT backend the tmp sizes are same but will be different in the NTT backend - // assert!(tmp_bytes.len() >= >::vec_znx_normalize_tmp_bytes(&self)); + // assert!(tmp_bytes.len() >= >::vec_znx_normalize_tmp_bytes(&self)); // assert_alignement(tmp_bytes.as_ptr()); } @@ -530,14 +546,14 @@ where } } - fn vec_znx_big_automorphism( - &self, - k: i64, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - ) { + fn vec_znx_big_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -557,7 +573,12 @@ where } } - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize) { + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut, + { + let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index a4a3242..66e58cf 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -2,12 +2,9 @@ use std::marker::PhantomData; use crate::ffi::vec_znx_dft; use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; -// const VEC_ZNX_DFT_ROWS: usize = 1; - -// VecZnxDft is `Backend` dependent denoted with generic `B` -pub struct VecZnxDft { +pub struct VecZnxDft { data: D, n: usize, cols: usize, @@ -15,7 +12,7 @@ pub struct VecZnxDft { _phantom: PhantomData, } -impl ZnxInfos for VecZnxDft { +impl ZnxInfos for VecZnxDft { fn cols(&self) -> usize { self.cols } @@ -31,20 +28,22 @@ impl ZnxInfos for VecZnxDft { fn size(&self) -> usize { self.size } +} +impl ZnxSliceSize for VecZnxDft { fn sl(&self) -> usize { - self.cols() * self.n() + self.n() * self.cols() } } -impl DataView for VecZnxDft { +impl DataView for VecZnxDft { type D = D; fn data(&self) -> &Self::D { &self.data } } -impl DataViewMut for VecZnxDft { +impl DataViewMut for VecZnxDft { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } @@ -85,7 +84,7 @@ impl>, B: Backend> VecZnxDft { pub type VecZnxDftOwned = VecZnxDft, B>; -impl VecZnxDft { +impl VecZnxDft { pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { Self { data, @@ -97,8 +96,16 @@ impl VecZnxDft { } } -impl VecZnxDft, B> { - pub fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { +pub trait VecZnxDftToRef { + fn to_ref(&self) -> VecZnxDft<&[u8], B>; +} + +pub trait VecZnxDftToMut { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>; +} + +impl VecZnxDftToMut for VecZnxDft, B> { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { VecZnxDft { data: self.data.as_mut_slice(), n: self.n, @@ -107,8 +114,10 @@ impl VecZnxDft, B> { _phantom: PhantomData, } } +} - pub fn to_ref(&self) -> VecZnxDft<&[u8], B> { +impl VecZnxDftToRef for VecZnxDft, B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { VecZnxDft { data: self.data.as_slice(), n: self.n, @@ -119,10 +128,34 @@ impl VecZnxDft, B> { } } -impl VecZnxDft<&mut [u8], B> { - pub fn to_ref(&self) -> VecZnxDft<&[u8], B> { +impl VecZnxDftToMut for VecZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { VecZnxDft { - data: &self.data, + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for VecZnxDft<&mut [u8], B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for VecZnxDft<&[u8], B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data, n: self.n, cols: self.cols, size: self.size, diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index e894ef4..83b7c26 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -1,11 +1,11 @@ use crate::ffi::{vec_znx_big, vec_znx_dft}; use crate::vec_znx_dft::bytes_of_vec_znx_dft; use crate::znx_base::ZnxInfos; -use crate::{Backend, VecZnxDftOwned}; -use crate::{FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxView, ZnxViewMut, ZnxZero, assert_alignement}; +use crate::{Backend, Scratch, VecZnxBigToMut, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxSliceSize}; +use crate::{FFT64, Module, ZnxView, ZnxViewMut, ZnxZero}; use std::cmp::min; -pub trait VecZnxDftAlloc { +pub trait VecZnxDftAlloc { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned; @@ -34,24 +34,26 @@ pub trait VecZnxDftAlloc { fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; } -pub trait VecZnxDftOps { +pub trait VecZnxDftOps { /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. fn vec_znx_idft_tmp_bytes(&self) -> usize; /// b <- IDFT(a), uses a as scratch space. - fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig, res_col: usize, a: &mut VecZnxDft, a_cols: usize); + fn vec_znx_idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_cols: usize) + where + R: VecZnxBigToMut, + A: VecZnxDftToMut; - fn vec_znx_idft( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxDft, - a_col: usize, - tmp_bytes: &mut [u8], - ); + fn vec_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxBigToMut, + A: VecZnxDftToRef; - fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_dft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxToRef; } impl VecZnxDftAlloc for Module { @@ -63,41 +65,34 @@ impl VecZnxDftAlloc for Module { VecZnxDftOwned::new_from_bytes(self, cols, size, bytes) } - // fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft { - // VecZnxDft::from_bytes_borrow(self, 1, cols, size, bytes) - // } - fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { bytes_of_vec_znx_dft(self, cols, size) } } -impl VecZnxDftOps for Module -where - DataMut: AsMut<[u8]> + AsRef<[u8]>, - Data: AsRef<[u8]>, -{ - fn vec_znx_idft_tmp_a( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &mut VecZnxDft, - a_col: usize, - ) { - let min_size: usize = min(res.size(), a.size()); +impl VecZnxDftOps for Module { + fn vec_znx_idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxDftToMut, + { + let mut res_mut = res.to_mut(); + let mut a_mut = a.to_mut(); + + let min_size: usize = min(res_mut.size(), a_mut.size()); unsafe { (0..min_size).for_each(|j| { vec_znx_dft::vec_znx_idft_tmp_a( self.ptr, - res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, 1 as u64, - a.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, 1 as u64, ) }); - (min_size..res.size()).for_each(|j| { - res.zero_at(res_col, j); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); }) } } @@ -110,61 +105,59 @@ where /// /// # Panics /// If b.cols < a_cols - fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize) { - let min_size: usize = min(res.size(), a.size()); + fn vec_znx_dft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxToRef, + { + let mut res_mut = res.to_mut(); + let a_ref = a.to_ref(); + + let min_size: usize = min(res_mut.size(), a_ref.size()); unsafe { (0..min_size).for_each(|j| { vec_znx_dft::vec_znx_dft( self.ptr, - res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, 1 as u64, - a.at_ptr(a_col, j), + a_ref.at_ptr(a_col, j), 1 as u64, - a.sl() as u64, + a_ref.sl() as u64, ) }); - (min_size..res.size()).for_each(|j| { - res.zero_at(res_col, j); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); }); } } // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. - fn vec_znx_idft( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxDft, - a_col: usize, - tmp_bytes: &mut [u8], - ) { - #[cfg(debug_assertions)] - { - assert!( - tmp_bytes.len() >= >::vec_znx_idft_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", - tmp_bytes.len(), - >::vec_znx_idft_tmp_bytes(self) - ); - assert_alignement(tmp_bytes.as_ptr()) - } + fn vec_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxBigToMut, + A: VecZnxDftToRef, + { + let mut res_mut = res.to_mut(); + let a_ref = a.to_ref(); - let min_size: usize = min(res.size(), a.size()); + let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_idft_tmp_bytes()); + + let min_size: usize = min(res_mut.size(), a_ref.size()); unsafe { (0..min_size).for_each(|j| { vec_znx_dft::vec_znx_idft( self.ptr, - res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, 1 as u64, - a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, 1 as u64, tmp_bytes.as_mut_ptr(), ) }); - (min_size..res.size()).for_each(|j| { - res.zero_at(res_col, j); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); }); } } diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index a8edb12..cdabe24 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,6 +1,9 @@ use crate::ffi::vec_znx; -use crate::znx_base::{ZnxInfos, switch_degree}; -use crate::{Backend, Module, VecZnx, VecZnxOwned, ZnxView, ZnxViewMut, assert_alignement}; +use crate::{ + Backend, Module, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, +}; +use itertools::izip; +use std::cmp::min; pub trait VecZnxAlloc { /// Allocates a new [VecZnx]. @@ -29,73 +32,86 @@ pub trait VecZnxAlloc { fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; } -pub trait VecZnxOps { +pub trait VecZnxOps { /// Normalizes the selected column of `a` and stores the result into the selected column of `res`. - fn vec_znx_normalize( - &self, - log_base2k: usize, - res: &mut VecZnx, - res_col: usize, - a: &VecZnx, - a_col: usize, - tmp_bytes: &mut [u8], - ); + fn vec_znx_normalize(&self, log_base2k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef; /// Normalizes the selected column of `a`. - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]); + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: VecZnxToMut; /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`. - fn vec_znx_add( - &self, - res: &mut VecZnx, - res_col: usize, - a: &VecZnx, - a_col: usize, - b: &VecZnx, - b_col: usize, - ); + fn vec_znx_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef; /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`. - fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; /// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`. - fn vec_znx_sub( - &self, - res: &mut VecZnx, - res_col: usize, - a: &VecZnx, - a_col: usize, - b: &VecZnx, - b_col: usize, - ); + fn vec_znx_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef; /// Subtracts the selected column of `a` from the selected column of `res` inplace. /// /// res[res_col] -= a[a_col] - fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; /// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res` /// /// res[res_col] = a[a_col] - res[res_col] - fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; // Negates the selected column of `a` and stores the result in `res_col` of `res`. - fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_negate(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; /// Negates the selected column of `a`. - fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize); + fn vec_znx_negate_inplace(&self, a: &mut A, a_col: usize) + where + A: VecZnxToMut; /// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`. - fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_rotate(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; /// Multiplies the selected column of `a` by X^k. - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut; /// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`. - fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; /// Applies the automorphism X^i -> X^ik on the selected column of `a`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut; /// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`. /// @@ -103,14 +119,10 @@ pub trait VecZnxOps { /// /// This method requires that all [VecZnx] of b have the same ring degree /// and that b.n() * b.len() <= a.n() - fn vec_znx_split( - &self, - res: &mut Vec>, - res_col: usize, - a: &VecZnx, - a_col: usize, - buf: &mut VecZnx, - ); + fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef; /// Merges the subrings of the selected column of `a` into the selected column of `res`. /// @@ -118,7 +130,15 @@ pub trait VecZnxOps { /// /// This method requires that all [VecZnx] of a have the same ring degree /// and that a.n() * a.len() <= b.n() - fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec>, a_col: usize); + fn vec_znx_merge(&self, res: &mut R, res_col: usize, a: Vec, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; + + fn switch_degree(&self, r: &mut R, col_b: usize, a: &A, col_a: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; } pub trait VecZnxScratch { @@ -140,27 +160,23 @@ impl VecZnxAlloc for Module { } } -impl VecZnxOps for Module -where - Data: AsRef<[u8]>, - DataMut: AsRef<[u8]> + AsMut<[u8]>, -{ - fn vec_znx_normalize( - &self, - log_base2k: usize, - res: &mut VecZnx, - res_col: usize, - a: &VecZnx, - a_col: usize, - tmp_bytes: &mut [u8], - ) { +impl VecZnxOps for Module { + fn vec_znx_normalize(&self, log_base2k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); assert_eq!(res.n(), self.n()); - assert!(tmp_bytes.len() >= ::vec_znx_normalize_tmp_bytes(&self)); - assert_alignement(tmp_bytes.as_ptr()); } + + let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_normalize_tmp_bytes()); + unsafe { vec_znx::vec_znx_normalize_base2k( self.ptr, @@ -176,22 +192,44 @@ where } } - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + + let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_normalize_tmp_bytes()); + unsafe { - let a_ptr: *const VecZnx<_> = a; - Self::vec_znx_normalize(self, log_base2k, a, a_col, &*a_ptr, a_col, tmp_bytes); + vec_znx::vec_znx_normalize_base2k( + self.ptr, + log_base2k as u64, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + tmp_bytes.as_mut_ptr(), + ); } } - fn vec_znx_add( - &self, - res: &mut VecZnx, - res_col: usize, - a: &VecZnx, - a_col: usize, - b: &VecZnx, - b_col: usize, - ) { + fn vec_znx_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -215,7 +253,14 @@ where } } - fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -237,15 +282,16 @@ where } } - fn vec_znx_sub( - &self, - res: &mut VecZnx, - res_col: usize, - a: &VecZnx, - a_col: usize, - b: &VecZnx, - b_col: usize, - ) { + fn vec_znx_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -269,7 +315,13 @@ where } } - fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -291,7 +343,13 @@ where } } - fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -313,7 +371,13 @@ where } } - fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_negate(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -332,14 +396,35 @@ where } } - fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize) { + fn vec_znx_negate_inplace(&self, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } unsafe { - let a_ref: *const VecZnx<_> = a; - Self::vec_znx_negate(self, a, a_col, a_ref.as_ref().unwrap(), a_col); + vec_znx::vec_znx_negate( + self.ptr, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) } } - fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_rotate(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -359,7 +444,11 @@ where } } - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -378,7 +467,13 @@ where } } - fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -398,7 +493,11 @@ where } } - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -417,23 +516,24 @@ where } } - fn vec_znx_split( - &self, - res: &mut Vec>, - res_col: usize, - a: &VecZnx, - a_col: usize, - buf: &mut VecZnx, - ) { - let (n_in, n_out) = (a.n(), res[0].n()); + fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + + let (n_in, n_out) = (a.n(), res[0].to_mut().n()); + + let (mut buf, _) = scratch.tmp_vec_znx(self, 1, a.size()); debug_assert!( n_out < n_in, "invalid a: output ring degree should be smaller" ); - res[1..].iter().for_each(|bi| { + res[1..].iter_mut().for_each(|bi| { debug_assert_eq!( - bi.n(), + bi.to_mut().n(), n_out, "invalid input a: all VecZnx must have the same degree" ) @@ -441,17 +541,23 @@ where res.iter_mut().enumerate().for_each(|(i, bi)| { if i == 0 { - switch_degree(bi, res_col, a, a_col); - self.vec_znx_rotate(-1, buf, 0, a, a_col); + self.switch_degree(bi, res_col, &a, a_col); + self.vec_znx_rotate(-1, &mut buf, 0, &a, a_col); } else { - switch_degree(bi, res_col, buf, a_col); - >::vec_znx_rotate_inplace(self, -1, buf, a_col); + self.switch_degree(bi, res_col, &mut buf, a_col); + self.vec_znx_rotate_inplace(-1, &mut buf, a_col); } }) } - fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec>, a_col: usize) { - let (n_in, n_out) = (res.n(), a[0].n()); + fn vec_znx_merge(&self, res: &mut R, res_col: usize, a: Vec, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + let (n_in, n_out) = (res.n(), a[0].to_ref().n()); debug_assert!( n_out < n_in, @@ -459,18 +565,47 @@ where ); a[1..].iter().for_each(|ai| { debug_assert_eq!( - ai.n(), + ai.to_ref().n(), n_out, "invalid input a: all VecZnx must have the same degree" ) }); a.iter().enumerate().for_each(|(_, ai)| { - switch_degree(res, res_col, ai, a_col); - >::vec_znx_rotate_inplace(self, -1, res, res_col); + self.switch_degree(&mut res, res_col, ai, a_col); + self.vec_znx_rotate_inplace(-1, &mut res, res_col); }); - >::vec_znx_rotate_inplace(self, a.len() as i64, res, res_col); + self.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col); + } + + fn switch_degree(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + let (n_in, n_out) = (a.n(), res.n()); + let (gap_in, gap_out): (usize, usize); + + if n_in > n_out { + (gap_in, gap_out) = (n_in / n_out, 1) + } else { + (gap_in, gap_out) = (1, n_out / n_in); + res.zero(); + } + + let size: usize = min(a.size(), res.size()); + + (0..size).for_each(|i| { + izip!( + a.at(a_col, i).iter().step_by(gap_in), + res.at_mut(res_col, i).iter_mut().step_by(gap_out) + ) + .for_each(|(x_in, x_out)| *x_out = *x_in); + }); } } diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index 9eea5bb..db6a50c 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -1,6 +1,5 @@ use itertools::izip; use rand_distr::num_traits::Zero; -use std::cmp::min; pub trait ZnxInfos { /// Returns the ring degree of the polynomials. @@ -24,7 +23,9 @@ pub trait ZnxInfos { fn poly_count(&self) -> usize { self.rows() * self.cols() * self.size() } +} +pub trait ZnxSliceSize { /// Returns the slice size, which is the offset between /// two size of the same column. fn sl(&self) -> usize; @@ -129,33 +130,6 @@ where impl ZnxZero for T where T: ZnxViewMut {} // impl ZnxRsh for T where T: ZnxZero {} -pub fn switch_degree + ZnxZero, D: ZnxView>( - b: &mut DMut, - col_b: usize, - a: &D, - col_a: usize, -) { - let (n_in, n_out) = (a.n(), b.n()); - let (gap_in, gap_out): (usize, usize); - - if n_in > n_out { - (gap_in, gap_out) = (n_in / n_out, 1) - } else { - (gap_in, gap_out) = (1, n_out / n_in); - b.zero(); - } - - let size: usize = min(a.size(), b.size()); - - (0..size).for_each(|i| { - izip!( - a.at(col_a, i).iter().step_by(gap_in), - b.at_mut(col_b, i).iter_mut().step_by(gap_out) - ) - .for_each(|(x_in, x_out)| *x_out = *x_in); - }); -} - use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; use crate::Scratch; From 08e81f50c9bfa72984474a9e83dc85372b0c1f42 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 6 May 2025 11:30:55 +0200 Subject: [PATCH 29/87] updated sampling traits --- base2k/examples/rlwe_encrypt.rs | 11 +- base2k/spqlios-arithmetic | 2 +- base2k/src/mat_znx_dft.rs | 2 +- base2k/src/mat_znx_dft_ops.rs | 22 +- base2k/src/sampling.rs | 270 ++++++++++++++++++---- rlwe/benches/gadget_product.rs | 139 ------------ rlwe/examples/encryption.rs | 76 ------- rlwe/src/automorphism.rs | 349 ----------------------------- rlwe/src/ciphertext.rs | 95 +------- rlwe/src/decryptor.rs | 67 ------ rlwe/src/elem.rs | 168 -------------- rlwe/src/encryptor.rs | 369 ------------------------------ rlwe/src/gadget_product.rs | 383 -------------------------------- rlwe/src/key_generator.rs | 55 ----- rlwe/src/key_switching.rs | 79 ------- rlwe/src/keys.rs | 82 ------- rlwe/src/lib.rs | 14 +- rlwe/src/parameters.rs | 88 -------- rlwe/src/plaintext.rs | 109 --------- rlwe/src/rgsw_product.rs | 300 ------------------------- rlwe/src/test.rs | 113 ---------- rlwe/src/trace.rs | 236 -------------------- 22 files changed, 251 insertions(+), 2778 deletions(-) delete mode 100644 rlwe/benches/gadget_product.rs delete mode 100644 rlwe/examples/encryption.rs delete mode 100644 rlwe/src/automorphism.rs delete mode 100644 rlwe/src/decryptor.rs delete mode 100644 rlwe/src/elem.rs delete mode 100644 rlwe/src/encryptor.rs delete mode 100644 rlwe/src/gadget_product.rs delete mode 100644 rlwe/src/key_generator.rs delete mode 100644 rlwe/src/key_switching.rs delete mode 100644 rlwe/src/keys.rs delete mode 100644 rlwe/src/parameters.rs delete mode 100644 rlwe/src/plaintext.rs delete mode 100644 rlwe/src/rgsw_product.rs delete mode 100644 rlwe/src/test.rs delete mode 100644 rlwe/src/trace.rs diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index b55efba..79270ea 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,7 +1,7 @@ use base2k::{ - Encoding, FFT64, Module, Sampling, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, - VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, - VecZnxOps, ZnxInfos, + AddNormal, Encoding, FFT64, FillUniform, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, + ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, + VecZnxDftOps, VecZnxOps, ZnxInfos, }; use itertools::izip; use sampling::source::Source; @@ -36,7 +36,7 @@ fn main() { ); // Fill the second column with random values: ct = (0, a) - module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source); + ct.fill_uniform(log_base2k, 1, ct_size, &mut source); let mut buf_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_size); @@ -88,9 +88,8 @@ fn main() { // Add noise to ct[0] // ct[0] <- ct[0] + e - module.add_normal( + ct.add_normal( log_base2k, - &mut ct, 0, // Selects the first column of ct (ct[0]) log_base2k * ct_size, // Scaling of the noise: 2^{-log_base2k * limbs} &mut source, diff --git a/base2k/spqlios-arithmetic b/base2k/spqlios-arithmetic index e3d3247..8135d85 160000 --- a/base2k/spqlios-arithmetic +++ b/base2k/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit e3d3247335faccf2b6361213c354cd61b958325e +Subproject commit 8135d85e7ac14601568fdd228e7dedf88994f7cf diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index 1f18b48..c34115d 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -151,7 +151,7 @@ impl> MatZnxDft { } } -pub type MatZnxDftAllocOwned = MatZnxDft, B>; +pub type MatZnxDftOwned = MatZnxDft, B>; pub trait MatZnxDftToRef { fn to_ref(&self) -> MatZnxDft<&[u8], B>; diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 9b79a2c..ae0cbb5 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -2,7 +2,7 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, + Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, }; @@ -13,7 +13,7 @@ pub trait MatZnxDftAlloc { /// /// * `rows`: number of rows (number of [VecZnxDft]). /// * `size`: number of size (number of size of each [VecZnxDft]). - fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftAllocOwned; + fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned; fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; @@ -24,7 +24,7 @@ pub trait MatZnxDftAlloc { cols_out: usize, size: usize, bytes: Vec, - ) -> MatZnxDftAllocOwned; + ) -> MatZnxDftOwned; } pub trait MatZnxDftScratch { @@ -103,11 +103,11 @@ pub trait MatZnxDftOps { impl MatZnxDftAlloc for Module { fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - MatZnxDftAllocOwned::bytes_of(self, rows, cols_in, cols_out, size) + MatZnxDftOwned::bytes_of(self, rows, cols_in, cols_out, size) } - fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftAllocOwned { - MatZnxDftAllocOwned::new(self, rows, cols_in, cols_out, size) + fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned { + MatZnxDftOwned::new(self, rows, cols_in, cols_out, size) } fn new_mat_znx_dft_from_bytes( @@ -117,8 +117,8 @@ impl MatZnxDftAlloc for Module { cols_out: usize, size: usize, bytes: Vec, - ) -> MatZnxDftAllocOwned { - MatZnxDftAllocOwned::new_from_bytes(self, rows, cols_in, cols_out, size, bytes) + ) -> MatZnxDftOwned { + MatZnxDftOwned::new_from_bytes(self, rows, cols_in, cols_out, size, bytes) } } @@ -305,8 +305,8 @@ impl MatZnxDftOps for Module { #[cfg(test)] mod tests { use crate::{ - Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, - VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut, + Encoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, + VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut, }; use sampling::source::Source; @@ -329,7 +329,7 @@ mod tests { for row_i in 0..mat_rows { let mut source: Source = Source::new([0u8; 32]); (0..mat_cols_out).for_each(|col_out| { - module.fill_uniform(log_base2k, &mut a, col_out, mat_size, &mut source); + a.fill_uniform(log_base2k, col_out, mat_size, &mut source); module.vec_znx_dft(&mut a_dft, col_out, &a, col_out); }); module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft); diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index b254286..212658a 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,47 +1,53 @@ use crate::znx_base::ZnxViewMut; -use crate::{Backend, Module, VecZnx, VecZnxToMut}; +use crate::{FFT64, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxToMut}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; -pub trait Sampling { +pub trait FillUniform { /// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] - fn fill_uniform(&self, log_base2k: usize, a: &mut A, col_i: usize, size: usize, source: &mut Source) - where - A: VecZnxToMut; + fn fill_uniform(&mut self, log_base2k: usize, col_i: usize, size: usize, source: &mut Source); +} - /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. - fn add_dist_f64>( - &self, +pub trait FillDistF64 { + fn fill_dist_f64>( + &mut self, log_base2k: usize, - a: &mut A, col_i: usize, log_k: usize, source: &mut Source, dist: D, bound: f64, - ) where - A: VecZnxToMut; + ); +} - /// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\]. - fn add_normal( - &self, +pub trait AddDistF64 { + /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. + fn add_dist_f64>( + &mut self, log_base2k: usize, - a: &mut A, col_i: usize, log_k: usize, source: &mut Source, - sigma: f64, + dist: D, bound: f64, - ) where - A: VecZnxToMut; + ); } -impl Sampling for Module { - fn fill_uniform(&self, log_base2k: usize, a: &mut A, col_i: usize, size: usize, source: &mut Source) - where - A: VecZnxToMut, - { - let mut a: VecZnx<&mut [u8]> = a.to_mut(); +pub trait FillNormal { + fn fill_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64); +} + +pub trait AddNormal { + /// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\]. + fn add_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64); +} + +impl FillUniform for VecZnx +where + VecZnx: VecZnxToMut, +{ + fn fill_uniform(&mut self, log_base2k: usize, col_i: usize, size: usize, source: &mut Source) { + let mut a: VecZnx<&mut [u8]> = self.to_mut(); let base2k: u64 = 1 << log_base2k; let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; @@ -51,20 +57,65 @@ impl Sampling for Module { .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); }) } +} - fn add_dist_f64>( - &self, +impl FillDistF64 for VecZnx +where + VecZnx: VecZnxToMut, +{ + fn fill_dist_f64>( + &mut self, log_base2k: usize, - a: &mut A, col_i: usize, log_k: usize, source: &mut Source, dist: D, bound: f64, - ) where - A: VecZnxToMut, - { - let mut a: VecZnx<&mut [u8]> = a.to_mut(); + ) { + let mut a: VecZnx<&mut [u8]> = self.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; + let log_base2k_rem: usize = log_k % log_base2k; + + if log_base2k_rem != 0 { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a = (dist_f64.round() as i64) << log_base2k_rem; + }); + } else { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a = dist_f64.round() as i64 + }); + } + } +} + +impl AddDistF64 for VecZnx +where + VecZnx: VecZnxToMut, +{ + fn add_dist_f64>( + &mut self, + log_base2k: usize, + col_i: usize, + log_k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) { + let mut a: VecZnx<&mut [u8]> = self.to_mut(); assert!( (bound.log2().ceil() as i64) < 64, "invalid bound: ceil(log2(bound))={} > 63", @@ -92,14 +143,149 @@ impl Sampling for Module { }); } } +} - fn add_normal(&self, log_base2k: usize, a: &mut A, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) - where - A: VecZnxToMut, - { +impl FillNormal for VecZnx +where + VecZnx: VecZnxToMut, +{ + fn fill_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) { + self.fill_dist_f64( + log_base2k, + col_i, + log_k, + source, + Normal::new(0.0, sigma).unwrap(), + bound, + ); + } +} + +impl AddNormal for VecZnx +where + VecZnx: VecZnxToMut, +{ + fn add_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) { + self.add_dist_f64( + log_base2k, + col_i, + log_k, + source, + Normal::new(0.0, sigma).unwrap(), + bound, + ); + } +} + +impl FillDistF64 for VecZnxBig +where + VecZnxBig: VecZnxBigToMut, +{ + fn fill_dist_f64>( + &mut self, + log_base2k: usize, + col_i: usize, + log_k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) { + let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; + let log_base2k_rem: usize = log_k % log_base2k; + + if log_base2k_rem != 0 { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a = (dist_f64.round() as i64) << log_base2k_rem; + }); + } else { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a = dist_f64.round() as i64 + }); + } + } +} + +impl AddDistF64 for VecZnxBig +where + VecZnxBig: VecZnxBigToMut, +{ + fn add_dist_f64>( + &mut self, + log_base2k: usize, + col_i: usize, + log_k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) { + let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; + let log_base2k_rem: usize = log_k % log_base2k; + + if log_base2k_rem != 0 { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += (dist_f64.round() as i64) << log_base2k_rem; + }); + } else { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += dist_f64.round() as i64 + }); + } + } +} + +impl FillNormal for VecZnxBig +where + VecZnxBig: VecZnxBigToMut, +{ + fn fill_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) { + self.fill_dist_f64( + log_base2k, + col_i, + log_k, + source, + Normal::new(0.0, sigma).unwrap(), + bound, + ); + } +} + +impl AddNormal for VecZnxBig +where + VecZnxBig: VecZnxBigToMut, +{ + fn add_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) { self.add_dist_f64( log_base2k, - a, col_i, log_k, source, @@ -111,14 +297,16 @@ impl Sampling for Module { #[cfg(test)] mod tests { - use super::Sampling; + use std::fmt::Display; + + use super::{AddNormal, FillUniform}; use crate::vec_znx_ops::*; use crate::znx_base::*; use crate::{FFT64, Module, Stats, VecZnx}; use sampling::source::Source; #[test] - fn fill_uniform() { + fn vec_znx_fill_uniform() { let n: usize = 4096; let module: Module = Module::::new(n); let log_base2k: usize = 17; @@ -129,7 +317,7 @@ mod tests { let one_12_sqrt: f64 = 0.28867513459481287; (0..cols).for_each(|col_i| { let mut a: VecZnx<_> = module.new_vec_znx(cols, size); - module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source); + a.fill_uniform(log_base2k, col_i, size, &mut source); (0..cols).for_each(|col_j| { if col_j != col_i { (0..size).for_each(|limb_i| { @@ -149,7 +337,7 @@ mod tests { } #[test] - fn add_normal() { + fn vec_znx_add_normal() { let n: usize = 4096; let module: Module = Module::::new(n); let log_base2k: usize = 17; @@ -163,7 +351,7 @@ mod tests { let k_f64: f64 = (1u64 << log_k as u64) as f64; (0..cols).for_each(|col_i| { let mut a: VecZnx<_> = module.new_vec_znx(cols, size); - module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound); + a.add_normal(log_base2k, col_i, log_k, &mut source, sigma, bound); (0..cols).for_each(|col_j| { if col_j != col_i { (0..size).for_each(|limb_i| { diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs deleted file mode 100644 index 14bb06d..0000000 --- a/rlwe/benches/gadget_product.rs +++ /dev/null @@ -1,139 +0,0 @@ -use base2k::{BACKEND, Module, Sampling, ScalarZnxDftOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, MatZnxDft, alloc_aligned_u8}; -use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; -use rlwe::{ - ciphertext::{Ciphertext, new_gadget_ciphertext}, - elem::ElemCommon, - encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes}, - gadget_product::{gadget_product_core, gadget_product_core_tmp_bytes}, - keys::SecretKey, - parameters::{Parameters, ParametersLiteral}, -}; -use sampling::source::Source; - -fn bench_gadget_product_inplace(c: &mut Criterion) { - fn runner<'a>( - module: &'a Module, - res_dft_0: &'a mut VecZnxDft, - res_dft_1: &'a mut VecZnxDft, - a: &'a VecZnx, - b: &'a Ciphertext, - b_cols: usize, - tmp_bytes: &'a mut [u8], - ) -> Box { - Box::new(move || { - gadget_product_core(module, res_dft_0, res_dft_1, a, b, b_cols, tmp_bytes); - }) - } - - let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("gadget_product_inplace"); - - for log_n in 10..11 { - let params_lit: ParametersLiteral = ParametersLiteral { - backend: BACKEND::FFT64, - log_n: log_n, - log_q: 32, - log_p: 0, - log_base2k: 16, - log_scale: 20, - xe: 3.2, - xs: 128, - }; - - let params: Parameters = Parameters::new(¶ms_lit); - - let mut tmp_bytes: Vec = alloc_aligned_u8( - params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) - | gadget_product_core_tmp_bytes( - params.module(), - params.log_base2k(), - params.log_q(), - params.log_q(), - params.cols_q(), - params.log_qp(), - ) - | encrypt_grlwe_sk_tmp_bytes( - params.module(), - params.log_base2k(), - params.cols_qp(), - params.log_qp(), - ), - ); - - let mut source: Source = Source::new([3; 32]); - - let mut sk0: SecretKey = SecretKey::new(params.module()); - let mut sk1: SecretKey = SecretKey::new(params.module()); - sk0.fill_ternary_hw(params.xs(), &mut source); - sk1.fill_ternary_hw(params.xs(), &mut source); - - let mut source_xe: Source = Source::new([4; 32]); - let mut source_xa: Source = Source::new([5; 32]); - - let mut sk0_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0); - - let mut sk1_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0); - - let mut gadget_ct: Ciphertext = new_gadget_ciphertext( - params.module(), - params.log_base2k(), - params.cols_q(), - params.log_qp(), - ); - - encrypt_grlwe_sk( - params.module(), - &mut gadget_ct, - &sk0.0, - &sk1_svp_ppol, - &mut source_xa, - &mut source_xe, - params.xe(), - &mut tmp_bytes, - ); - - let mut ct: Ciphertext = params.new_ciphertext(params.log_q()); - - params.encrypt_rlwe_sk( - &mut ct, - None, - &sk0_svp_ppol, - &mut source_xa, - &mut source_xe, - &mut tmp_bytes, - ); - - let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(1, gadget_ct.cols()); - let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(1, gadget_ct.cols()); - - let mut a: VecZnx = params.module().new_vec_znx(0, params.cols_q()); - params - .module() - .fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa); - - let b_cols: usize = gadget_ct.cols(); - - let runners: [(String, Box); 1] = [(format!("gadget_product"), { - runner( - params.module(), - &mut res_dft_0, - &mut res_dft_1, - &mut a, - &gadget_ct, - b_cols, - &mut tmp_bytes, - ) - })]; - - for (name, mut runner) in runners { - let id: BenchmarkId = BenchmarkId::new(name, format!("n={}", 1 << log_n)); - b.bench_with_input(id, &(), |b: &mut criterion::Bencher<'_>, _| { - b.iter(&mut runner) - }); - } - } -} - -criterion_group!(benches, bench_gadget_product_inplace); -criterion_main!(benches); diff --git a/rlwe/examples/encryption.rs b/rlwe/examples/encryption.rs deleted file mode 100644 index 20a0603..0000000 --- a/rlwe/examples/encryption.rs +++ /dev/null @@ -1,76 +0,0 @@ -use base2k::{Encoding, ScalarZnxDftOps, VecZnx, alloc_aligned}; -use rlwe::{ - ciphertext::Ciphertext, - elem::ElemCommon, - keys::SecretKey, - parameters::{Parameters, ParametersLiteral}, - plaintext::Plaintext, -}; -use sampling::source::Source; - -fn main() { - let params_lit: ParametersLiteral = ParametersLiteral { - backend: base2k::BACKEND::FFT64, - log_n: 10, - log_q: 54, - log_p: 0, - log_base2k: 17, - log_scale: 20, - xe: 3.2, - xs: 128, - }; - - let params: Parameters = Parameters::new(¶ms_lit); - - let mut tmp_bytes: Vec = - alloc_aligned(params.decrypt_rlwe_tmp_byte(params.log_q()) | params.encrypt_rlwe_sk_tmp_bytes(params.log_q())); - - let mut source: Source = Source::new([0; 32]); - let mut sk: SecretKey = SecretKey::new(params.module()); - sk.fill_ternary_hw(params.xs(), &mut source); - - let mut want = vec![i64::default(); params.n()]; - - want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - - let mut pt: Plaintext = params.new_plaintext(params.log_q()); - - let log_base2k = pt.log_base2k(); - - let log_k: usize = params.log_q() - 20; - - pt.0.value[0].encode_vec_i64(0, log_base2k, log_k, &want, 32); - pt.0.value[0].normalize(log_base2k, &mut tmp_bytes); - - println!("log_k: {}", log_k); - pt.0.value[0].print(0, pt.cols(), 16); - println!(); - - let mut ct: Ciphertext = params.new_ciphertext(params.log_q()); - - let mut source_xe: Source = Source::new([1; 32]); - let mut source_xa: Source = Source::new([2; 32]); - - let mut sk_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk_svp_ppol, &sk.0); - - params.encrypt_rlwe_sk( - &mut ct, - Some(&pt), - &sk_svp_ppol, - &mut source_xa, - &mut source_xe, - &mut tmp_bytes, - ); - - params.decrypt_rlwe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes); - pt.0.value[0].print(0, pt.cols(), 16); - - let mut have = vec![i64::default(); params.n()]; - - println!("pt: {}", log_k); - pt.0.value[0].decode_vec_i64(0, pt.log_base2k(), log_k, &mut have); - - println!("want: {:?}", &want[..16]); - println!("have: {:?}", &have[..16]); -} diff --git a/rlwe/src/automorphism.rs b/rlwe/src/automorphism.rs deleted file mode 100644 index ea2b834..0000000 --- a/rlwe/src/automorphism.rs +++ /dev/null @@ -1,349 +0,0 @@ -use crate::{ - ciphertext::{Ciphertext, new_gadget_ciphertext}, - elem::ElemCommon, - encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes}, - key_switching::{key_switch_rlwe, key_switch_rlwe_inplace, key_switch_tmp_bytes}, - keys::SecretKey, - parameters::Parameters, -}; -use base2k::{ - MatZnxDft, MatZnxDftOps, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, - VecZnxDft, VecZnxDftOps, VecZnxOps, assert_alignement, -}; -use sampling::source::Source; -use std::collections::HashMap; - -/// Stores DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E, A]) where AUTO(X, p): X^{i} -> X^{i*p} -pub struct AutomorphismKey { - pub value: Ciphertext, - pub p: i64, -} - -pub fn automorphis_key_new_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize { - module.bytes_of_scalar() + module.bytes_of_scalar_znx_dft() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q) -} - -impl Parameters { - pub fn automorphism_key_new_tmp_bytes(&self, rows: usize, log_q: usize) -> usize { - automorphis_key_new_tmp_bytes(self.module(), self.log_base2k(), rows, log_q) - } - - pub fn automorphism_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { - automorphism_tmp_bytes( - self.module(), - self.log_base2k(), - res_logq, - in_logq, - gct_logq, - ) - } -} - -impl AutomorphismKey { - pub fn new( - module: &Module, - p: i64, - sk: &SecretKey, - log_base2k: usize, - rows: usize, - log_q: usize, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], - ) -> Self { - Self::new_many_core( - module, - &vec![p], - sk, - log_base2k, - rows, - log_q, - source_xa, - source_xe, - sigma, - tmp_bytes, - ) - .into_iter() - .next() - .unwrap() - } - - pub fn new_many( - module: &Module, - p: &Vec, - sk: &SecretKey, - log_base2k: usize, - rows: usize, - log_q: usize, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], - ) -> HashMap { - Self::new_many_core( - module, p, sk, log_base2k, rows, log_q, source_xa, source_xe, sigma, tmp_bytes, - ) - .into_iter() - .zip(p.iter().cloned()) - .map(|(key, pi)| (pi, key)) - .collect() - } - - fn new_many_core( - module: &Module, - p: &Vec, - sk: &SecretKey, - log_base2k: usize, - rows: usize, - log_q: usize, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], - ) -> Vec { - let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar()); - let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar_znx_dft()); - - let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes); - let mut sk_out: ScalarZnxDft = module.new_scalar_znx_dft_from_bytes_borrow(sk_out_bytes); - - let mut keys: Vec = Vec::new(); - - p.iter().for_each(|pi| { - let mut value: Ciphertext = new_gadget_ciphertext(module, log_base2k, rows, log_q); - - let p_inv: i64 = module.galois_element_inv(*pi); - - module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx()); - module.scalar_znx_dft_prepare(&mut sk_out, &sk_auto); - encrypt_grlwe_sk( - module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes, - ); - - keys.push(Self { - value: value, - p: *pi, - }) - }); - - keys - } -} - -pub fn automorphism_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { - key_switch_tmp_bytes(module, log_base2k, res_logq, in_logq, gct_logq) -} - -pub fn automorphism( - module: &Module, - c: &mut Ciphertext, - a: &Ciphertext, - b: &AutomorphismKey, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - key_switch_rlwe(module, c, a, &b.value, b_cols, tmp_bytes); - // c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)] - module.vec_znx_automorphism_inplace(b.p, c.at_mut(0)); - // c[1] = AUTO(b, p) - module.vec_znx_automorphism_inplace(b.p, c.at_mut(1)); -} - -pub fn automorphism_inplace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize { - return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols) - + 2 * module.bytes_of_vec_znx_dft(1, std::cmp::min(c_cols, a_cols)); -} - -pub fn automorphism_inplace( - module: &Module, - a: &mut Ciphertext, - b: &AutomorphismKey, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - key_switch_rlwe_inplace(module, a, &b.value, b_cols, tmp_bytes); - // a[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)] - module.vec_znx_automorphism_inplace(b.p, a.at_mut(0)); - // a[1] = AUTO(b, p) - module.vec_znx_automorphism_inplace(b.p, a.at_mut(1)); -} - -pub fn automorphism_big( - module: &Module, - c: &mut Ciphertext, - a: &Ciphertext, - b: &AutomorphismKey, - tmp_bytes: &mut [u8], -) { - let cols = std::cmp::min(c.cols(), a.cols()); - - #[cfg(debug_assertions)] - { - assert!(tmp_bytes.len() >= automorphism_tmp_bytes(module, c.cols(), a.cols(), b.value.rows(), b.value.cols())); - assert_alignement(tmp_bytes.as_ptr()); - } - - let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - - let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft); - - // a1_dft = DFT(a[1]) - module.vec_znx_dft(&mut a1_dft, a.at(1)); - - // res_dft = IDFT() = [-b*AUTO(s, -p) + a * s + e] - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(0), tmp_bytes); - module.vec_znx_idft_tmp_a(c.at_mut(0), &mut res_dft); - - // res_dft = [-b*AUTO(s, -p) + a * s + e] + [-a * s + m + e] = [-b*AUTO(s, -p) + m + e] - module.vec_znx_big_add_small_inplace(c.at_mut(0), a.at(0)); - - // c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)] - module.vec_znx_big_automorphism_inplace(b.p, c.at_mut(0)); - - // res_dft = IDFT() = [b] - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(1), tmp_bytes); - module.vec_znx_idft_tmp_a(c.at_mut(1), &mut res_dft); - - // c[1] = AUTO(b, p) - module.vec_znx_big_automorphism_inplace(b.p, c.at_mut(1)); -} - -#[cfg(test)] -mod test { - use super::{AutomorphismKey, automorphism}; - use crate::{ - ciphertext::Ciphertext, - decryptor::decrypt_rlwe, - elem::ElemCommon, - encryptor::encrypt_rlwe_sk, - keys::SecretKey, - parameters::{Parameters, ParametersLiteral}, - plaintext::Plaintext, - }; - use base2k::{BACKEND, Encoding, Module, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxOps, alloc_aligned}; - use sampling::source::{Source, new_seed}; - - #[test] - fn test_automorphism() { - let log_base2k: usize = 10; - let log_q: usize = 50; - let log_p: usize = 15; - - // Basic parameters with enough limbs to test edge cases - let params_lit: ParametersLiteral = ParametersLiteral { - backend: BACKEND::FFT64, - log_n: 12, - log_q: log_q, - log_p: log_p, - log_base2k: log_base2k, - log_scale: 20, - xe: 3.2, - xs: 1 << 11, - }; - - let params: Parameters = Parameters::new(¶ms_lit); - - let module: &Module = params.module(); - let log_q: usize = params.log_q(); - let log_qp: usize = params.log_qp(); - let gct_rows: usize = params.cols_q(); - let gct_cols: usize = params.cols_qp(); - - // scratch space - let mut tmp_bytes: Vec = alloc_aligned( - params.decrypt_rlwe_tmp_byte(log_q) - | params.encrypt_rlwe_sk_tmp_bytes(log_q) - | params.automorphism_key_new_tmp_bytes(gct_rows, log_qp) - | params.automorphism_tmp_bytes(log_q, log_q, log_qp), - ); - - // Samplers for public and private randomness - let mut source_xe: Source = Source::new(new_seed()); - let mut source_xa: Source = Source::new(new_seed()); - let mut source_xs: Source = Source::new(new_seed()); - - let mut sk: SecretKey = SecretKey::new(module); - sk.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk_svp_ppol: ScalarZnxDft = module.new_svp_ppol(); - module.svp_prepare(&mut sk_svp_ppol, &sk.0); - - let p: i64 = -5; - - let auto_key: AutomorphismKey = AutomorphismKey::new( - module, - p, - &sk, - log_base2k, - gct_rows, - log_qp, - &mut source_xa, - &mut source_xe, - params.xe(), - &mut tmp_bytes, - ); - - let mut data: Vec = vec![0i64; params.n()]; - - data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - - let log_k: usize = 2 * log_base2k; - - let mut ct: Ciphertext = params.new_ciphertext(log_q); - let mut pt: Plaintext = params.new_plaintext(log_q); - let mut pt_auto: Plaintext = params.new_plaintext(log_q); - - pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32); - module.vec_znx_automorphism(p, pt_auto.at_mut(0), pt.at(0)); - - encrypt_rlwe_sk( - module, - &mut ct.elem_mut(), - Some(pt.at(0)), - &sk_svp_ppol, - &mut source_xa, - &mut source_xe, - params.xe(), - &mut tmp_bytes, - ); - - let mut ct_auto: Ciphertext = params.new_ciphertext(log_q); - - // ct <- AUTO(ct) - automorphism( - module, - &mut ct_auto, - &ct, - &auto_key, - gct_cols, - &mut tmp_bytes, - ); - - // pt = dec(auto(ct)) - auto(pt) - decrypt_rlwe( - module, - pt.elem_mut(), - ct_auto.elem(), - &sk_svp_ppol, - &mut tmp_bytes, - ); - - module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_auto.at(0)); - - // pt.at(0).print(pt.cols(), 16); - - let noise_have: f64 = pt.at(0).std(0, log_base2k).log2(); - - let var_msg: f64 = (params.xs() as f64) / params.n() as f64; - let var_a_err: f64 = 1f64 / 12f64; - - let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, ct_auto.log_q(), auto_key.value.log_q()); - - println!("noise_pred: {}", noise_pred); - println!("noise_have: {}", noise_have); - - assert!(noise_have <= noise_pred + 1.0); - } -} diff --git a/rlwe/src/ciphertext.rs b/rlwe/src/ciphertext.rs index bcffeec..dc83a66 100644 --- a/rlwe/src/ciphertext.rs +++ b/rlwe/src/ciphertext.rs @@ -1,93 +1,4 @@ -use crate::elem::{Elem, ElemCommon}; -use crate::parameters::Parameters; -use base2k::{ZnxInfos, Layout, Module, VecZnx, MatZnxDft}; -pub struct Ciphertext(pub Elem); - -impl Parameters { - pub fn new_ciphertext(&self, log_q: usize) -> Ciphertext { - Ciphertext::new(self.module(), self.log_base2k(), log_q, 2) - } -} - -impl ElemCommon for Ciphertext -where - T: ZnxInfos, -{ - fn n(&self) -> usize { - self.elem().n() - } - - fn log_n(&self) -> usize { - self.elem().log_n() - } - - fn log_q(&self) -> usize { - self.elem().log_q() - } - - fn elem(&self) -> &Elem { - &self.0 - } - - fn elem_mut(&mut self) -> &mut Elem { - &mut self.0 - } - - fn size(&self) -> usize { - self.elem().size() - } - - fn layout(&self) -> Layout { - self.elem().layout() - } - - fn rows(&self) -> usize { - self.elem().rows() - } - - fn cols(&self) -> usize { - self.elem().cols() - } - - fn at(&self, i: usize) -> &T { - self.elem().at(i) - } - - fn at_mut(&mut self, i: usize) -> &mut T { - self.elem_mut().at_mut(i) - } - - fn log_base2k(&self) -> usize { - self.elem().log_base2k() - } - - fn log_scale(&self) -> usize { - self.elem().log_scale() - } -} - -impl Ciphertext { - pub fn new(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> Self { - Self(Elem::::new(module, log_base2k, log_q, rows)) - } -} - -pub fn new_rlwe_ciphertext(module: &Module, log_base2k: usize, log_q: usize) -> Ciphertext { - let rows: usize = 2; - Ciphertext::::new(module, log_base2k, log_q, rows) -} - -pub fn new_gadget_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext { - let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut elem: Elem = Elem::::new(module, log_base2k, 2, rows, cols); - elem.log_q = log_q; - Ciphertext(elem) -} - -pub fn new_rgsw_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext { - let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut elem: Elem = Elem::::new(module, log_base2k, 4, rows, cols); - elem.log_q = log_q; - Ciphertext(elem) -} +pub struct Ciphertext{ + x +} \ No newline at end of file diff --git a/rlwe/src/decryptor.rs b/rlwe/src/decryptor.rs deleted file mode 100644 index 4c1fb7e..0000000 --- a/rlwe/src/decryptor.rs +++ /dev/null @@ -1,67 +0,0 @@ -use crate::{ - ciphertext::Ciphertext, - elem::{Elem, ElemCommon}, - keys::SecretKey, - parameters::Parameters, - plaintext::Plaintext, -}; -use base2k::{Module, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZnxDftOps}; -use std::cmp::min; - -pub struct Decryptor { - sk: ScalarZnxDft, -} - -impl Decryptor { - pub fn new(params: &Parameters, sk: &SecretKey) -> Self { - let mut sk_svp_ppol: ScalarZnxDft = params.module().new_svp_ppol(); - sk.prepare(params.module(), &mut sk_svp_ppol); - Self { sk: sk_svp_ppol } - } -} - -pub fn decrypt_rlwe_tmp_byte(module: &Module, cols: usize) -> usize { - module.bytes_of_vec_znx_dft(1, cols) + module.vec_znx_big_normalize_tmp_bytes() -} - -impl Parameters { - pub fn decrypt_rlwe_tmp_byte(&self, log_q: usize) -> usize { - decrypt_rlwe_tmp_byte( - self.module(), - (log_q + self.log_base2k() - 1) / self.log_base2k(), - ) - } - - pub fn decrypt_rlwe(&self, res: &mut Plaintext, ct: &Ciphertext, sk: &ScalarZnxDft, tmp_bytes: &mut [u8]) { - decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes) - } -} - -pub fn decrypt_rlwe(module: &Module, res: &mut Elem, a: &Elem, sk: &ScalarZnxDft, tmp_bytes: &mut [u8]) { - let cols: usize = a.cols(); - - assert!( - tmp_bytes.len() >= decrypt_rlwe_tmp_byte(module, cols), - "invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_tmp_byte={}", - tmp_bytes.len(), - decrypt_rlwe_tmp_byte(module, cols) - ); - - let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - - let mut res_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, 1, cols, tmp_bytes_vec_znx_dft); - let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big(); - - // res_dft <- DFT(ct[1]) * DFT(sk) - module.svp_apply_dft(&mut res_dft, sk, a.at(1)); - // res_big <- ct[1] x sk - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); - // res_big <- ct[1] x sk + ct[0] - module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0)); - // res <- normalize(ct[1] x sk + ct[0]) - module.vec_znx_big_normalize(a.log_base2k(), res.at_mut(0), &res_big, tmp_bytes_normalize); - - res.log_base2k = a.log_base2k(); - res.log_q = min(res.log_q(), a.log_q()); - res.log_scale = a.log_scale(); -} diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs deleted file mode 100644 index c6fe59f..0000000 --- a/rlwe/src/elem.rs +++ /dev/null @@ -1,168 +0,0 @@ -use base2k::{ZnxInfos, Layout, Module, VecZnx, VecZnxOps, MatZnxDft, MatZnxDftOps}; - -pub struct Elem { - pub value: Vec, - pub log_base2k: usize, - pub log_q: usize, - pub log_scale: usize, -} - -pub trait ElemVecZnx { - fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem; - fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem; - fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize; - fn zero(&mut self); -} - -impl ElemVecZnx for Elem { - fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize { - let cols = (log_q + log_base2k - 1) / log_base2k; - module.n() * cols * size * 8 - } - - fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem { - assert!(size > 0); - let n: usize = module.n(); - assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size)); - let mut value: Vec = Vec::new(); - let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let elem_size = VecZnx::bytes_of(n, size, cols); - let mut ptr: usize = 0; - (0..size).for_each(|_| { - value.push(VecZnx::from_bytes(n, 1, cols, &mut bytes[ptr..])); - ptr += elem_size - }); - Self { - value, - log_q, - log_base2k, - log_scale: 0, - } - } - - fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem { - assert!(size > 0); - let n: usize = module.n(); - assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size)); - let mut value: Vec = Vec::new(); - let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let elem_size = VecZnx::bytes_of(n, 1, cols); - let mut ptr: usize = 0; - (0..size).for_each(|_| { - value.push(VecZnx::from_bytes_borrow(n, 1, cols, &mut bytes[ptr..])); - ptr += elem_size - }); - Self { - value, - log_q, - log_base2k, - log_scale: 0, - } - } - - fn zero(&mut self) { - self.value.iter_mut().for_each(|i| i.zero()); - } -} - -pub trait ElemCommon { - fn n(&self) -> usize; - fn log_n(&self) -> usize; - fn elem(&self) -> &Elem; - fn elem_mut(&mut self) -> &mut Elem; - fn size(&self) -> usize; - fn layout(&self) -> Layout; - fn rows(&self) -> usize; - fn cols(&self) -> usize; - fn log_base2k(&self) -> usize; - fn log_q(&self) -> usize; - fn log_scale(&self) -> usize; - fn at(&self, i: usize) -> &T; - fn at_mut(&mut self, i: usize) -> &mut T; -} - -impl ElemCommon for Elem { - fn n(&self) -> usize { - self.value[0].n() - } - - fn log_n(&self) -> usize { - self.value[0].log_n() - } - - fn elem(&self) -> &Elem { - self - } - - fn elem_mut(&mut self) -> &mut Elem { - self - } - - fn size(&self) -> usize { - self.value.len() - } - - fn layout(&self) -> Layout { - self.value[0].layout() - } - - fn rows(&self) -> usize { - self.value[0].rows() - } - - fn cols(&self) -> usize { - self.value[0].cols() - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_q(&self) -> usize { - self.log_q - } - - fn log_scale(&self) -> usize { - self.log_scale - } - - fn at(&self, i: usize) -> &T { - assert!(i < self.size()); - &self.value[i] - } - - fn at_mut(&mut self, i: usize) -> &mut T { - assert!(i < self.size()); - &mut self.value[i] - } -} - -impl Elem { - pub fn new(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> Self { - assert!(rows > 0); - let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut value: Vec = Vec::new(); - (0..rows).for_each(|_| value.push(module.new_vec_znx(1, cols))); - Self { - value, - log_q, - log_base2k, - log_scale: 0, - } - } -} - -impl Elem { - pub fn new(module: &Module, log_base2k: usize, size: usize, rows: usize, cols: usize) -> Self { - assert!(rows > 0); - assert!(cols > 0); - let mut value: Vec = Vec::new(); - (0..size).for_each(|_| value.push(module.new_vmp_pmat(1, rows, cols))); - Self { - value: value, - log_q: 0, - log_base2k: log_base2k, - log_scale: 0, - } - } -} diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs deleted file mode 100644 index 7354a0f..0000000 --- a/rlwe/src/encryptor.rs +++ /dev/null @@ -1,369 +0,0 @@ -use crate::ciphertext::Ciphertext; -use crate::elem::{Elem, ElemCommon, ElemVecZnx}; -use crate::keys::SecretKey; -use crate::parameters::Parameters; -use crate::plaintext::Plaintext; -use base2k::sampling::Sampling; -use base2k::{ - ZnxInfos, Module, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, MatZnxDft, - MatZnxDftOps, -}; - -use sampling::source::{Source, new_seed}; - -impl Parameters { - pub fn encrypt_rlwe_sk_tmp_bytes(&self, log_q: usize) -> usize { - encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q) - } - pub fn encrypt_rlwe_sk( - &self, - ct: &mut Ciphertext, - pt: Option<&Plaintext>, - sk: &ScalarZnxDft, - source_xa: &mut Source, - source_xe: &mut Source, - tmp_bytes: &mut [u8], - ) { - encrypt_rlwe_sk( - self.module(), - &mut ct.0, - pt.map(|pt| pt.at(0)), - sk, - source_xa, - source_xe, - self.xe(), - tmp_bytes, - ) - } -} - -pub struct EncryptorSk { - sk: ScalarZnxDft, - source_xa: Source, - source_xe: Source, - initialized: bool, - tmp_bytes: Vec, -} - -impl EncryptorSk { - pub fn new(params: &Parameters, sk: Option<&SecretKey>) -> Self { - let mut sk_svp_ppol: ScalarZnxDft = params.module().new_svp_ppol(); - let mut initialized: bool = false; - if let Some(sk) = sk { - sk.prepare(params.module(), &mut sk_svp_ppol); - initialized = true; - } - Self { - sk: sk_svp_ppol, - initialized, - source_xa: Source::new(new_seed()), - source_xe: Source::new(new_seed()), - tmp_bytes: vec![0u8; params.encrypt_rlwe_sk_tmp_bytes(params.cols_qp())], - } - } - - pub fn set_sk(&mut self, module: &Module, sk: &SecretKey) { - sk.prepare(module, &mut self.sk); - self.initialized = true; - } - - pub fn seed_source_xa(&mut self, seed: [u8; 32]) { - self.source_xa = Source::new(seed) - } - - pub fn seed_source_xe(&mut self, seed: [u8; 32]) { - self.source_xe = Source::new(seed) - } - - pub fn encrypt_rlwe_sk(&mut self, params: &Parameters, ct: &mut Ciphertext, pt: Option<&Plaintext>) { - assert!( - self.initialized == true, - "invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]" - ); - params.encrypt_rlwe_sk( - ct, - pt, - &self.sk, - &mut self.source_xa, - &mut self.source_xe, - &mut self.tmp_bytes, - ); - } - - pub fn encrypt_rlwe_sk_core( - &self, - params: &Parameters, - ct: &mut Ciphertext, - pt: Option<&Plaintext>, - source_xa: &mut Source, - source_xe: &mut Source, - tmp_bytes: &mut [u8], - ) { - assert!( - self.initialized == true, - "invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]" - ); - params.encrypt_rlwe_sk(ct, pt, &self.sk, source_xa, source_xe, tmp_bytes); - } -} - -pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize { - module.bytes_of_vec_znx_dft(1, (log_q + log_base2k - 1) / log_base2k) + module.vec_znx_big_normalize_tmp_bytes() -} -pub fn encrypt_rlwe_sk( - module: &Module, - ct: &mut Elem, - pt: Option<&VecZnx>, - sk: &ScalarZnxDft, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], -) { - encrypt_rlwe_sk_core::<0>(module, ct, pt, sk, source_xa, source_xe, sigma, tmp_bytes) -} - -fn encrypt_rlwe_sk_core( - module: &Module, - ct: &mut Elem, - pt: Option<&VecZnx>, - sk: &ScalarZnxDft, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], -) { - let cols: usize = ct.cols(); - let log_base2k: usize = ct.log_base2k(); - let log_q: usize = ct.log_q(); - - assert!( - tmp_bytes.len() >= encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q), - "invalid tmp_bytes: tmp_bytes={} < encrypt_rlwe_sk_tmp_bytes={}", - tmp_bytes.len(), - encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) - ); - - let log_q: usize = ct.log_q(); - let log_base2k: usize = ct.log_base2k(); - let c1: &mut VecZnx = ct.at_mut(1); - - // c1 <- Z_{2^prec}[X]/(X^{N}+1) - module.fill_uniform(log_base2k, c1, cols, source_xa); - - let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - - // Scratch space for DFT values - let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, 1, cols, tmp_bytes_vec_znx_dft); - - // Applies buf_dft <- DFT(s) * DFT(c1) - module.svp_apply_dft(&mut buf_dft, sk, c1); - - // Alias scratch space - let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); - - // buf_big = s x c1 - module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); - - match PT_POS { - // c0 <- -s x c1 + m - 0 => { - let c0: &mut VecZnx = ct.at_mut(0); - if let Some(pt) = pt { - module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt); - module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); - } else { - module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); - module.vec_znx_negate_inplace(c0); - } - } - // c1 <- c1 + m - 1 => { - if let Some(pt) = pt { - module.vec_znx_add_inplace(c1, pt); - c1.normalize(log_base2k, tmp_bytes_normalize); - } - let c0: &mut VecZnx = ct.at_mut(0); - module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); - module.vec_znx_negate_inplace(c0); - } - _ => panic!("PT_POS must be 1 or 2"), - } - - // c0 <- -s x c1 + m + e - module.add_normal( - log_base2k, - ct.at_mut(0), - log_q, - source_xe, - sigma, - (sigma * 6.0).ceil(), - ); -} - -impl Parameters { - pub fn encrypt_grlwe_sk_tmp_bytes(&self, rows: usize, log_q: usize) -> usize { - encrypt_grlwe_sk_tmp_bytes(self.module(), self.log_base2k(), rows, log_q) - } -} - -pub fn encrypt_grlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize { - let cols = (log_q + log_base2k - 1) / log_base2k; - Elem::::bytes_of(module, log_base2k, log_q, 2) - + Plaintext::bytes_of(module, log_base2k, log_q) - + encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) - + module.vmp_prepare_tmp_bytes(rows, cols) -} - -pub fn encrypt_grlwe_sk( - module: &Module, - ct: &mut Ciphertext, - m: &Scalar, - sk: &ScalarZnxDft, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], -) { - let log_q: usize = ct.log_q(); - let log_base2k: usize = ct.log_base2k(); - let (left, right) = ct.0.value.split_at_mut(1); - encrypt_grlwe_sk_core::<0>( - module, - log_base2k, - [&mut left[0], &mut right[0]], - log_q, - m, - sk, - source_xa, - source_xe, - sigma, - tmp_bytes, - ) -} - -impl Parameters { - pub fn encrypt_rgsw_sk_tmp_bytes(&self, rows: usize, log_q: usize) -> usize { - encrypt_rgsw_sk_tmp_bytes(self.module(), self.log_base2k(), rows, log_q) - } -} - -pub fn encrypt_rgsw_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize { - let cols = (log_q + log_base2k - 1) / log_base2k; - Elem::::bytes_of(module, log_base2k, log_q, 2) - + Plaintext::bytes_of(module, log_base2k, log_q) - + encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) - + module.vmp_prepare_tmp_bytes(rows, cols) -} - -pub fn encrypt_rgsw_sk( - module: &Module, - ct: &mut Ciphertext, - m: &Scalar, - sk: &ScalarZnxDft, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], -) { - let log_q: usize = ct.log_q(); - let log_base2k: usize = ct.log_base2k(); - - let (left, right) = ct.0.value.split_at_mut(2); - let (ll, lr) = left.split_at_mut(1); - let (rl, rr) = right.split_at_mut(1); - - encrypt_grlwe_sk_core::<0>( - module, - log_base2k, - [&mut ll[0], &mut lr[0]], - log_q, - m, - sk, - source_xa, - source_xe, - sigma, - tmp_bytes, - ); - encrypt_grlwe_sk_core::<1>( - module, - log_base2k, - [&mut rl[0], &mut rr[0]], - log_q, - m, - sk, - source_xa, - source_xe, - sigma, - tmp_bytes, - ); -} - -fn encrypt_grlwe_sk_core( - module: &Module, - log_base2k: usize, - mut ct: [&mut MatZnxDft; 2], - log_q: usize, - m: &Scalar, - sk: &ScalarZnxDft, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], -) { - let rows: usize = ct[0].rows(); - - let min_tmp_bytes_len = encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q); - - assert!( - tmp_bytes.len() >= min_tmp_bytes_len, - "invalid tmp_bytes: tmp_bytes.len()={} < encrypt_grlwe_sk_tmp_bytes={}", - tmp_bytes.len(), - min_tmp_bytes_len - ); - - let bytes_of_elem: usize = Elem::::bytes_of(module, log_base2k, log_q, 2); - let bytes_of_pt: usize = Plaintext::bytes_of(module, log_base2k, log_q); - let bytes_of_enc_sk: usize = encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q); - - let (tmp_bytes_pt, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_pt); - let (tmp_bytes_enc_sk, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_enc_sk); - let (tmp_bytes_elem, tmp_bytes_vmp_prepare_row) = tmp_bytes.split_at_mut(bytes_of_elem); - - let mut tmp_elem: Elem = Elem::::from_bytes_borrow(module, log_base2k, log_q, 2, tmp_bytes_elem); - let mut tmp_pt: Plaintext = Plaintext::from_bytes_borrow(module, log_base2k, log_q, tmp_bytes_pt); - - (0..rows).for_each(|row_i| { - // Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i}) - tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.raw()); - - // Encrypts RLWE(m * 2^{-log_base2k*i}) - encrypt_rlwe_sk_core::( - module, - &mut tmp_elem, - Some(tmp_pt.at(0)), - sk, - source_xa, - source_xe, - sigma, - tmp_bytes_enc_sk, - ); - - // Zeroes the ith-row of tmp_pt - tmp_pt.at_mut(0).at_mut(row_i).fill(0); - - // GRLWE[row_i][0||1] = [-as + m * 2^{-i*log_base2k} + e*2^{-log_q} || a] - module.vmp_prepare_row( - ct[0], - tmp_elem.at(0).raw(), - row_i, - tmp_bytes_vmp_prepare_row, - ); - module.vmp_prepare_row( - &mut ct[1], - tmp_elem.at(1).raw(), - row_i, - tmp_bytes_vmp_prepare_row, - ); - }); -} diff --git a/rlwe/src/gadget_product.rs b/rlwe/src/gadget_product.rs deleted file mode 100644 index 9315cd8..0000000 --- a/rlwe/src/gadget_product.rs +++ /dev/null @@ -1,383 +0,0 @@ -use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; -use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, MatZnxDft, MatZnxDftOps}; -use std::cmp::min; - -pub fn gadget_product_core_tmp_bytes( - module: &Module, - log_base2k: usize, - res_log_q: usize, - in_log_q: usize, - gct_rows: usize, - gct_log_q: usize, -) -> usize { - let gct_cols: usize = (gct_log_q + log_base2k - 1) / log_base2k; - let in_cols: usize = (in_log_q + log_base2k - 1) / log_base2k; - let out_cols: usize = (res_log_q + log_base2k - 1) / log_base2k; - module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, gct_rows, gct_cols) -} - -impl Parameters { - pub fn gadget_product_tmp_bytes(&self, res_log_q: usize, in_log_q: usize, gct_rows: usize, gct_log_q: usize) -> usize { - gadget_product_core_tmp_bytes( - self.module(), - self.log_base2k(), - res_log_q, - in_log_q, - gct_rows, - gct_log_q, - ) - } -} - -pub fn gadget_product_core( - module: &Module, - res_dft_0: &mut VecZnxDft, - res_dft_1: &mut VecZnxDft, - a: &VecZnx, - b: &Ciphertext, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - assert!(b_cols <= b.cols()); - module.vec_znx_dft(res_dft_1, a); - module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes); - module.vmp_apply_dft_to_dft_inplace(res_dft_1, b.at(1), tmp_bytes); -} - -pub fn gadget_product_big_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize { - return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols) - + 2 * module.bytes_of_vec_znx_dft(1, min(c_cols, a_cols)); -} - -/// Evaluates the gadget product: c.at(i) = IDFT() -/// -/// # Arguments -/// -/// * `module`: backend support for operations mod (X^N + 1). -/// * `c`: a [Ciphertext] with cols_c cols. -/// * `a`: a [Ciphertext] with cols_a cols. -/// * `b`: a [Ciphertext] with at least min(cols_c, cols_a) rows. -pub fn gadget_product_big( - module: &Module, - c: &mut Ciphertext, - a: &Ciphertext, - b: &Ciphertext, - tmp_bytes: &mut [u8], -) { - let cols: usize = min(c.cols(), a.cols()); - - let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - - let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft); - - // a1_dft = DFT(a[1]) - module.vec_znx_dft(&mut a1_dft, a.at(1)); - - // c[i] = IDFT(DFT(a[1]) * b[i]) - (0..2).for_each(|i| { - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(i), tmp_bytes); - module.vec_znx_idft_tmp_a(c.at_mut(i), &mut res_dft); - }) -} - -/// Evaluates the gadget product: c.at(i) = NORMALIZE(IDFT() -/// -/// # Arguments -/// -/// * `module`: backend support for operations mod (X^N + 1). -/// * `c`: a [Ciphertext] with cols_c cols. -/// * `a`: a [Ciphertext] with cols_a cols. -/// * `b`: a [Ciphertext] with at least min(cols_c, cols_a) rows. -pub fn gadget_product( - module: &Module, - c: &mut Ciphertext, - a: &Ciphertext, - b: &Ciphertext, - tmp_bytes: &mut [u8], -) { - let cols: usize = min(c.cols(), a.cols()); - - let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - - let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft); - let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); - - // a1_dft = DFT(a[1]) - module.vec_znx_dft(&mut a1_dft, a.at(1)); - - // c[i] = IDFT(DFT(a[1]) * b[i]) - (0..2).for_each(|i| { - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(i), tmp_bytes); - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); - module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(i), &mut res_big, tmp_bytes); - }) -} - -#[cfg(test)] -mod test { - use crate::{ - ciphertext::{Ciphertext, new_gadget_ciphertext}, - decryptor::decrypt_rlwe, - elem::{Elem, ElemCommon, ElemVecZnx}, - encryptor::encrypt_grlwe_sk, - gadget_product::gadget_product_core, - keys::SecretKey, - parameters::{Parameters, ParametersLiteral}, - plaintext::Plaintext, - }; - use base2k::{ - BACKEND, ZnxInfos, Sampling, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, MatZnxDft, - alloc_aligned_u8, - }; - use sampling::source::{Source, new_seed}; - - #[test] - fn test_gadget_product_core() { - let log_base2k: usize = 10; - let q_cols: usize = 7; - let p_cols: usize = 1; - - // Basic parameters with enough limbs to test edge cases - let params_lit: ParametersLiteral = ParametersLiteral { - backend: BACKEND::FFT64, - log_n: 12, - log_q: q_cols * log_base2k, - log_p: p_cols * log_base2k, - log_base2k: log_base2k, - log_scale: 20, - xe: 3.2, - xs: 1 << 11, - }; - - let params: Parameters = Parameters::new(¶ms_lit); - - // scratch space - let mut tmp_bytes: Vec = alloc_aligned_u8( - params.decrypt_rlwe_tmp_byte(params.log_qp()) - | params.gadget_product_tmp_bytes( - params.log_qp(), - params.log_qp(), - params.cols_qp(), - params.log_qp(), - ) - | params.encrypt_grlwe_sk_tmp_bytes(params.cols_qp(), params.log_qp()), - ); - - // Samplers for public and private randomness - let mut source_xe: Source = Source::new(new_seed()); - let mut source_xa: Source = Source::new(new_seed()); - let mut source_xs: Source = Source::new(new_seed()); - - // Two secret keys - let mut sk0: SecretKey = SecretKey::new(params.module()); - sk0.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk0_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0); - - let mut sk1: SecretKey = SecretKey::new(params.module()); - sk1.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk1_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0); - - // The gadget ciphertext - let mut gadget_ct: Ciphertext = new_gadget_ciphertext( - params.module(), - log_base2k, - params.cols_qp(), - params.log_qp(), - ); - - // gct = [-b*sk1 + g(sk0) + e, b] - encrypt_grlwe_sk( - params.module(), - &mut gadget_ct, - &sk0.0, - &sk1_svp_ppol, - &mut source_xa, - &mut source_xe, - params.xe(), - &mut tmp_bytes, - ); - - // Intermediate buffers - - // Input polynopmial, uniformly distributed - let mut a: VecZnx = params.module().new_vec_znx(1, params.cols_q()); - params - .module() - .fill_uniform(log_base2k, &mut a, params.cols_q(), &mut source_xa); - - // res = g^-1(a) * gct - let mut elem_res: Elem = Elem::::new(params.module(), log_base2k, params.log_qp(), 2); - - // Ideal output = a * s - let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(1, a.cols()); - let mut a_big: VecZnxBig = a_dft.as_vec_znx_big(); - let mut a_times_s: VecZnx = params.module().new_vec_znx(1, a.cols()); - - // a * sk0 - params.module().svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a); - params.module().vec_znx_idft_tmp_a(&mut a_big, &mut a_dft); - params - .module() - .vec_znx_big_normalize(params.log_base2k(), &mut a_times_s, &a_big, &mut tmp_bytes); - - // Plaintext for decrypted output of gadget product - let mut pt: Plaintext = Plaintext::new(params.module(), params.log_base2k(), params.log_qp()); - - // Iterates over all possible cols values for input/output polynomials and gadget ciphertext. - - (1..a.cols() + 1).for_each(|a_cols| { - let mut a_trunc: VecZnx = params.module().new_vec_znx(1, a_cols); - a_trunc.copy_from(&a); - - (1..gadget_ct.cols() + 1).for_each(|b_cols| { - let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(1, b_cols); - let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(1, b_cols); - let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big(); - let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big(); - - pt.elem_mut().zero(); - elem_res.zero(); - - // let b_cols: usize = min(a_cols+1, gadget_ct.cols()); - - println!("a_cols: {} b_cols: {}", a_cols, b_cols); - - // res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e') - // res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c) - gadget_product_core( - params.module(), - &mut res_dft_0, - &mut res_dft_1, - &a_trunc, - &gadget_ct, - b_cols, - &mut tmp_bytes, - ); - - // res_big_0 = IDFT(res_dft_0) - params - .module() - .vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0); - // res_big_1 = IDFT(res_dft_1); - params - .module() - .vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1); - - // res_big_0 = normalize(res_big_0) - params - .module() - .vec_znx_big_normalize(log_base2k, elem_res.at_mut(0), &res_big_0, &mut tmp_bytes); - - // res_big_1 = normalize(res_big_1) - params - .module() - .vec_znx_big_normalize(log_base2k, elem_res.at_mut(1), &res_big_1, &mut tmp_bytes); - - // <(-c*sk1 + a*sk0 + e, a), (1, sk1)> = a*sk0 + e - decrypt_rlwe( - params.module(), - pt.elem_mut(), - &elem_res, - &sk1_svp_ppol, - &mut tmp_bytes, - ); - - // a * sk0 + e - a*sk0 = e - params - .module() - .vec_znx_sub_ab_inplace(pt.at_mut(0), &mut a_times_s); - pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes); - - // pt.at(0).print(pt.elem().cols(), 16); - - let noise_have: f64 = pt.at(0).std(0, log_base2k).log2(); - - let var_a_err: f64; - - if a_cols < a.cols() { - var_a_err = 1f64 / 12f64; - } else { - var_a_err = 0f64; - } - - let a_logq: usize = a_cols * log_base2k; - let b_logq: usize = b_cols * log_base2k; - let var_msg: f64 = (params.xs() as f64) / params.n() as f64; - - println!("{} {} {} {}", var_msg, var_a_err, a_logq, b_logq); - - let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, a_logq, b_logq); - - println!("noise_pred: {}", noise_pred); - println!("noise_have: {}", noise_have); - - // assert!(noise_have <= noise_pred + 1.0); - }); - }); - } -} - -impl Parameters { - pub fn noise_grlwe_product(&self, var_msg: f64, var_a_err: f64, a_logq: usize, b_logq: usize) -> f64 { - let n: f64 = self.n() as f64; - let var_xs: f64 = self.xs() as f64; - - let var_gct_err_lhs: f64; - let var_gct_err_rhs: f64; - if b_logq < self.log_qp() { - let var_round: f64 = 1f64 / 12f64; - var_gct_err_lhs = var_round; - var_gct_err_rhs = var_round; - } else { - var_gct_err_lhs = self.xe() * self.xe(); - var_gct_err_rhs = 0f64; - } - - noise_grlwe_product( - n, - self.log_base2k(), - var_xs, - var_msg, - var_a_err, - var_gct_err_lhs, - var_gct_err_rhs, - a_logq, - b_logq, - ) - } -} - -pub fn noise_grlwe_product( - n: f64, - log_base2k: usize, - var_xs: f64, - var_msg: f64, - var_a_err: f64, - var_gct_err_lhs: f64, - var_gct_err_rhs: f64, - a_logq: usize, - b_logq: usize, -) -> f64 { - let a_logq: usize = min(a_logq, b_logq); - let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k; - - let b_scale = 2.0f64.powi(b_logq as i32); - let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); - - let base: f64 = (1 << (log_base2k)) as f64; - let var_base: f64 = base * base / 12f64; - - // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) - // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs - let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); - noise += var_msg * var_a_err * a_scale * a_scale * n; - noise = noise.sqrt(); - noise /= b_scale; - noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] -} diff --git a/rlwe/src/key_generator.rs b/rlwe/src/key_generator.rs deleted file mode 100644 index 88a2331..0000000 --- a/rlwe/src/key_generator.rs +++ /dev/null @@ -1,55 +0,0 @@ -use crate::encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes}; -use crate::keys::{PublicKey, SecretKey, SwitchingKey}; -use crate::parameters::Parameters; -use base2k::{Module, ScalarZnxDft}; -use sampling::source::Source; - -pub struct KeyGenerator {} - -impl KeyGenerator { - pub fn gen_secret_key_thread_safe(&self, params: &Parameters, source: &mut Source) -> SecretKey { - let mut sk: SecretKey = SecretKey::new(params.module()); - sk.fill_ternary_hw(params.xs(), source); - sk - } - - pub fn gen_public_key_thread_safe( - &self, - params: &Parameters, - sk_ppol: &ScalarZnxDft, - source: &mut Source, - tmp_bytes: &mut [u8], - ) -> PublicKey { - let mut xa_source: Source = source.branch(); - let mut xe_source: Source = source.branch(); - let mut pk: PublicKey = PublicKey::new(params.module(), params.log_base2k(), params.log_qp()); - pk.gen_thread_safe( - params.module(), - sk_ppol, - params.xe(), - &mut xa_source, - &mut xe_source, - tmp_bytes, - ); - pk - } -} - -pub fn gen_switching_key_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize { - encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q) -} - -pub fn gen_switching_key( - module: &Module, - swk: &mut SwitchingKey, - sk_in: &SecretKey, - sk_out: &ScalarZnxDft, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], -) { - encrypt_grlwe_sk( - module, &mut swk.0, &sk_in.0, sk_out, source_xa, source_xe, sigma, tmp_bytes, - ); -} diff --git a/rlwe/src/key_switching.rs b/rlwe/src/key_switching.rs deleted file mode 100644 index e73c7f9..0000000 --- a/rlwe/src/key_switching.rs +++ /dev/null @@ -1,79 +0,0 @@ -use crate::ciphertext::Ciphertext; -use crate::elem::ElemCommon; -use base2k::{Module, VecZnx, VecZnxBigOps, VecZnxDftOps, MatZnxDft, MatZnxDftOps, assert_alignement}; -use std::cmp::min; - -pub fn key_switch_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { - let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k; - let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k; - let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k; - return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols) - + module.bytes_of_vec_znx_dft(1, std::cmp::min(res_cols, in_cols)) - + module.bytes_of_vec_znx_dft(1, gct_cols); -} - -pub fn key_switch_rlwe( - module: &Module, - c: &mut Ciphertext, - a: &Ciphertext, - b: &Ciphertext, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - key_switch_rlwe_core(module, c, a, b, b_cols, tmp_bytes); -} - -pub fn key_switch_rlwe_inplace( - module: &Module, - a: &mut Ciphertext, - b: &Ciphertext, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - key_switch_rlwe_core(module, a, a, b, b_cols, tmp_bytes); -} - -fn key_switch_rlwe_core( - module: &Module, - c: *mut Ciphertext, - a: *const Ciphertext, - b: &Ciphertext, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - // SAFETY WARNING: must ensure `c` and `a` are valid for read/write - let c: &mut Ciphertext = unsafe { &mut *c }; - let a: &Ciphertext = unsafe { &*a }; - - let cols: usize = min(min(c.cols(), a.cols()), b.rows()); - - #[cfg(debug_assertions)] - { - assert!(b_cols <= b.cols()); - assert!(tmp_bytes.len() >= key_switch_tmp_bytes(module, c.cols(), a.cols(), b.rows(), b.cols())); - assert_alignement(tmp_bytes.as_ptr()); - } - - let (tmp_bytes_a1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - - let mut a1_dft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_a1_dft); - let mut res_dft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_res_dft); - let mut res_big = res_dft.as_vec_znx_big(); - - module.vec_znx_dft(&mut a1_dft, a.at(1)); - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(0), tmp_bytes); - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); - - module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0)); - module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut res_big, tmp_bytes); - - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(1), tmp_bytes); - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); - - module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut res_big, tmp_bytes); -} - -pub fn key_switch_grlwe(module: &Module, c: &mut Ciphertext, a: &Ciphertext, b: &Ciphertext) {} - -pub fn key_switch_rgsw(module: &Module, c: &mut Ciphertext, a: &Ciphertext, b: &Ciphertext) {} diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs deleted file mode 100644 index 511f755..0000000 --- a/rlwe/src/keys.rs +++ /dev/null @@ -1,82 +0,0 @@ -use crate::ciphertext::{Ciphertext, new_gadget_ciphertext}; -use crate::elem::{Elem, ElemCommon}; -use crate::encryptor::{encrypt_rlwe_sk, encrypt_rlwe_sk_tmp_bytes}; -use base2k::{Module, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, MatZnxDft}; -use sampling::source::Source; - -pub struct SecretKey(pub Scalar); - -impl SecretKey { - pub fn new(module: &Module) -> Self { - SecretKey(Scalar::new(module.n())) - } - - pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { - self.0.fill_ternary_prob(prob, source); - } - - pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { - self.0.fill_ternary_hw(hw, source); - } - - pub fn prepare(&self, module: &Module, sk_ppol: &mut ScalarZnxDft) { - module.scalar_znx_dft_prepare(sk_ppol, &self.0) - } -} - -pub struct PublicKey(pub Elem); - -impl PublicKey { - pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> PublicKey { - PublicKey(Elem::::new(module, log_base2k, log_q, 2)) - } - - pub fn gen_thread_safe( - &mut self, - module: &Module, - sk: &ScalarZnxDft, - xe: f64, - xa_source: &mut Source, - xe_source: &mut Source, - tmp_bytes: &mut [u8], - ) { - encrypt_rlwe_sk( - module, - &mut self.0, - None, - sk, - xa_source, - xe_source, - xe, - tmp_bytes, - ); - } - - pub fn gen_thread_safe_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize { - encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) - } -} - -pub struct SwitchingKey(pub Ciphertext); - -impl SwitchingKey { - pub fn new(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> SwitchingKey { - SwitchingKey(new_gadget_ciphertext(module, log_base2k, rows, log_q)) - } - - pub fn n(&self) -> usize { - self.0.n() - } - - pub fn rows(&self) -> usize { - self.0.rows() - } - - pub fn cols(&self) -> usize { - self.0.cols() - } - - pub fn log_base2k(&self) -> usize { - self.0.log_base2k() - } -} diff --git a/rlwe/src/lib.rs b/rlwe/src/lib.rs index aecb526..3a7eec6 100644 --- a/rlwe/src/lib.rs +++ b/rlwe/src/lib.rs @@ -1,13 +1 @@ -pub mod automorphism; -pub mod ciphertext; -pub mod decryptor; -pub mod elem; -pub mod encryptor; -pub mod gadget_product; -pub mod key_generator; -pub mod key_switching; -pub mod keys; -pub mod parameters; -pub mod plaintext; -pub mod rgsw_product; -pub mod trace; +pub mod ciphertext; \ No newline at end of file diff --git a/rlwe/src/parameters.rs b/rlwe/src/parameters.rs deleted file mode 100644 index cd3a91d..0000000 --- a/rlwe/src/parameters.rs +++ /dev/null @@ -1,88 +0,0 @@ -use base2k::module::{BACKEND, Module}; - -pub const DEFAULT_SIGMA: f64 = 3.2; - -pub struct ParametersLiteral { - pub backend: BACKEND, - pub log_n: usize, - pub log_q: usize, - pub log_p: usize, - pub log_base2k: usize, - pub log_scale: usize, - pub xe: f64, - pub xs: usize, -} - -pub struct Parameters { - log_n: usize, - log_q: usize, - log_p: usize, - log_scale: usize, - log_base2k: usize, - xe: f64, - xs: usize, - module: Module, -} - -impl Parameters { - pub fn new(p: &ParametersLiteral) -> Self { - assert!( - p.log_n + 2 * p.log_base2k <= 53, - "invalid parameters: p.log_n + 2*p.log_base2k > 53" - ); - Self { - log_n: p.log_n, - log_q: p.log_q, - log_p: p.log_p, - log_scale: p.log_scale, - log_base2k: p.log_base2k, - xe: p.xe, - xs: p.xs, - module: Module::new(1 << p.log_n, p.backend), - } - } - - pub fn n(&self) -> usize { - 1 << self.log_n - } - - pub fn log_scale(&self) -> usize { - self.log_scale - } - - pub fn log_q(&self) -> usize { - self.log_q - } - - pub fn log_p(&self) -> usize { - self.log_p - } - - pub fn log_qp(&self) -> usize { - self.log_q + self.log_p - } - - pub fn cols_q(&self) -> usize { - (self.log_q + self.log_base2k - 1) / self.log_base2k - } - - pub fn cols_qp(&self) -> usize { - (self.log_q + self.log_p + self.log_base2k - 1) / self.log_base2k - } - - pub fn log_base2k(&self) -> usize { - self.log_base2k - } - - pub fn module(&self) -> &Module { - &self.module - } - - pub fn xe(&self) -> f64 { - self.xe - } - - pub fn xs(&self) -> usize { - self.xs - } -} diff --git a/rlwe/src/plaintext.rs b/rlwe/src/plaintext.rs deleted file mode 100644 index 258756b..0000000 --- a/rlwe/src/plaintext.rs +++ /dev/null @@ -1,109 +0,0 @@ -use crate::ciphertext::Ciphertext; -use crate::elem::{Elem, ElemCommon, ElemVecZnx}; -use crate::parameters::Parameters; -use base2k::{Layout, Module, VecZnx}; - -pub struct Plaintext(pub Elem); - -impl Parameters { - pub fn new_plaintext(&self, log_q: usize) -> Plaintext { - Plaintext::new(self.module(), self.log_base2k(), log_q) - } - - pub fn bytes_of_plaintext(&self, log_q: usize) -> usize -where { - Elem::::bytes_of(self.module(), self.log_base2k(), log_q, 1) - } - - pub fn plaintext_from_bytes(&self, log_q: usize, bytes: &mut [u8]) -> Plaintext { - Plaintext(Elem::::from_bytes( - self.module(), - self.log_base2k(), - log_q, - 1, - bytes, - )) - } -} - -impl Plaintext { - pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self { - Self(Elem::::new(module, log_base2k, log_q, 1)) - } -} - -impl Plaintext { - pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize) -> usize { - Elem::::bytes_of(module, log_base2k, log_q, 1) - } - - pub fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self { - Self(Elem::::from_bytes( - module, log_base2k, log_q, 1, bytes, - )) - } - - pub fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self { - Self(Elem::::from_bytes_borrow( - module, log_base2k, log_q, 1, bytes, - )) - } - - pub fn as_ciphertext(&self) -> Ciphertext { - unsafe { Ciphertext::(std::ptr::read(&self.0)) } - } -} - -impl ElemCommon for Plaintext { - fn n(&self) -> usize { - self.0.n() - } - - fn log_n(&self) -> usize { - self.elem().log_n() - } - - fn log_q(&self) -> usize { - self.0.log_q - } - - fn elem(&self) -> &Elem { - &self.0 - } - - fn elem_mut(&mut self) -> &mut Elem { - &mut self.0 - } - - fn size(&self) -> usize { - self.elem().size() - } - - fn layout(&self) -> Layout { - self.elem().layout() - } - - fn rows(&self) -> usize { - self.0.rows() - } - - fn cols(&self) -> usize { - self.0.cols() - } - - fn at(&self, i: usize) -> &VecZnx { - self.0.at(i) - } - - fn at_mut(&mut self, i: usize) -> &mut VecZnx { - self.0.at_mut(i) - } - - fn log_base2k(&self) -> usize { - self.0.log_base2k() - } - - fn log_scale(&self) -> usize { - self.0.log_scale() - } -} diff --git a/rlwe/src/rgsw_product.rs b/rlwe/src/rgsw_product.rs deleted file mode 100644 index 1f76166..0000000 --- a/rlwe/src/rgsw_product.rs +++ /dev/null @@ -1,300 +0,0 @@ -use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; -use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, MatZnxDft, MatZnxDftOps, assert_alignement}; -use std::cmp::min; - -impl Parameters { - pub fn rgsw_product_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { - rgsw_product_tmp_bytes( - self.module(), - self.log_base2k(), - res_logq, - in_logq, - gct_logq, - ) - } -} -pub fn rgsw_product_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { - let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k; - let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k; - let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k; - return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols) - + module.bytes_of_vec_znx_dft(1, std::cmp::min(res_cols, in_cols)) - + 2 * module.bytes_of_vec_znx_dft(1, gct_cols); -} - -pub fn rgsw_product( - module: &Module, - c: &mut Ciphertext, - a: &Ciphertext, - b: &Ciphertext, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - #[cfg(debug_assertions)] - { - assert!(b_cols <= b.cols()); - assert_eq!(c.size(), 2); - assert_eq!(a.size(), 2); - assert_eq!(b.size(), 4); - assert!(tmp_bytes.len() >= rgsw_product_tmp_bytes(module, c.cols(), a.cols(), min(b.rows(), a.cols()), b_cols)); - assert_alignement(tmp_bytes.as_ptr()); - } - - let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, a.cols())); - let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - - let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, a.cols(), tmp_bytes_ai_dft); - let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c0_dft); - let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c1_dft); - - let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big(); - let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big(); - - module.vec_znx_dft(&mut ai_dft, a.at(0)); - module.vmp_apply_dft_to_dft(&mut c0_dft, &ai_dft, b.at(0), tmp_bytes); - module.vmp_apply_dft_to_dft(&mut c1_dft, &ai_dft, b.at(1), tmp_bytes); - - module.vec_znx_dft(&mut ai_dft, a.at(1)); - module.vmp_apply_dft_to_dft_add(&mut c0_dft, &ai_dft, b.at(2), tmp_bytes); - module.vmp_apply_dft_to_dft_add(&mut c1_dft, &ai_dft, b.at(3), tmp_bytes); - - module.vec_znx_idft_tmp_a(&mut c0_big, &mut c0_dft); - module.vec_znx_idft_tmp_a(&mut c1_big, &mut c1_dft); - - module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut c0_big, tmp_bytes); - module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut c1_big, tmp_bytes); -} - -pub fn rgsw_product_inplace( - module: &Module, - a: &mut Ciphertext, - b: &Ciphertext, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - #[cfg(debug_assertions)] - { - assert!(b_cols <= b.cols()); - assert_eq!(a.size(), 2); - assert_eq!(b.size(), 4); - assert!(tmp_bytes.len() >= rgsw_product_tmp_bytes(module, a.cols(), a.cols(), min(b.rows(), a.cols()), b_cols)); - assert_alignement(tmp_bytes.as_ptr()); - } - - let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, a.cols())); - let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - - let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, a.cols(), tmp_bytes_ai_dft); - let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c0_dft); - let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c1_dft); - - let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big(); - let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big(); - - module.vec_znx_dft(&mut ai_dft, a.at(0)); - module.vmp_apply_dft_to_dft(&mut c0_dft, &ai_dft, b.at(0), tmp_bytes); - module.vmp_apply_dft_to_dft(&mut c1_dft, &ai_dft, b.at(1), tmp_bytes); - - module.vec_znx_dft(&mut ai_dft, a.at(1)); - module.vmp_apply_dft_to_dft_add(&mut c0_dft, &ai_dft, b.at(2), tmp_bytes); - module.vmp_apply_dft_to_dft_add(&mut c1_dft, &ai_dft, b.at(3), tmp_bytes); - - module.vec_znx_idft_tmp_a(&mut c0_big, &mut c0_dft); - module.vec_znx_idft_tmp_a(&mut c1_big, &mut c1_dft); - - module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut c0_big, tmp_bytes); - module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut c1_big, tmp_bytes); -} - -#[cfg(test)] -mod test { - use crate::{ - ciphertext::{Ciphertext, new_rgsw_ciphertext}, - decryptor::decrypt_rlwe, - elem::ElemCommon, - encryptor::{encrypt_rgsw_sk, encrypt_rlwe_sk}, - keys::SecretKey, - parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral}, - plaintext::Plaintext, - rgsw_product::rgsw_product_inplace, - }; - use base2k::{BACKEND, Encoding, Module, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxOps, MatZnxDft, alloc_aligned}; - use sampling::source::{Source, new_seed}; - - #[test] - fn test_rgsw_product() { - let log_base2k: usize = 10; - let log_q: usize = 50; - let log_p: usize = 15; - - // Basic parameters with enough limbs to test edge cases - let params_lit: ParametersLiteral = ParametersLiteral { - backend: BACKEND::FFT64, - log_n: 12, - log_q: log_q, - log_p: log_p, - log_base2k: log_base2k, - log_scale: 20, - xe: 3.2, - xs: 1 << 11, - }; - - let params: Parameters = Parameters::new(¶ms_lit); - - let module: &Module = params.module(); - let log_q: usize = params.log_q(); - let log_qp: usize = params.log_qp(); - let gct_rows: usize = params.cols_q(); - let gct_cols: usize = params.cols_qp(); - - // scratch space - let mut tmp_bytes: Vec = alloc_aligned( - params.decrypt_rlwe_tmp_byte(log_q) - | params.encrypt_rlwe_sk_tmp_bytes(log_q) - | params.rgsw_product_tmp_bytes(log_q, log_q, log_qp) - | params.encrypt_rgsw_sk_tmp_bytes(gct_rows, log_qp), - ); - - // Samplers for public and private randomness - let mut source_xe: Source = Source::new(new_seed()); - let mut source_xa: Source = Source::new(new_seed()); - let mut source_xs: Source = Source::new(new_seed()); - - let mut sk: SecretKey = SecretKey::new(module); - sk.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk_svp_ppol: ScalarZnxDft = module.new_svp_ppol(); - module.svp_prepare(&mut sk_svp_ppol, &sk.0); - - let mut ct_rgsw: Ciphertext = new_rgsw_ciphertext(module, log_base2k, gct_rows, log_qp); - - let k: i64 = 3; - - // X^k - let m: Scalar = module.new_scalar(); - let data: &mut [i64] = m.raw_mut(); - data[k as usize] = 1; - - encrypt_rgsw_sk( - module, - &mut ct_rgsw, - &m, - &sk_svp_ppol, - &mut source_xa, - &mut source_xe, - DEFAULT_SIGMA, - &mut tmp_bytes, - ); - - let log_k: usize = 2 * log_base2k; - - let mut ct: Ciphertext = params.new_ciphertext(log_q); - let mut pt: Plaintext = params.new_plaintext(log_q); - let mut pt_rotate: Plaintext = params.new_plaintext(log_q); - - pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32); - - module.vec_znx_rotate(k, pt_rotate.at_mut(0), pt.at_mut(0)); - - encrypt_rlwe_sk( - module, - &mut ct.elem_mut(), - Some(pt.at(0)), - &sk_svp_ppol, - &mut source_xa, - &mut source_xe, - params.xe(), - &mut tmp_bytes, - ); - - rgsw_product_inplace(module, &mut ct, &ct_rgsw, gct_cols, &mut tmp_bytes); - - decrypt_rlwe( - module, - pt.elem_mut(), - ct.elem(), - &sk_svp_ppol, - &mut tmp_bytes, - ); - - module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_rotate.at(0)); - - // pt.at(0).print(pt.cols(), 16); - - let noise_have: f64 = pt.at(0).std(0, log_base2k).log2(); - - let var_msg: f64 = 1f64 / params.n() as f64; // X^{k} - let var_a0_err: f64 = params.xe() * params.xe(); - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_pred: f64 = params.noise_rgsw_product(var_msg, var_a0_err, var_a1_err, ct.log_q(), ct_rgsw.log_q()); - - println!("noise_pred: {}", noise_pred); - println!("noise_have: {}", noise_have); - - assert!(noise_have <= noise_pred + 1.0); - } -} - -impl Parameters { - pub fn noise_rgsw_product(&self, var_msg: f64, var_a0_err: f64, var_a1_err: f64, a_logq: usize, b_logq: usize) -> f64 { - let n: f64 = self.n() as f64; - let var_xs: f64 = self.xs() as f64; - - let var_gct_err_lhs: f64; - let var_gct_err_rhs: f64; - if b_logq < self.log_qp() { - let var_round: f64 = 1f64 / 12f64; - var_gct_err_lhs = var_round; - var_gct_err_rhs = var_round; - } else { - var_gct_err_lhs = self.xe() * self.xe(); - var_gct_err_rhs = 0f64; - } - - noise_rgsw_product( - n, - self.log_base2k(), - var_xs, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - a_logq, - b_logq, - ) - } -} - -pub fn noise_rgsw_product( - n: f64, - log_base2k: usize, - var_xs: f64, - var_msg: f64, - var_a0_err: f64, - var_a1_err: f64, - var_gct_err_lhs: f64, - var_gct_err_rhs: f64, - a_logq: usize, - b_logq: usize, -) -> f64 { - let a_logq: usize = min(a_logq, b_logq); - let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k; - - let b_scale = 2.0f64.powi(b_logq as i32); - let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); - - let base: f64 = (1 << (log_base2k)) as f64; - let var_base: f64 = base * base / 12f64; - - // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) - // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs - let mut noise: f64 = 2.0 * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); - noise += var_msg * var_a0_err * a_scale * a_scale * n; - noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs; - noise = noise.sqrt(); - noise /= b_scale; - noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] -} diff --git a/rlwe/src/test.rs b/rlwe/src/test.rs deleted file mode 100644 index 2a7e9d0..0000000 --- a/rlwe/src/test.rs +++ /dev/null @@ -1,113 +0,0 @@ -use base2k::{alloc_aligned, SvpPPol, SvpPPolOps, VecZnx, BACKEND}; -use sampling::source::{Source, new_seed}; -use crate::{ciphertext::Ciphertext, decryptor::decrypt_rlwe, elem::ElemCommon, encryptor::encrypt_rlwe_sk, keys::SecretKey, parameters::{Parameters, ParametersLiteral, DEFAULT_SIGMA}, plaintext::Plaintext}; - - - -pub struct Context{ - pub params: Parameters, - pub sk0: SecretKey, - pub sk0_ppol:SvpPPol, - pub sk1: SecretKey, - pub sk1_ppol: SvpPPol, - pub tmp_bytes: Vec, -} - -impl Context{ - pub fn new(log_n: usize, log_base2k: usize, log_q: usize, log_p: usize) -> Self{ - - let params_lit: ParametersLiteral = ParametersLiteral { - backend: BACKEND::FFT64, - log_n: log_n, - log_q: log_q, - log_p: log_p, - log_base2k: log_base2k, - log_scale: 20, - xe: DEFAULT_SIGMA, - xs: 1 << (log_n-1), - }; - - let params: Parameters =Parameters::new(¶ms_lit); - let module = params.module(); - - let log_q: usize = params.log_q(); - - let mut source_xs: Source = Source::new(new_seed()); - - let mut sk0: SecretKey = SecretKey::new(module); - sk0.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk0_ppol: base2k::SvpPPol = module.new_svp_ppol(); - module.svp_prepare(&mut sk0_ppol, &sk0.0); - - let mut sk1: SecretKey = SecretKey::new(module); - sk1.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk1_ppol: base2k::SvpPPol = module.new_svp_ppol(); - module.svp_prepare(&mut sk1_ppol, &sk1.0); - - let tmp_bytes: Vec = alloc_aligned(params.decrypt_rlwe_tmp_byte(log_q)| params.encrypt_rlwe_sk_tmp_bytes(log_q)); - - Context{ - params: params, - sk0: sk0, - sk0_ppol: sk0_ppol, - sk1: sk1, - sk1_ppol: sk1_ppol, - tmp_bytes: tmp_bytes, - - } - } - - pub fn encrypt_rlwe_sk0(&mut self, pt: &Plaintext, ct: &mut Ciphertext){ - - let mut source_xe: Source = Source::new(new_seed()); - let mut source_xa: Source = Source::new(new_seed()); - - encrypt_rlwe_sk( - self.params.module(), - ct.elem_mut(), - Some(pt.elem()), - &self.sk0_ppol, - &mut source_xa, - &mut source_xe, - self.params.xe(), - &mut self.tmp_bytes, - ); - } - - pub fn encrypt_rlwe_sk1(&mut self, ct: &mut Ciphertext, pt: &Plaintext){ - - let mut source_xe: Source = Source::new(new_seed()); - let mut source_xa: Source = Source::new(new_seed()); - - encrypt_rlwe_sk( - self.params.module(), - ct.elem_mut(), - Some(pt.elem()), - &self.sk1_ppol, - &mut source_xa, - &mut source_xe, - self.params.xe(), - &mut self.tmp_bytes, - ); - } - - pub fn decrypt_sk0(&mut self, pt: &mut Plaintext, ct: &Ciphertext){ - decrypt_rlwe( - self.params.module(), - pt.elem_mut(), - ct.elem(), - &self.sk0_ppol, - &mut self.tmp_bytes, - ); - } - - pub fn decrypt_sk1(&mut self, pt: &mut Plaintext, ct: &Ciphertext){ - decrypt_rlwe( - self.params.module(), - pt.elem_mut(), - ct.elem(), - &self.sk1_ppol, - &mut self.tmp_bytes, - ); - } -} \ No newline at end of file diff --git a/rlwe/src/trace.rs b/rlwe/src/trace.rs deleted file mode 100644 index 005c497..0000000 --- a/rlwe/src/trace.rs +++ /dev/null @@ -1,236 +0,0 @@ -use crate::{automorphism::AutomorphismKey, ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; -use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, MatZnxDftOps, assert_alignement}; -use std::collections::HashMap; - -pub fn trace_galois_elements(module: &Module) -> Vec { - let mut gal_els: Vec = Vec::new(); - (0..module.log_n()).for_each(|i| { - if i == 0 { - gal_els.push(-1); - } else { - gal_els.push(module.galois_element(1 << (i - 1))); - } - }); - gal_els -} - -impl Parameters { - pub fn trace_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { - self.automorphism_tmp_bytes(res_logq, in_logq, gct_logq) - } -} - -pub fn trace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize { - return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols) - + 2 * module.bytes_of_vec_znx_dft(1, std::cmp::min(c_cols, a_cols)); -} - -pub fn trace_inplace( - module: &Module, - a: &mut Ciphertext, - start: usize, - end: usize, - b: &HashMap, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - let cols: usize = a.cols(); - - let b_rows: usize; - - if let Some((_, key)) = b.iter().next() { - b_rows = key.value.rows(); - #[cfg(debug_assertions)] - { - println!("{} {}", b_cols, key.value.cols()); - assert!(b_cols <= key.value.cols()) - } - } else { - panic!("b: HashMap, is empty") - } - - #[cfg(debug_assertions)] - { - assert!(start <= end); - assert!(end <= module.n()); - assert!(tmp_bytes.len() >= trace_tmp_bytes(module, cols, cols, b_rows, b_cols)); - assert_alignement(tmp_bytes.as_ptr()); - } - - let cols: usize = std::cmp::min(b_cols, a.cols()); - - let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - - let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_res_dft); - let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); - - let log_base2k: usize = a.log_base2k(); - - (start..end).for_each(|i| { - a.at_mut(0).rsh(log_base2k, 1, tmp_bytes); - a.at_mut(1).rsh(log_base2k, 1, tmp_bytes); - - let p: i64; - if i == 0 { - p = -1; - } else { - p = module.galois_element(1 << (i - 1)); - } - - if let Some(key) = b.get(&p) { - module.vec_znx_dft(&mut a1_dft, a.at(1)); - - // a[0] = NORMALIZE(a[0] + AUTO(a[0] + IDFT())) - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, key.value.at(0), tmp_bytes); - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); - module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0)); - module.vec_znx_big_automorphism_inplace(p, &mut res_big); - module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0)); - module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut res_big, tmp_bytes); - - // a[1] = NORMALIZE(a[1] + AUTO(IDFT())) - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, key.value.at(1), tmp_bytes); - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); - module.vec_znx_big_automorphism_inplace(p, &mut res_big); - module.vec_znx_big_add_small_inplace(&mut res_big, a.at(1)); - module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut res_big, tmp_bytes); - } else { - panic!("b[{}] is empty", p) - } - }) -} - -#[cfg(test)] -mod test { - use super::{trace_galois_elements, trace_inplace}; - use crate::{ - automorphism::AutomorphismKey, - ciphertext::Ciphertext, - decryptor::decrypt_rlwe, - elem::ElemCommon, - encryptor::encrypt_rlwe_sk, - keys::SecretKey, - parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral}, - plaintext::Plaintext, - }; - use base2k::{BACKEND, Encoding, Module, ScalarZnxDft, ScalarZnxDftOps, VecZnx, alloc_aligned}; - use sampling::source::{Source, new_seed}; - use std::collections::HashMap; - - #[test] - fn test_trace_inplace() { - let log_base2k: usize = 10; - let log_q: usize = 50; - let log_p: usize = 15; - - // Basic parameters with enough limbs to test edge cases - let params_lit: ParametersLiteral = ParametersLiteral { - backend: BACKEND::FFT64, - log_n: 12, - log_q: log_q, - log_p: log_p, - log_base2k: log_base2k, - log_scale: 20, - xe: 3.2, - xs: 1 << 11, - }; - - let params: Parameters = Parameters::new(¶ms_lit); - - let module: &Module = params.module(); - let log_q: usize = params.log_q(); - let log_qp: usize = params.log_qp(); - let gct_rows: usize = params.cols_q(); - let gct_cols: usize = params.cols_qp(); - - // scratch space - let mut tmp_bytes: Vec = alloc_aligned( - params.decrypt_rlwe_tmp_byte(log_q) - | params.encrypt_rlwe_sk_tmp_bytes(log_q) - | params.automorphism_key_new_tmp_bytes(gct_rows, log_qp) - | params.automorphism_tmp_bytes(log_q, log_q, log_qp), - ); - - // Samplers for public and private randomness - let mut source_xe: Source = Source::new(new_seed()); - let mut source_xa: Source = Source::new(new_seed()); - let mut source_xs: Source = Source::new(new_seed()); - - let mut sk: SecretKey = SecretKey::new(module); - sk.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk_svp_ppol: ScalarZnxDft = module.new_svp_ppol(); - module.svp_prepare(&mut sk_svp_ppol, &sk.0); - - let gal_els: Vec = trace_galois_elements(module); - - let auto_keys: HashMap = AutomorphismKey::new_many( - module, - &gal_els, - &sk, - log_base2k, - gct_rows, - log_qp, - &mut source_xa, - &mut source_xe, - DEFAULT_SIGMA, - &mut tmp_bytes, - ); - - let mut data: Vec = vec![0i64; params.n()]; - - data.iter_mut() - .enumerate() - .for_each(|(i, x)| *x = 1 + i as i64); - - let log_k: usize = 2 * log_base2k; - - let mut ct: Ciphertext = params.new_ciphertext(log_q); - let mut pt: Plaintext = params.new_plaintext(log_q); - - pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32); - pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes); - - pt.at(0).decode_vec_i64(0, log_base2k, log_k, &mut data); - - pt.at(0).print(0, pt.cols(), 16); - - encrypt_rlwe_sk( - module, - &mut ct.elem_mut(), - Some(pt.at(0)), - &sk_svp_ppol, - &mut source_xa, - &mut source_xe, - params.xe(), - &mut tmp_bytes, - ); - - trace_inplace(module, &mut ct, 0, 4, &auto_keys, gct_cols, &mut tmp_bytes); - trace_inplace( - module, - &mut ct, - 4, - module.log_n(), - &auto_keys, - gct_cols, - &mut tmp_bytes, - ); - - // pt = dec(auto(ct)) - auto(pt) - decrypt_rlwe( - module, - pt.elem_mut(), - ct.elem(), - &sk_svp_ppol, - &mut tmp_bytes, - ); - - pt.at(0).print(0, pt.cols(), 16); - - pt.at(0).decode_vec_i64(0, log_base2k, log_k, &mut data); - - println!("trace: {:?}", &data[..16]); - } -} From 4efe22e723914635ef74c2b6bb01ab522a357207 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 6 May 2025 14:38:22 +0200 Subject: [PATCH 30/87] Start of full rewrite of rlwe crate --- rlwe/Cargo.toml | 4 -- rlwe/src/ciphertext.rs | 4 -- rlwe/src/elem.rs | 85 ++++++++++++++++++++++++++++++++++++++++++ rlwe/src/encryption.rs | 71 +++++++++++++++++++++++++++++++++++ rlwe/src/keys.rs | 64 +++++++++++++++++++++++++++++++ rlwe/src/lib.rs | 4 +- 6 files changed, 223 insertions(+), 9 deletions(-) delete mode 100644 rlwe/src/ciphertext.rs create mode 100644 rlwe/src/elem.rs create mode 100644 rlwe/src/encryption.rs create mode 100644 rlwe/src/keys.rs diff --git a/rlwe/Cargo.toml b/rlwe/Cargo.toml index 0822281..692c4fb 100644 --- a/rlwe/Cargo.toml +++ b/rlwe/Cargo.toml @@ -10,7 +10,3 @@ base2k = {path="../base2k"} sampling = {path="../sampling"} rand_distr = {workspace = true} itertools = {workspace = true} - -[[bench]] -name = "gadget_product" -harness = false \ No newline at end of file diff --git a/rlwe/src/ciphertext.rs b/rlwe/src/ciphertext.rs deleted file mode 100644 index dc83a66..0000000 --- a/rlwe/src/ciphertext.rs +++ /dev/null @@ -1,4 +0,0 @@ - -pub struct Ciphertext{ - x -} \ No newline at end of file diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs new file mode 100644 index 0000000..7574bf5 --- /dev/null +++ b/rlwe/src/elem.rs @@ -0,0 +1,85 @@ +use base2k::{Backend, FFT64, MatZnxDft, MatZnxDftAlloc, Module, VecZnx, VecZnxAlloc}; + +pub struct Ciphertext { + data: T, + log_base2k: usize, + log_q: usize, +} + +impl Ciphertext { + pub fn log_base2k(&self) -> usize { + self.log_base2k + } + + pub fn log_q(&self) -> usize { + self.log_q + } + + pub fn data(&self) -> &T { + &self.data + } + + pub fn data_mut(&mut self) -> &mut T { + &mut self.data + } +} + +pub struct Plaintext { + data: T, + log_base2k: usize, + log_q: usize, +} + +impl Plaintext { + pub fn log_base2k(&self) -> usize { + self.log_base2k + } + + pub fn log_q(&self) -> usize { + self.log_q + } + + pub fn data(&self) -> &T { + &self.data + } + + pub fn data_mut(&mut self) -> &mut T { + &mut self.data + } +} + +pub(crate) type CipherVecZnx = Ciphertext>; + +impl Ciphertext>> { + pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { + Self { + data: module.new_vec_znx(cols, derive_size(log_base2k, log_q)), + log_base2k: log_base2k, + log_q: log_q, + } + } +} + +impl Plaintext>> { + pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self { + Self { + data: module.new_vec_znx(1, derive_size(log_base2k, log_q)), + log_base2k: log_base2k, + log_q: log_q, + } + } +} + +impl Ciphertext, B>> { + pub fn new(module: &Module, log_base2k: usize, rows: usize, cols_in: usize, cols_out: usize, log_q: usize) -> Self { + Self { + data: module.new_mat_znx_dft(rows, cols_in, cols_out, derive_size(log_base2k, log_q)), + log_base2k: log_base2k, + log_q: log_q, + } + } +} + +pub(crate) fn derive_size(log_base2k: usize, log_q: usize) -> usize { + (log_q + log_base2k - 1) / log_base2k +} diff --git a/rlwe/src/encryption.rs b/rlwe/src/encryption.rs new file mode 100644 index 0000000..ca4b837 --- /dev/null +++ b/rlwe/src/encryption.rs @@ -0,0 +1,71 @@ +use base2k::{ + AddNormal, Backend, FillUniform, Module, VecZnxDftOps, ScalarZnxDftOps, ScalarZnxDftToRef, VecZnxBigOps, Scratch, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, FFT64 +}; + +use sampling::source::Source; + +use crate::{ + elem::{CipherVecZnx, Plaintext}, + keys::SecretKey, +}; + +pub trait EncryptSk { + fn encrypt( + module: &Module, + res: &mut D, + pt: Option<&Plaintext

>, + sk: &SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + sigma: f64, + bound: f64, + ) where + P: VecZnxToRef, + S: ScalarZnxDftToRef; +} + +impl EncryptSk> for CipherVecZnx +where + VecZnx: VecZnxToMut + VecZnxToRef, +{ + fn encrypt( + module: &Module, + ct: &mut CipherVecZnx, + pt: Option<&Plaintext

>, + sk: &SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + sigma: f64, + bound: f64, + ) where + P: VecZnxToRef, + S: ScalarZnxDftToRef, + { + let log_base2k: usize = ct.log_base2k(); + let log_q: usize = ct.log_q(); + let mut ct_mut: VecZnx<&mut [u8]> = ct.data_mut().to_mut(); + let size: usize = ct_mut.size(); + + ct_mut.fill_uniform(log_base2k, 1, size, source_xa); + + // c1_dft = DFT(a) * DFT(s) + let (mut c1_dft, scratch_1) = scratch.tmp_vec_znx_dft(module, 1, size); + module.svp_apply(&mut c1_dft, 0, &sk.data().to_ref(), 0, &ct_mut, 1); + + // c1_big = IDFT(c1_dft) + let (mut c1_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size); + module.vec_znx_idft_tmp_a(&mut c1_big, 0, &mut c1_dft, 0); + + // c1_big = m - c1_big + if let Some(pt) = pt { + module.vec_znx_big_sub_small_b_inplace(&mut c1_big, 0, &pt.data().to_ref(), 0); + } + // c1_big += e + c1_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound); + + // c0 = norm(c1_big) + module.vec_znx_big_normalize(log_base2k, &mut ct_mut, 0, &c1_big, 0, scratch_2); + } +} diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs new file mode 100644 index 0000000..50f1221 --- /dev/null +++ b/rlwe/src/keys.rs @@ -0,0 +1,64 @@ +use base2k::{ + Backend, FFT64, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToMut, Scratch, VecZnx, VecZnxDft, + VecZnxDftAlloc, VecZnxDftToMut, +}; +use sampling::source::Source; + +use crate::elem::derive_size; + +pub struct SecretKey { + data: T, +} + +impl SecretKey { + pub fn data(&self) -> &T { + &self.data + } + + pub fn data_mut(&self) -> &mut T { + &mut self.data + } +} + +impl SecretKey>> { + pub fn new(module: &Module) -> Self { + Self { + data: module.new_scalar(1), + } + } + + pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { + self.data.fill_ternary_prob(0, prob, source); + } + + pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { + self.data.fill_ternary_hw(0, hw, source); + } + + pub fn svp_prepare(&self, module: &Module, sk_prep: &mut SecretKey>) + where + ScalarZnxDft: ScalarZnxDftToMut, + { + module.svp_prepare(&mut sk_prep.data, 0, &self.data, 0) + } +} + +pub struct PublicKey { + data: VecZnxDft, +} + +impl PublicKey, B> { + pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self { + Self { + data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_q)), + } + } +} + +impl> PublicKey { + pub fn generate(&mut self, module: &Module, sk: &SecretKey>) + where + ScalarZnxDft: ScalarZnxDftToMut, + { + } +} diff --git a/rlwe/src/lib.rs b/rlwe/src/lib.rs index 3a7eec6..023acb5 100644 --- a/rlwe/src/lib.rs +++ b/rlwe/src/lib.rs @@ -1 +1,3 @@ -pub mod ciphertext; \ No newline at end of file +pub mod elem; +pub mod encryption; +pub mod keys; From 645f1a94acb25a982f4301041f1fb4b9f1645e58 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 6 May 2025 14:40:57 +0200 Subject: [PATCH 31/87] scope shuffling for encryption of rlwe with sk --- rlwe/src/encryption.rs | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/rlwe/src/encryption.rs b/rlwe/src/encryption.rs index ca4b837..f8b11c9 100644 --- a/rlwe/src/encryption.rs +++ b/rlwe/src/encryption.rs @@ -48,24 +48,29 @@ where let mut ct_mut: VecZnx<&mut [u8]> = ct.data_mut().to_mut(); let size: usize = ct_mut.size(); + // c1 = a ct_mut.fill_uniform(log_base2k, 1, size, source_xa); - // c1_dft = DFT(a) * DFT(s) - let (mut c1_dft, scratch_1) = scratch.tmp_vec_znx_dft(module, 1, size); - module.svp_apply(&mut c1_dft, 0, &sk.data().to_ref(), 0, &ct_mut, 1); + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); - // c1_big = IDFT(c1_dft) - let (mut c1_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size); - module.vec_znx_idft_tmp_a(&mut c1_big, 0, &mut c1_dft, 0); + { + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); - // c1_big = m - c1_big - if let Some(pt) = pt { - module.vec_znx_big_sub_small_b_inplace(&mut c1_big, 0, &pt.data().to_ref(), 0); + // c0_dft = DFT(a) * DFT(s) + module.svp_apply(&mut c0_dft, 0, &sk.data().to_ref(), 0, &ct_mut, 1); + + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); } - // c1_big += e - c1_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound); - // c0 = norm(c1_big) - module.vec_znx_big_normalize(log_base2k, &mut ct_mut, 0, &c1_big, 0, scratch_2); + // c0_big = m - c0_big + if let Some(pt) = pt { + module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, &pt.data().to_ref(), 0); + } + // c0_big += e + c0_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound); + + // c0 = norm(c0_big = -as + m + e) + module.vec_znx_big_normalize(log_base2k, &mut ct_mut, 0, &c0_big, 0, scratch_1); } } From d2303aa29e75952e826cc88725ebcfa72744c61e Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 6 May 2025 14:46:11 +0200 Subject: [PATCH 32/87] small fix to generalize VecZnxBigAlloc --- base2k/src/vec_znx_big_ops.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 185a20c..169c66a 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -149,12 +149,12 @@ pub trait VecZnxBigScratch { fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; } -impl VecZnxBigAlloc for Module { - fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned { +impl VecZnxBigAlloc for Module { + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned { VecZnxBig::new(self, cols, size) } - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { VecZnxBig::new_from_bytes(self, cols, size, bytes) } From 669450c4f1899bb0fcb91115a61aaf0f866928e2 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 6 May 2025 14:46:26 +0200 Subject: [PATCH 33/87] added encrypt_tmp_bytes --- rlwe/src/encryption.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/rlwe/src/encryption.rs b/rlwe/src/encryption.rs index f8b11c9..3b291f9 100644 --- a/rlwe/src/encryption.rs +++ b/rlwe/src/encryption.rs @@ -1,5 +1,6 @@ use base2k::{ - AddNormal, Backend, FillUniform, Module, VecZnxDftOps, ScalarZnxDftOps, ScalarZnxDftToRef, VecZnxBigOps, Scratch, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, FFT64 + AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxBigAlloc, + VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxToMut, VecZnxToRef, ZnxInfos, }; use sampling::source::Source; @@ -23,6 +24,10 @@ pub trait EncryptSk { ) where P: VecZnxToRef, S: ScalarZnxDftToRef; + + fn encrypt_tmp_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + } } impl EncryptSk> for CipherVecZnx From f9b194cca14fba68ee68a4009a8c3274a519320b Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 6 May 2025 16:02:32 +0200 Subject: [PATCH 34/87] Updated svp --- base2k/examples/rlwe_encrypt.rs | 6 +++--- base2k/spqlios-arithmetic | 2 +- base2k/src/ffi/svp.rs | 11 ++++++++++ base2k/src/scalar_znx_dft_ops.rs | 37 +++++++++++++++++++++++++------- 4 files changed, 44 insertions(+), 12 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 79270ea..1a9f4b3 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -40,14 +40,14 @@ fn main() { let mut buf_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_size); + module.vec_znx_dft(&mut buf_dft, 0, &ct, 1); + // Applies DFT(ct[1]) * DFT(s) - module.svp_apply( + module.svp_apply_dft_inplace( &mut buf_dft, // DFT(ct[1] * s) 0, // Selects the first column of res &s_dft, // DFT(s) 0, // Selects the first column of s_dft - &ct, - 1, // Selects the second column of ct ); // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) diff --git a/base2k/spqlios-arithmetic b/base2k/spqlios-arithmetic index 8135d85..b6fa494 160000 --- a/base2k/spqlios-arithmetic +++ b/base2k/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit 8135d85e7ac14601568fdd228e7dedf88994f7cf +Subproject commit b6fa494a14c52842712f8ff032ea80812467dec2 diff --git a/base2k/src/ffi/svp.rs b/base2k/src/ffi/svp.rs index 71c871d..9d4999f 100644 --- a/base2k/src/ffi/svp.rs +++ b/base2k/src/ffi/svp.rs @@ -33,3 +33,14 @@ unsafe extern "C" { a_sl: u64, ); } + +unsafe extern "C" { + pub unsafe fn svp_apply_dft_to_dft( + module: *const MODULE, + res: *const VEC_ZNX_DFT, + res_size: u64, + ppol: *const SVP_PPOL, + a: *const VEC_ZNX_DFT, + a_size: u64, + ); +} diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs index ea98a57..a4b3ccc 100644 --- a/base2k/src/scalar_znx_dft_ops.rs +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -3,14 +3,13 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ Backend, FFT64, Module, ScalarToRef, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, VecZnx, - VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxSliceSize, + VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxSliceSize, }; pub trait ScalarZnxDftAlloc { fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned; fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned; - // fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft; } pub trait ScalarZnxDftOps { @@ -22,7 +21,11 @@ pub trait ScalarZnxDftOps { where R: VecZnxDftToMut, A: ScalarZnxDftToRef, - B: VecZnxToRef; + B: VecZnxDftToRef; + fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef; } impl ScalarZnxDftAlloc for Module { @@ -58,20 +61,38 @@ impl ScalarZnxDftOps for Module { where R: VecZnxDftToMut, A: ScalarZnxDftToRef, - B: VecZnxToRef, + B: VecZnxDftToRef, { let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); - let b: VecZnx<&[u8]> = b.to_ref(); + let b: VecZnxDft<&[u8], FFT64> = b.to_ref(); unsafe { - svp::svp_apply_dft( + svp::svp_apply_dft_to_dft( self.ptr, res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, res.size() as u64, a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, - b.at_ptr(b_col, 0), + b.at_ptr(b_col, 0) as *const vec_znx_dft_t, b.size() as u64, - b.sl() as u64, + ) + } + } + + fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); + unsafe { + svp::svp_apply_dft_to_dft( + self.ptr, + res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, + res.size() as u64, + a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, + res.at_ptr(res_col, 0) as *const vec_znx_dft_t, + res.size() as u64, ) } } From e35924f44cc77a987f0b142720281eecb7481e43 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 6 May 2025 16:43:06 +0200 Subject: [PATCH 35/87] small fix to base2k rlwe encryption example --- base2k/examples/rlwe_encrypt.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 1a9f4b3..4d2961c 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -43,7 +43,7 @@ fn main() { module.vec_znx_dft(&mut buf_dft, 0, &ct, 1); // Applies DFT(ct[1]) * DFT(s) - module.svp_apply_dft_inplace( + module.svp_apply_inplace( &mut buf_dft, // DFT(ct[1] * s) 0, // Selects the first column of res &s_dft, // DFT(s) @@ -102,13 +102,12 @@ fn main() { // Decryption // DFT(ct[1] * s) - module.svp_apply( + module.vec_znx_dft(&mut buf_dft, 0, &ct, 1); + module.svp_apply_inplace( &mut buf_dft, 0, // Selects the first column of res. &s_dft, 0, - &ct, - 1, // Selects the second column of ct (ct[1]) ); // BIG(c1 * s) = IDFT(DFT(c1 * s)) From fe6f99b9ce29b9e6d2d8ae1b4bb24553744a692d Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 6 May 2025 16:43:17 +0200 Subject: [PATCH 36/87] added rlwe basic sk encryption --- rlwe/src/elem.rs | 19 ++++++++- rlwe/src/encryption.rs | 97 ++++++++++++++++++++++++++++++++++++------ rlwe/src/keys.rs | 2 +- 3 files changed, 102 insertions(+), 16 deletions(-) diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 7574bf5..5749208 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,4 +1,4 @@ -use base2k::{Backend, FFT64, MatZnxDft, MatZnxDftAlloc, Module, VecZnx, VecZnxAlloc}; +use base2k::{Backend, FFT64, MatZnxDft, MatZnxDftAlloc, Module, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc}; pub struct Ciphertext { data: T, @@ -48,7 +48,12 @@ impl Plaintext { } } -pub(crate) type CipherVecZnx = Ciphertext>; +pub(crate) type CtVecZnx = Ciphertext>; +pub(crate) type CtVecZnxDft = Ciphertext>; +pub(crate) type CtMatZnxDft = Ciphertext>; +pub(crate) type PtVecZnx = Plaintext>; +pub(crate) type PtVecZnxDft = Plaintext>; +pub(crate) type PtMatZnxDft = Plaintext>; impl Ciphertext>> { pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { @@ -70,6 +75,16 @@ impl Plaintext>> { } } +impl Ciphertext, B>> { + pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { + Self { + data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_q)), + log_base2k: log_base2k, + log_q: log_q, + } + } +} + impl Ciphertext, B>> { pub fn new(module: &Module, log_base2k: usize, rows: usize, cols_in: usize, cols_out: usize, log_q: usize) -> Self { Self { diff --git a/rlwe/src/encryption.rs b/rlwe/src/encryption.rs index 3b291f9..3d62bfe 100644 --- a/rlwe/src/encryption.rs +++ b/rlwe/src/encryption.rs @@ -1,20 +1,21 @@ use base2k::{ - AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxBigAlloc, - VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxToMut, VecZnxToRef, ZnxInfos, + AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, + VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, + VecZnxToMut, VecZnxToRef, ZnxInfos, }; use sampling::source::Source; use crate::{ - elem::{CipherVecZnx, Plaintext}, + elem::{CtVecZnx, CtVecZnxDft, PtVecZnx}, keys::SecretKey, }; -pub trait EncryptSk { - fn encrypt( +pub trait EncryptSk { + fn encrypt( module: &Module, res: &mut D, - pt: Option<&Plaintext

>, + pt: Option<&P>, sk: &SecretKey, source_xa: &mut Source, source_xe: &mut Source, @@ -22,7 +23,6 @@ pub trait EncryptSk { sigma: f64, bound: f64, ) where - P: VecZnxToRef, S: ScalarZnxDftToRef; fn encrypt_tmp_bytes(module: &Module, size: usize) -> usize { @@ -30,14 +30,15 @@ pub trait EncryptSk { } } -impl EncryptSk> for CipherVecZnx +impl EncryptSk, PtVecZnx

> for CtVecZnx where VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx

: VecZnxToRef, { - fn encrypt( + fn encrypt( module: &Module, - ct: &mut CipherVecZnx, - pt: Option<&Plaintext

>, + ct: &mut CtVecZnx, + pt: Option<&PtVecZnx

>, sk: &SecretKey, source_xa: &mut Source, source_xe: &mut Source, @@ -45,7 +46,6 @@ where sigma: f64, bound: f64, ) where - P: VecZnxToRef, S: ScalarZnxDftToRef, { let log_base2k: usize = ct.log_base2k(); @@ -60,9 +60,10 @@ where { let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + module.vec_znx_dft(&mut c0_dft, 0, &ct_mut, 1); // c0_dft = DFT(a) * DFT(s) - module.svp_apply(&mut c0_dft, 0, &sk.data().to_ref(), 0, &ct_mut, 1); + module.svp_apply_inplace(&mut c0_dft, 0, &sk.data().to_ref(), 0); // c0_big = IDFT(c0_dft) module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); @@ -79,3 +80,73 @@ where module.vec_znx_big_normalize(log_base2k, &mut ct_mut, 0, &c0_big, 0, scratch_1); } } + +pub trait EncryptZeroSk { + fn encrypt_zero( + module: &Module, + res: &mut D, + sk: &SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + sigma: f64, + bound: f64, + ) where + S: ScalarZnxDftToRef; + + fn encrypt_zero_tmp_bytes(module: &Module, size: usize) -> usize { + (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) + + module.bytes_of_vec_znx_big(1, size) + + module.bytes_of_vec_znx(1, size) + + module.vec_znx_big_normalize_tmp_bytes() + } +} + +impl EncryptZeroSk> for CtVecZnxDft +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, +{ + fn encrypt_zero( + module: &Module, + ct: &mut CtVecZnxDft, + sk: &SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + sigma: f64, + bound: f64, + ) where + S: ScalarZnxDftToRef, + { + let log_base2k: usize = ct.log_base2k(); + let log_q: usize = ct.log_q(); + let mut ct_mut: VecZnxDft<&mut [u8], FFT64> = ct.data_mut().to_mut(); + let size: usize = ct_mut.size(); + + // ct[1] = DFT(a) + { + let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); + tmp_znx.fill_uniform(log_base2k, 1, size, source_xa); + module.vec_znx_dft(&mut ct_mut, 1, &tmp_znx, 0); + } + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); + + { + let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + // c0_dft = DFT(a) * DFT(s) + module.svp_apply(&mut tmp_dft, 0, &sk.data().to_ref(), 0, &ct_mut, 1); + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); + } + + // c0_big += e + c0_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound); + + // c0 = norm(c0_big = -as + e) + let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); + module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); + // ct[0] = DFT(-as + e) + module.vec_znx_dft(&mut ct_mut, 0, &tmp_znx, 0); + } +} diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index 50f1221..d84abc0 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -15,7 +15,7 @@ impl SecretKey { &self.data } - pub fn data_mut(&self) -> &mut T { + pub fn data_mut(&mut self) -> &mut T { &mut self.data } } From 9afe9372bd843458544ef49031880023a0739581 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 6 May 2025 18:02:00 +0200 Subject: [PATCH 37/87] wip, playing with base2k traits in rlwe crate to ensure inherent compatibility --- rlwe/src/elem.rs | 171 +++++++++++++++++++++++++++++++++++++---- rlwe/src/encryption.rs | 31 ++++---- rlwe/src/keys.rs | 4 +- 3 files changed, 172 insertions(+), 34 deletions(-) diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 5749208..3cb1360 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,4 +1,52 @@ -use base2k::{Backend, FFT64, MatZnxDft, MatZnxDftAlloc, Module, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc}; +use base2k::{ + Backend, DataView, DataViewMut, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module, VecZnx, VecZnxAlloc, + VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, +}; + +pub trait Infos +where + T: ZnxInfos, +{ + fn inner(&self) -> &T; + + /// Returns the ring degree of the polynomials. + fn n(&self) -> usize { + self.inner().n() + } + + /// Returns the base two logarithm of the ring dimension of the polynomials. + fn log_n(&self) -> usize { + self.inner().log_n() + } + + /// Returns the number of rows. + fn rows(&self) -> usize { + self.inner().rows() + } + + /// Returns the number of polynomials in each row. + fn cols(&self) -> usize { + self.inner().cols() + } + + /// Returns the number of size per polynomial. + fn size(&self) -> usize { + let size: usize = self.inner().size(); + debug_assert_eq!(size, derive_size(self.log_base2k(), self.log_q())); + size + } + + /// Returns the total number of small polynomials. + fn poly_count(&self) -> usize { + self.rows() * self.cols() * self.size() + } + + /// Returns the base 2 logarithm of the ciphertext base. + fn log_base2k(&self) -> usize; + + /// Returns the base 2 logarithm of the ciphertext modulus. + fn log_q(&self) -> usize; +} pub struct Ciphertext { data: T, @@ -6,20 +54,32 @@ pub struct Ciphertext { log_q: usize, } -impl Ciphertext { - pub fn log_base2k(&self) -> usize { - self.log_base2k - } - - pub fn log_q(&self) -> usize { - self.log_q - } - - pub fn data(&self) -> &T { +impl Infos for Ciphertext +where + T: ZnxInfos, +{ + fn inner(&self) -> &T { &self.data } - pub fn data_mut(&mut self) -> &mut T { + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_q(&self) -> usize { + self.log_q + } +} + +impl DataView for Ciphertext { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for Ciphertext { + fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } @@ -30,15 +90,24 @@ pub struct Plaintext { log_q: usize, } -impl Plaintext { - pub fn log_base2k(&self) -> usize { +impl Infos for Plaintext +where + T: ZnxInfos, +{ + fn inner(&self) -> &T { + &self.data + } + + fn log_base2k(&self) -> usize { self.log_base2k } - pub fn log_q(&self) -> usize { + fn log_q(&self) -> usize { self.log_q } +} +impl Plaintext { pub fn data(&self) -> &T { &self.data } @@ -55,6 +124,24 @@ pub(crate) type PtVecZnx = Plaintext>; pub(crate) type PtVecZnxDft = Plaintext>; pub(crate) type PtMatZnxDft = Plaintext>; +impl VecZnxToMut for Ciphertext +where + D: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data_mut().to_mut() + } +} + +impl VecZnxToRef for Ciphertext +where + D: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data().to_ref() + } +} + impl Ciphertext>> { pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { Self { @@ -65,6 +152,24 @@ impl Ciphertext>> { } } +impl VecZnxToMut for Plaintext +where + D: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data_mut().to_mut() + } +} + +impl VecZnxToRef for Plaintext +where + D: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data().to_ref() + } +} + impl Plaintext>> { pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self { Self { @@ -75,6 +180,24 @@ impl Plaintext>> { } } +impl VecZnxDftToMut for Ciphertext +where + D: VecZnxDftToMut, +{ + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + self.data_mut().to_mut() + } +} + +impl VecZnxDftToRef for Ciphertext +where + D: VecZnxDftToRef, +{ + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + self.data().to_ref() + } +} + impl Ciphertext, B>> { pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { Self { @@ -85,6 +208,24 @@ impl Ciphertext, B>> { } } +impl MatZnxDftToMut for Ciphertext +where + D: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.data_mut().to_mut() + } +} + +impl MatZnxDftToRef for Ciphertext +where + D: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.data().to_ref() + } +} + impl Ciphertext, B>> { pub fn new(module: &Module, log_base2k: usize, rows: usize, cols_in: usize, cols_out: usize, log_q: usize) -> Self { Self { diff --git a/rlwe/src/encryption.rs b/rlwe/src/encryption.rs index 3d62bfe..de3146f 100644 --- a/rlwe/src/encryption.rs +++ b/rlwe/src/encryption.rs @@ -1,15 +1,12 @@ use base2k::{ AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, - VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, - VecZnxToMut, VecZnxToRef, ZnxInfos, + VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxToMut, + VecZnxToRef, ZnxInfos, }; use sampling::source::Source; -use crate::{ - elem::{CtVecZnx, CtVecZnxDft, PtVecZnx}, - keys::SecretKey, -}; +use crate::{elem::Infos, keys::SecretKey}; pub trait EncryptSk { fn encrypt( @@ -30,15 +27,15 @@ pub trait EncryptSk { } } -impl EncryptSk, PtVecZnx

> for CtVecZnx +impl EncryptSk for C where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx

: VecZnxToRef, + C: VecZnxToMut + ZnxInfos + Infos, + P: VecZnxToRef, { fn encrypt( module: &Module, - ct: &mut CtVecZnx, - pt: Option<&PtVecZnx

>, + ct: &mut C, + pt: Option<&P>, sk: &SecretKey, source_xa: &mut Source, source_xe: &mut Source, @@ -50,7 +47,7 @@ where { let log_base2k: usize = ct.log_base2k(); let log_q: usize = ct.log_q(); - let mut ct_mut: VecZnx<&mut [u8]> = ct.data_mut().to_mut(); + let mut ct_mut: VecZnx<&mut [u8]> = ct.to_mut(); let size: usize = ct_mut.size(); // c1 = a @@ -71,7 +68,7 @@ where // c0_big = m - c0_big if let Some(pt) = pt { - module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, &pt.data().to_ref(), 0); + module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0); } // c0_big += e c0_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound); @@ -102,13 +99,13 @@ pub trait EncryptZeroSk { } } -impl EncryptZeroSk> for CtVecZnxDft +impl EncryptZeroSk for C where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + C: VecZnxDftToMut + ZnxInfos + Infos, { fn encrypt_zero( module: &Module, - ct: &mut CtVecZnxDft, + ct: &mut C, sk: &SecretKey, source_xa: &mut Source, source_xe: &mut Source, @@ -120,7 +117,7 @@ where { let log_base2k: usize = ct.log_base2k(); let log_q: usize = ct.log_q(); - let mut ct_mut: VecZnxDft<&mut [u8], FFT64> = ct.data_mut().to_mut(); + let mut ct_mut: VecZnxDft<&mut [u8], FFT64> = ct.to_mut(); let size: usize = ct_mut.size(); // ct[1] = DFT(a) diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index d84abc0..77f1d9a 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,5 +1,5 @@ use base2k::{ - Backend, FFT64, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToMut, Scratch, VecZnx, VecZnxDft, + Backend, FFT64, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToMut, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, }; use sampling::source::Source; @@ -56,7 +56,7 @@ impl PublicKey, B> { } impl> PublicKey { - pub fn generate(&mut self, module: &Module, sk: &SecretKey>) + pub fn generate(&mut self, module: &Module, sk: &SecretKey>, scratch: &mut Scratch) where ScalarZnxDft: ScalarZnxDftToMut, { From ccebb80660e6b71c2dc3d4319bd5b362e3a6816e Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 7 May 2025 10:23:18 +0200 Subject: [PATCH 38/87] wip --- base2k/src/sampling.rs | 2 - base2k/src/scalar_znx_dft_ops.rs | 4 +- rlwe/src/elem.rs | 35 +++++---- rlwe/src/encryption.rs | 118 ++++++++++++++++++++++++++----- rlwe/src/keys.rs | 13 +++- 5 files changed, 128 insertions(+), 44 deletions(-) diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index 212658a..b2d6f22 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -297,8 +297,6 @@ where #[cfg(test)] mod tests { - use std::fmt::Display; - use super::{AddNormal, FillUniform}; use crate::vec_znx_ops::*; use crate::znx_base::*; diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs index a4b3ccc..a51d72f 100644 --- a/base2k/src/scalar_znx_dft_ops.rs +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -2,8 +2,8 @@ use crate::ffi::svp; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, Module, ScalarToRef, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, VecZnx, - VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxSliceSize, + Backend, FFT64, Module, ScalarToRef, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, + VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, }; pub trait ScalarZnxDftAlloc { diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 3cb1360..1126ed4 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,13 +1,12 @@ use base2k::{ - Backend, DataView, DataViewMut, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module, VecZnx, VecZnxAlloc, - VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, + Backend, DataView, DataViewMut, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnxDftToRef, VecZnx, + VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, }; -pub trait Infos -where - T: ZnxInfos, -{ - fn inner(&self) -> &T; +pub trait Infos { + type Inner: ZnxInfos; + + fn inner(&self) -> &Self::Inner; /// Returns the ring degree of the polynomials. fn n(&self) -> usize { @@ -48,17 +47,16 @@ where fn log_q(&self) -> usize; } -pub struct Ciphertext { - data: T, +pub struct RLWECt{ + data: VecZnx, log_base2k: usize, log_q: usize, } -impl Infos for Ciphertext -where - T: ZnxInfos, -{ - fn inner(&self) -> &T { +impl Infos for RLWECt { + type Inner = T; + + fn inner(&self) -> &Self::Inner { &self.data } @@ -90,11 +88,10 @@ pub struct Plaintext { log_q: usize, } -impl Infos for Plaintext -where - T: ZnxInfos, -{ - fn inner(&self) -> &T { +impl Infos for Plaintext { + type Inner = T; + + fn inner(&self) -> &Self::Inner { &self.data } diff --git a/rlwe/src/encryption.rs b/rlwe/src/encryption.rs index de3146f..e0f9e1f 100644 --- a/rlwe/src/encryption.rs +++ b/rlwe/src/encryption.rs @@ -6,13 +6,16 @@ use base2k::{ use sampling::source::Source; -use crate::{elem::Infos, keys::SecretKey}; +use crate::{ + elem::{Ciphertext, Infos, Plaintext}, + keys::SecretKey, +}; -pub trait EncryptSk { +pub trait EncryptSk { fn encrypt( module: &Module, - res: &mut D, - pt: Option<&P>, + res: &mut Ciphertext, + pt: Option<&Plaintext

>, sk: &SecretKey, source_xa: &mut Source, source_xe: &mut Source, @@ -22,20 +25,18 @@ pub trait EncryptSk { ) where S: ScalarZnxDftToRef; - fn encrypt_tmp_bytes(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) - } + fn encrypt_scratch_bytes(module: &Module, size: usize) -> usize; } -impl EncryptSk for C +impl EncryptSk for Ciphertext where - C: VecZnxToMut + ZnxInfos + Infos, - P: VecZnxToRef, + C: VecZnxToMut + ZnxInfos, + P: VecZnxToRef + ZnxInfos, { fn encrypt( module: &Module, - ct: &mut C, - pt: Option<&P>, + ct: &mut Ciphertext, + pt: Option<&Plaintext

>, sk: &SecretKey, source_xa: &mut Source, source_xe: &mut Source, @@ -76,6 +77,41 @@ where // c0 = norm(c0_big = -as + m + e) module.vec_znx_big_normalize(log_base2k, &mut ct_mut, 0, &c0_big, 0, scratch_1); } + + fn encrypt_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + } +} + +impl Ciphertext +where + C: VecZnxToMut + ZnxInfos, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: Option<&Plaintext

>, + sk: &SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + sigma: f64, + bound: f64, + ) where + P: VecZnxToRef + ZnxInfos, + S: ScalarZnxDftToRef, + { + >::encrypt( + module, self, pt, sk, source_xa, source_xe, scratch, sigma, bound, + ); + } + + pub fn encrypt_sk_scratch_bytes

(module: &Module, size: usize) -> usize + where + Self: EncryptSk, + { + >::encrypt_scratch_bytes(module, size) + } } pub trait EncryptZeroSk { @@ -91,17 +127,12 @@ pub trait EncryptZeroSk { ) where S: ScalarZnxDftToRef; - fn encrypt_zero_tmp_bytes(module: &Module, size: usize) -> usize { - (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) - + module.bytes_of_vec_znx_big(1, size) - + module.bytes_of_vec_znx(1, size) - + module.vec_znx_big_normalize_tmp_bytes() - } + fn encrypt_zero_scratch_bytes(module: &Module, size: usize) -> usize; } impl EncryptZeroSk for C where - C: VecZnxDftToMut + ZnxInfos + Infos, + C: VecZnxDftToMut + ZnxInfos + Infos, { fn encrypt_zero( module: &Module, @@ -146,4 +177,53 @@ where // ct[0] = DFT(-as + e) module.vec_znx_dft(&mut ct_mut, 0, &tmp_znx, 0); } + + fn encrypt_zero_scratch_bytes(module: &Module, size: usize) -> usize{ + (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) + + module.bytes_of_vec_znx_big(1, size) + + module.bytes_of_vec_znx(1, size) + + module.vec_znx_big_normalize_tmp_bytes() + } +} + +#[cfg(test)] +mod tests { + use base2k::{FFT64, Module, ScratchOwned, VecZnx, Scalar}; + use sampling::source::Source; + + use crate::{elem::{Ciphertext, Infos, Plaintext}, keys::SecretKey}; + + #[test] + fn encrypt_sk_vec_znx_fft64() { + let module: Module = Module::::new(32); + let log_base2k: usize = 8; + let log_q: usize = 54; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6; + + let mut ct: Ciphertext>> = Ciphertext::>>::new(&module, log_base2k, log_q, 2); + let mut pt: Plaintext>> = Plaintext::>>::new(&module, log_base2k, log_q); + + let mut source_xe = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + + let mut scratch: ScratchOwned = ScratchOwned::new(ct.encrypt_encsk_scratch_bytes(&module, ct.size())); + + let mut sk: SecretKey>> = SecretKey::new(&module); + let mut sk_prep + sk.svp_prepare(&module, &mut sk_prep); + + ct.encrypt_sk( + &module, + Some(&pt), + &sk_prep, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + sigma, + bound, + ); + } } diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index 77f1d9a..ee8bb94 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,6 +1,5 @@ use base2k::{ - Backend, FFT64, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToMut, Scratch, VecZnxDft, - VecZnxDftAlloc, VecZnxDftToMut, + Backend, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, ZnxInfos, FFT64 }; use sampling::source::Source; @@ -43,6 +42,16 @@ impl SecretKey>> { } } +type SecretKeyPrep = SecretKey>; + +impl SecretKey, B>> { + pub fn new(module: &Module) -> Self{ + Self{ + data: module.new_scalar_znx_dft(1) + } + } +} + pub struct PublicKey { data: VecZnxDft, } From a6224f756341dcd33faf5aec86f047e62240b1b6 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 7 May 2025 11:57:56 +0200 Subject: [PATCH 39/87] updated Scalar name --- base2k/src/scalar_znx.rs | 74 ++++++++++++++++---------------- base2k/src/scalar_znx_dft_ops.rs | 6 +-- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index acdac8c..731add3 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -5,13 +5,13 @@ use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; use sampling::source::Source; -pub struct Scalar { +pub struct ScalarZnx { data: D, n: usize, cols: usize, } -impl ZnxInfos for Scalar { +impl ZnxInfos for ScalarZnx { fn cols(&self) -> usize { self.cols } @@ -29,30 +29,30 @@ impl ZnxInfos for Scalar { } } -impl ZnxSliceSize for Scalar { +impl ZnxSliceSize for ScalarZnx { fn sl(&self) -> usize { self.n() } } -impl DataView for Scalar { +impl DataView for ScalarZnx { type D = D; fn data(&self) -> &Self::D { &self.data } } -impl DataViewMut for Scalar { +impl DataViewMut for ScalarZnx { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } -impl> ZnxView for Scalar { +impl> ZnxView for ScalarZnx { type Scalar = i64; } -impl + AsRef<[u8]>> Scalar { +impl + AsRef<[u8]>> ScalarZnx { pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) { let choices: [i64; 3] = [-1, 0, 1]; let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; @@ -71,7 +71,7 @@ impl + AsRef<[u8]>> Scalar { } } -impl>> Scalar { +impl>> ScalarZnx { pub(crate) fn bytes_of(n: usize, cols: usize) -> usize { n * cols * size_of::() } @@ -96,37 +96,37 @@ impl>> Scalar { } } -pub type ScalarOwned = Scalar>; +pub type ScalarZnxOwned = ScalarZnx>; -pub trait ScalarAlloc { +pub trait ScalarZnxAlloc { fn bytes_of_scalar(&self, cols: usize) -> usize; - fn new_scalar(&self, cols: usize) -> ScalarOwned; - fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarOwned; + fn new_scalar(&self, cols: usize) -> ScalarZnxOwned; + fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned; } -impl ScalarAlloc for Module { +impl ScalarZnxAlloc for Module { fn bytes_of_scalar(&self, cols: usize) -> usize { - ScalarOwned::bytes_of::(self.n(), cols) + ScalarZnxOwned::bytes_of::(self.n(), cols) } - fn new_scalar(&self, cols: usize) -> ScalarOwned { - ScalarOwned::new::(self.n(), cols) + fn new_scalar(&self, cols: usize) -> ScalarZnxOwned { + ScalarZnxOwned::new::(self.n(), cols) } - fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarOwned { - ScalarOwned::new_from_bytes::(self.n(), cols, bytes) + fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned { + ScalarZnxOwned::new_from_bytes::(self.n(), cols, bytes) } } -pub trait ScalarToRef { - fn to_ref(&self) -> Scalar<&[u8]>; +pub trait ScalarZnxToRef { + fn to_ref(&self) -> ScalarZnx<&[u8]>; } -pub trait ScalarToMut { - fn to_mut(&mut self) -> Scalar<&mut [u8]>; +pub trait ScalarZnxToMut { + fn to_mut(&mut self) -> ScalarZnx<&mut [u8]>; } -impl ScalarToMut for Scalar> { - fn to_mut(&mut self) -> Scalar<&mut [u8]> { - Scalar { +impl ScalarZnxToMut for ScalarZnx> { + fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> { + ScalarZnx { data: self.data.as_mut_slice(), n: self.n, cols: self.cols, @@ -134,9 +134,9 @@ impl ScalarToMut for Scalar> { } } -impl ScalarToRef for Scalar> { - fn to_ref(&self) -> Scalar<&[u8]> { - Scalar { +impl ScalarZnxToRef for ScalarZnx> { + fn to_ref(&self) -> ScalarZnx<&[u8]> { + ScalarZnx { data: self.data.as_slice(), n: self.n, cols: self.cols, @@ -144,9 +144,9 @@ impl ScalarToRef for Scalar> { } } -impl ScalarToMut for Scalar<&mut [u8]> { - fn to_mut(&mut self) -> Scalar<&mut [u8]> { - Scalar { +impl ScalarZnxToMut for ScalarZnx<&mut [u8]> { + fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> { + ScalarZnx { data: self.data, n: self.n, cols: self.cols, @@ -154,9 +154,9 @@ impl ScalarToMut for Scalar<&mut [u8]> { } } -impl ScalarToRef for Scalar<&mut [u8]> { - fn to_ref(&self) -> Scalar<&[u8]> { - Scalar { +impl ScalarZnxToRef for ScalarZnx<&mut [u8]> { + fn to_ref(&self) -> ScalarZnx<&[u8]> { + ScalarZnx { data: self.data, n: self.n, cols: self.cols, @@ -164,9 +164,9 @@ impl ScalarToRef for Scalar<&mut [u8]> { } } -impl ScalarToRef for Scalar<&[u8]> { - fn to_ref(&self) -> Scalar<&[u8]> { - Scalar { +impl ScalarZnxToRef for ScalarZnx<&[u8]> { + fn to_ref(&self) -> ScalarZnx<&[u8]> { + ScalarZnx { data: self.data, n: self.n, cols: self.cols, diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs index a51d72f..888b2a9 100644 --- a/base2k/src/scalar_znx_dft_ops.rs +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -2,7 +2,7 @@ use crate::ffi::svp; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, Module, ScalarToRef, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, + Backend, FFT64, Module, ScalarZnxToRef, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, }; @@ -16,7 +16,7 @@ pub trait ScalarZnxDftOps { fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: ScalarZnxDftToMut, - A: ScalarToRef; + A: ScalarZnxToRef; fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) where R: VecZnxDftToMut, @@ -46,7 +46,7 @@ impl ScalarZnxDftOps for Module { fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: ScalarZnxDftToMut, - A: ScalarToRef, + A: ScalarZnxToRef, { unsafe { svp::svp_prepare( From 240884db8dea69c47a4dfbb98d6abb6ec36a97fb Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 7 May 2025 11:58:09 +0200 Subject: [PATCH 40/87] fixed wrong buffer size zeroing --- base2k/src/znx_base.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index db6a50c..e8dcab2 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -110,7 +110,7 @@ where std::ptr::write_bytes( self.as_mut_ptr(), 0, - self.n() * size_of::() * self.poly_count(), + self.n() * self.poly_count(), ); } } @@ -120,7 +120,7 @@ where std::ptr::write_bytes( self.at_mut_ptr(i, j), 0, - self.n() * size_of::(), + self.n(), ); } } @@ -128,7 +128,6 @@ where // Blanket implementations impl ZnxZero for T where T: ZnxViewMut {} -// impl ZnxRsh for T where T: ZnxZero {} use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; From 6ce525e5a1516d6156b239862dd2a03b8374d237 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 7 May 2025 12:05:12 +0200 Subject: [PATCH 41/87] added sk encryption --- base2k/examples/rlwe_encrypt.rs | 4 +- base2k/src/scalar_znx_dft_ops.rs | 4 +- base2k/src/znx_base.rs | 12 +- rlwe/src/elem.rs | 246 ++++++++++++--------------- rlwe/src/encryption.rs | 283 ++++++++++++++++++------------- rlwe/src/keys.rs | 83 ++++++--- 6 files changed, 333 insertions(+), 299 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 4d2961c..16b7d3a 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,5 +1,5 @@ use base2k::{ - AddNormal, Encoding, FFT64, FillUniform, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, + AddNormal, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos, }; @@ -20,7 +20,7 @@ fn main() { let mut source: Source = Source::new(seed); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) - let mut s: Scalar> = module.new_scalar(1); + let mut s: ScalarZnx> = module.new_scalar(1); s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs index 888b2a9..f5f8f7f 100644 --- a/base2k/src/scalar_znx_dft_ops.rs +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -2,8 +2,8 @@ use crate::ffi::svp; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, Module, ScalarZnxToRef, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, - VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, + Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft, + VecZnxDftToMut, VecZnxDftToRef, }; pub trait ScalarZnxDftAlloc { diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index e8dcab2..5230dfd 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -107,21 +107,13 @@ where { fn zero(&mut self) { unsafe { - std::ptr::write_bytes( - self.as_mut_ptr(), - 0, - self.n() * self.poly_count(), - ); + std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count()); } } fn zero_at(&mut self, i: usize, j: usize) { unsafe { - std::ptr::write_bytes( - self.at_mut_ptr(i, j), - 0, - self.n(), - ); + std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n()); } } } diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 1126ed4..fe1b3b4 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,6 +1,6 @@ use base2k::{ - Backend, DataView, DataViewMut, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnxDftToRef, VecZnx, - VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, + Backend, Module, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, + ZnxInfos, }; pub trait Infos { @@ -31,7 +31,7 @@ pub trait Infos { /// Returns the number of size per polynomial. fn size(&self) -> usize { let size: usize = self.inner().size(); - debug_assert_eq!(size, derive_size(self.log_base2k(), self.log_q())); + debug_assert_eq!(size, derive_size(self.log_base2k(), self.log_k())); size } @@ -43,18 +43,18 @@ pub trait Infos { /// Returns the base 2 logarithm of the ciphertext base. fn log_base2k(&self) -> usize; - /// Returns the base 2 logarithm of the ciphertext modulus. - fn log_q(&self) -> usize; + /// Returns the bit precision of the ciphertext. + fn log_k(&self) -> usize; } -pub struct RLWECt{ - data: VecZnx, - log_base2k: usize, - log_q: usize, +pub struct RLWECt { + pub data: VecZnx, + pub log_base2k: usize, + pub log_k: usize, } -impl Infos for RLWECt { - type Inner = T; +impl Infos for RLWECt { + type Inner = VecZnx; fn inner(&self) -> &Self::Inner { &self.data @@ -64,32 +64,37 @@ impl Infos for RLWECt { self.log_base2k } - fn log_q(&self) -> usize { - self.log_q + fn log_k(&self) -> usize { + self.log_k } } -impl DataView for Ciphertext { - type D = D; - fn data(&self) -> &Self::D { - &self.data +impl VecZnxToMut for RLWECt +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() } } -impl DataViewMut for Ciphertext { - fn data_mut(&mut self) -> &mut Self::D { - &mut self.data +impl VecZnxToRef for RLWECt +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() } } -pub struct Plaintext { - data: T, - log_base2k: usize, - log_q: usize, +pub struct RLWEPt { + pub data: VecZnx, + pub log_base2k: usize, + pub log_k: usize, } -impl Infos for Plaintext { - type Inner = T; +impl Infos for RLWEPt { + type Inner = VecZnx; fn inner(&self) -> &Self::Inner { &self.data @@ -99,140 +104,99 @@ impl Infos for Plaintext { self.log_base2k } - fn log_q(&self) -> usize { - self.log_q + fn log_k(&self) -> usize { + self.log_k } } -impl Plaintext { - pub fn data(&self) -> &T { +impl VecZnxToMut for RLWEPt +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for RLWEPt +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl RLWECt> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { + Self { + data: module.new_vec_znx(cols, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl RLWEPt> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +pub struct RLWECtDft { + pub data: VecZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl RLWECtDft, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx_dft(1, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for RLWECtDft { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { &self.data } - pub fn data_mut(&mut self) -> &mut T { - &mut self.data + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k } } -pub(crate) type CtVecZnx = Ciphertext>; -pub(crate) type CtVecZnxDft = Ciphertext>; -pub(crate) type CtMatZnxDft = Ciphertext>; -pub(crate) type PtVecZnx = Plaintext>; -pub(crate) type PtVecZnxDft = Plaintext>; -pub(crate) type PtMatZnxDft = Plaintext>; - -impl VecZnxToMut for Ciphertext +impl VecZnxDftToMut for RLWECtDft where - D: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data_mut().to_mut() - } -} - -impl VecZnxToRef for Ciphertext -where - D: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data().to_ref() - } -} - -impl Ciphertext>> { - pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { - Self { - data: module.new_vec_znx(cols, derive_size(log_base2k, log_q)), - log_base2k: log_base2k, - log_q: log_q, - } - } -} - -impl VecZnxToMut for Plaintext -where - D: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data_mut().to_mut() - } -} - -impl VecZnxToRef for Plaintext -where - D: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data().to_ref() - } -} - -impl Plaintext>> { - pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self { - Self { - data: module.new_vec_znx(1, derive_size(log_base2k, log_q)), - log_base2k: log_base2k, - log_q: log_q, - } - } -} - -impl VecZnxDftToMut for Ciphertext -where - D: VecZnxDftToMut, + VecZnxDft: VecZnxDftToMut, { fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { - self.data_mut().to_mut() + self.data.to_mut() } } -impl VecZnxDftToRef for Ciphertext +impl VecZnxDftToRef for RLWECtDft where - D: VecZnxDftToRef, + VecZnxDft: VecZnxDftToRef, { fn to_ref(&self) -> VecZnxDft<&[u8], B> { - self.data().to_ref() + self.data.to_ref() } } -impl Ciphertext, B>> { - pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { - Self { - data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_q)), - log_base2k: log_base2k, - log_q: log_q, - } - } -} - -impl MatZnxDftToMut for Ciphertext -where - D: MatZnxDftToMut, -{ - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - self.data_mut().to_mut() - } -} - -impl MatZnxDftToRef for Ciphertext -where - D: MatZnxDftToRef, -{ - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - self.data().to_ref() - } -} - -impl Ciphertext, B>> { - pub fn new(module: &Module, log_base2k: usize, rows: usize, cols_in: usize, cols_out: usize, log_q: usize) -> Self { - Self { - data: module.new_mat_znx_dft(rows, cols_in, cols_out, derive_size(log_base2k, log_q)), - log_base2k: log_base2k, - log_q: log_q, - } - } -} - -pub(crate) fn derive_size(log_base2k: usize, log_q: usize) -> usize { - (log_q + log_base2k - 1) / log_base2k +pub(crate) fn derive_size(log_base2k: usize, log_k: usize) -> usize { + (log_k + log_base2k - 1) / log_base2k } diff --git a/rlwe/src/encryption.rs b/rlwe/src/encryption.rs index e0f9e1f..148ded4 100644 --- a/rlwe/src/encryption.rs +++ b/rlwe/src/encryption.rs @@ -1,161 +1,166 @@ +use std::cmp::min; + use base2k::{ - AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, - VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxToMut, - VecZnxToRef, ZnxInfos, + AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, + VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, + VecZnxDftToRef, VecZnxToMut, VecZnxToRef, }; use sampling::source::Source; use crate::{ - elem::{Ciphertext, Infos, Plaintext}, - keys::SecretKey, + elem::{Infos, RLWECt, RLWECtDft, RLWEPt}, + keys::SecretKeyDft, }; -pub trait EncryptSk { - fn encrypt( - module: &Module, - res: &mut Ciphertext, - pt: Option<&Plaintext

>, - sk: &SecretKey, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - sigma: f64, - bound: f64, - ) where - S: ScalarZnxDftToRef; - - fn encrypt_scratch_bytes(module: &Module, size: usize) -> usize; +pub fn encrypt_rlwe_sk_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) } -impl EncryptSk for Ciphertext -where - C: VecZnxToMut + ZnxInfos, - P: VecZnxToRef + ZnxInfos, +pub fn encrypt_rlwe_sk( + module: &Module, + ct: &mut RLWECt, + pt: Option<&RLWEPt

>, + sk: &SecretKeyDft, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + sigma: f64, + bound: f64, +) where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx

: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, { - fn encrypt( - module: &Module, - ct: &mut Ciphertext, - pt: Option<&Plaintext

>, - sk: &SecretKey, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - sigma: f64, - bound: f64, - ) where - S: ScalarZnxDftToRef, + let log_base2k: usize = ct.log_base2k(); + let log_k: usize = ct.log_k(); + let size: usize = ct.size(); + + // c1 = a + ct.data.fill_uniform(log_base2k, 1, size, source_xa); + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); + { - let log_base2k: usize = ct.log_base2k(); - let log_q: usize = ct.log_q(); - let mut ct_mut: VecZnx<&mut [u8]> = ct.to_mut(); - let size: usize = ct_mut.size(); + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + module.vec_znx_dft(&mut c0_dft, 0, ct, 1); - // c1 = a - ct_mut.fill_uniform(log_base2k, 1, size, source_xa); + // c0_dft = DFT(a) * DFT(s) + module.svp_apply_inplace(&mut c0_dft, 0, sk, 0); - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); - - { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); - module.vec_znx_dft(&mut c0_dft, 0, &ct_mut, 1); - - // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, &sk.data().to_ref(), 0); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); - } - - // c0_big = m - c0_big - if let Some(pt) = pt { - module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0); - } - // c0_big += e - c0_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound); - - // c0 = norm(c0_big = -as + m + e) - module.vec_znx_big_normalize(log_base2k, &mut ct_mut, 0, &c0_big, 0, scratch_1); + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); } - fn encrypt_scratch_bytes(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + // c0_big = m - c0_big + if let Some(pt) = pt { + module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0); } + // c0_big += e + c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); + + // c0 = norm(c0_big = -as + m + e) + module.vec_znx_big_normalize(log_base2k, ct, 0, &c0_big, 0, scratch_1); } -impl Ciphertext -where - C: VecZnxToMut + ZnxInfos, +pub fn decrypt_rlwe( + module: &Module, + pt: &mut RLWEPt

, + ct: &RLWECt, + sk: &SecretKeyDft, + scratch: &mut Scratch, +) where + VecZnx

: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, { + let size: usize = min(pt.size(), ct.size()); + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); + + { + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + module.vec_znx_dft(&mut c0_dft, 0, ct, 1); + + // c0_dft = DFT(a) * DFT(s) + module.svp_apply_inplace(&mut c0_dft, 0, sk, 0); + + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); + } + + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_big_add_small_inplace(&mut c0_big, 0, ct, 0); + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1); + + pt.log_base2k = ct.log_base2k(); + pt.log_k = min(pt.log_k(), ct.log_k()); +} + +pub fn decrypt_rlwe_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) +} + +impl RLWECt { pub fn encrypt_sk( &mut self, module: &Module, - pt: Option<&Plaintext

>, - sk: &SecretKey, + pt: Option<&RLWEPt

>, + sk: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, sigma: f64, bound: f64, ) where - P: VecZnxToRef + ZnxInfos, - S: ScalarZnxDftToRef, + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx

: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, { - >::encrypt( + encrypt_rlwe_sk( module, self, pt, sk, source_xa, source_xe, scratch, sigma, bound, - ); + ) } - pub fn encrypt_sk_scratch_bytes

(module: &Module, size: usize) -> usize + pub fn decrypt(&self, module: &Module, pt: &mut RLWEPt

, sk: &SecretKeyDft, scratch: &mut Scratch) where - Self: EncryptSk, + VecZnx

: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, { - >::encrypt_scratch_bytes(module, size) + decrypt_rlwe(module, pt, self, sk, scratch); } } -pub trait EncryptZeroSk { - fn encrypt_zero( - module: &Module, - res: &mut D, - sk: &SecretKey, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - sigma: f64, - bound: f64, - ) where - S: ScalarZnxDftToRef; - - fn encrypt_zero_scratch_bytes(module: &Module, size: usize) -> usize; +pub(crate) fn encrypt_rlwe_zero_dft_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) } -impl EncryptZeroSk for C -where - C: VecZnxDftToMut + ZnxInfos + Infos, -{ +impl RLWECtDft { fn encrypt_zero( module: &Module, - ct: &mut C, - sk: &SecretKey, + ct: &mut RLWECtDft, + sk: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, sigma: f64, bound: f64, ) where - S: ScalarZnxDftToRef, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef, { let log_base2k: usize = ct.log_base2k(); - let log_q: usize = ct.log_q(); - let mut ct_mut: VecZnxDft<&mut [u8], FFT64> = ct.to_mut(); - let size: usize = ct_mut.size(); + let log_k: usize = ct.log_k(); + let size: usize = ct.size(); // ct[1] = DFT(a) { let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); tmp_znx.fill_uniform(log_base2k, 1, size, source_xa); - module.vec_znx_dft(&mut ct_mut, 1, &tmp_znx, 0); + module.vec_znx_dft(ct, 1, &tmp_znx, 0); } let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); @@ -163,22 +168,22 @@ where { let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); // c0_dft = DFT(a) * DFT(s) - module.svp_apply(&mut tmp_dft, 0, &sk.data().to_ref(), 0, &ct_mut, 1); + module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1); // c0_big = IDFT(c0_dft) module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); } // c0_big += e - c0_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound); + c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); // c0 = norm(c0_big = -as + e) let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); // ct[0] = DFT(-as + e) - module.vec_znx_dft(&mut ct_mut, 0, &tmp_znx, 0); + module.vec_znx_dft(ct, 0, &tmp_znx, 0); } - fn encrypt_zero_scratch_bytes(module: &Module, size: usize) -> usize{ + fn encrypt_zero_scratch_bytes(module: &Module, size: usize) -> usize { (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + module.bytes_of_vec_znx(1, size) @@ -188,42 +193,80 @@ where #[cfg(test)] mod tests { - use base2k::{FFT64, Module, ScratchOwned, VecZnx, Scalar}; + use base2k::{Encoding, FFT64, Module, ScratchOwned, ZnxZero}; + use itertools::izip; use sampling::source::Source; - use crate::{elem::{Ciphertext, Infos, Plaintext}, keys::SecretKey}; + use crate::{ + elem::{Infos, RLWECt, RLWEPt}, + keys::{SecretKey, SecretKeyDft}, + }; + + use super::{decrypt_rlwe_scratch_bytes, encrypt_rlwe_sk_scratch_bytes}; #[test] fn encrypt_sk_vec_znx_fft64() { let module: Module = Module::::new(32); let log_base2k: usize = 8; - let log_q: usize = 54; + let log_k_ct: usize = 54; + let log_k_pt: usize = 40; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6; + let bound: f64 = sigma * 6.0; - let mut ct: Ciphertext>> = Ciphertext::>>::new(&module, log_base2k, log_q, 2); - let mut pt: Plaintext>> = Plaintext::>>::new(&module, log_base2k, log_q); + let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct, 2); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_pt); - let mut source_xe = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new(ct.encrypt_encsk_scratch_bytes(&module, ct.size())); + let mut scratch: ScratchOwned = + ScratchOwned::new(encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) | decrypt_rlwe_scratch_bytes(&module, ct.size())); - let mut sk: SecretKey>> = SecretKey::new(&module); - let mut sk_prep - sk.svp_prepare(&module, &mut sk_prep); + let sk: SecretKey> = SecretKey::new(&module); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0xFF); + + pt.data + .encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10); ct.encrypt_sk( &module, Some(&pt), - &sk_prep, + &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow(), sigma, bound, ); + + pt.data.zero(); + + ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + let mut data_have: Vec = vec![0i64; module.n()]; + + pt.data + .decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have); + + let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64; + izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { + let b_scaled = (*b as f64) / scale; + assert!( + (*a as f64 - b_scaled).abs() < 0.1, + "{} {}", + *a as f64, + b_scaled + ) + }); + + module.free(); } } diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index ee8bb94..767d1eb 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,31 +1,27 @@ use base2k::{ - Backend, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, ZnxInfos, FFT64 + Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut, + ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, }; use sampling::source::Source; use crate::elem::derive_size; pub struct SecretKey { - data: T, + pub data: ScalarZnx, } -impl SecretKey { - pub fn data(&self) -> &T { - &self.data - } - - pub fn data_mut(&mut self) -> &mut T { - &mut self.data - } -} - -impl SecretKey>> { +impl SecretKey> { pub fn new(module: &Module) -> Self { Self { data: module.new_scalar(1), } } +} +impl SecretKey +where + S: AsMut<[u8]> + AsRef<[u8]>, +{ pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { self.data.fill_ternary_prob(0, prob, source); } @@ -33,27 +29,66 @@ impl SecretKey>> { pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { self.data.fill_ternary_hw(0, hw, source); } +} - pub fn svp_prepare(&self, module: &Module, sk_prep: &mut SecretKey>) - where - ScalarZnxDft: ScalarZnxDftToMut, - { - module.svp_prepare(&mut sk_prep.data, 0, &self.data, 0) +impl ScalarZnxToMut for SecretKey +where + ScalarZnx: ScalarZnxToMut, +{ + fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> { + self.data.to_mut() } } -type SecretKeyPrep = SecretKey>; +impl ScalarZnxToRef for SecretKey +where + ScalarZnx: ScalarZnxToRef, +{ + fn to_ref(&self) -> ScalarZnx<&[u8]> { + self.data.to_ref() + } +} -impl SecretKey, B>> { - pub fn new(module: &Module) -> Self{ - Self{ - data: module.new_scalar_znx_dft(1) +pub struct SecretKeyDft { + pub data: ScalarZnxDft, +} + +impl SecretKeyDft, B> { + pub fn new(module: &Module) -> Self { + Self { + data: module.new_scalar_znx_dft(1), } } + + pub fn dft(&mut self, module: &Module, sk: &SecretKey) + where + SecretKeyDft, B>: ScalarZnxDftToMut, + SecretKey: ScalarZnxToRef, + { + module.svp_prepare(self, 0, sk, 0) + } +} + +impl ScalarZnxDftToMut for SecretKeyDft +where + ScalarZnxDft: ScalarZnxDftToMut, +{ + fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl ScalarZnxDftToRef for SecretKeyDft +where + ScalarZnxDft: ScalarZnxDftToRef, +{ + fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { + self.data.to_ref() + } } pub struct PublicKey { - data: VecZnxDft, + pub data: VecZnxDft, } impl PublicKey, B> { From 64874dbda8ad6ae242cad18bfa64bd990168f852 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 7 May 2025 15:51:01 +0200 Subject: [PATCH 42/87] multiples fixes to base2k, including svp to take into account column interleaving --- base2k/examples/rlwe_encrypt.rs | 2 +- base2k/src/ffi/svp.rs | 2 ++ base2k/src/lib.rs | 20 +++++++++++++++++- base2k/src/mat_znx_dft_ops.rs | 2 +- base2k/src/scalar_znx.rs | 22 +++++++++++++------ base2k/src/scalar_znx_dft.rs | 15 +++++++++++++ base2k/src/scalar_znx_dft_ops.rs | 4 ++++ base2k/src/vec_znx_big.rs | 36 ++++++++++++++++++++++++++++++++ base2k/src/vec_znx_big_ops.rs | 2 +- base2k/src/vec_znx_dft_ops.rs | 2 +- base2k/src/vec_znx_ops.rs | 4 ++-- base2k/src/znx_base.rs | 2 +- 12 files changed, 99 insertions(+), 14 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 16b7d3a..b9d78f4 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -20,7 +20,7 @@ fn main() { let mut source: Source = Source::new(seed); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) - let mut s: ScalarZnx> = module.new_scalar(1); + let mut s: ScalarZnx> = module.new_scalar_znx(1); s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain diff --git a/base2k/src/ffi/svp.rs b/base2k/src/ffi/svp.rs index 9d4999f..08b2da1 100644 --- a/base2k/src/ffi/svp.rs +++ b/base2k/src/ffi/svp.rs @@ -39,8 +39,10 @@ unsafe extern "C" { module: *const MODULE, res: *const VEC_ZNX_DFT, res_size: u64, + res_cols: u64, ppol: *const SVP_PPOL, a: *const VEC_ZNX_DFT, a_size: u64, + a_cols: u64, ); } diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index f3b2525..450a69f 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -177,7 +177,7 @@ impl Scratch { } } - pub fn tmp_scalar_slice(&mut self, len: usize) -> (&mut [T], &mut Self) { + pub fn tmp_slice(&mut self, len: usize) -> (&mut [T], &mut Self) { let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::()); unsafe { @@ -188,6 +188,24 @@ impl Scratch { } } + pub fn tmp_scalar(&mut self, module: &Module, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols)); + + ( + ScalarZnx::from_data(take_slice, module.n(), cols), + Self::new(rem_slice), + ) + } + + pub fn tmp_scalar_dft(&mut self, module: &Module, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols)); + + ( + ScalarZnxDft::from_data(take_slice, module.n(), cols), + Self::new(rem_slice), + ) + } + pub fn tmp_vec_znx_dft( &mut self, module: &Module, diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index ae0cbb5..85e6264 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -279,7 +279,7 @@ impl MatZnxDftOps for Module { ); } - let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_tmp_bytes( + let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes( res.size(), a.size(), b.rows(), diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index 731add3..dde286a 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -98,24 +98,34 @@ impl>> ScalarZnx { pub type ScalarZnxOwned = ScalarZnx>; +pub(crate) fn bytes_of_scalar_znx(module: &Module, cols: usize) -> usize { + ScalarZnxOwned::bytes_of::(module.n(), cols) +} + pub trait ScalarZnxAlloc { - fn bytes_of_scalar(&self, cols: usize) -> usize; - fn new_scalar(&self, cols: usize) -> ScalarZnxOwned; - fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned; + fn bytes_of_scalar_znx(&self, cols: usize) -> usize; + fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned; + fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned; } impl ScalarZnxAlloc for Module { - fn bytes_of_scalar(&self, cols: usize) -> usize { + fn bytes_of_scalar_znx(&self, cols: usize) -> usize { ScalarZnxOwned::bytes_of::(self.n(), cols) } - fn new_scalar(&self, cols: usize) -> ScalarZnxOwned { + fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned { ScalarZnxOwned::new::(self.n(), cols) } - fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned { + fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned { ScalarZnxOwned::new_from_bytes::(self.n(), cols, bytes) } } +impl ScalarZnx { + pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { + Self { data, n, cols } + } +} + pub trait ScalarZnxToRef { fn to_ref(&self) -> ScalarZnx<&[u8]>; } diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index c93609f..3626625 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -52,6 +52,10 @@ impl> ZnxView for ScalarZnxDft { type Scalar = f64; } +pub(crate) fn bytes_of_scalar_znx_dft(module: &Module, cols: usize) -> usize { + ScalarZnxDftOwned::bytes_of(module, cols) +} + impl>, B: Backend> ScalarZnxDft { pub(crate) fn bytes_of(module: &Module, cols: usize) -> usize { unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols } @@ -79,6 +83,17 @@ impl>, B: Backend> ScalarZnxDft { } } +impl ScalarZnxDft { + pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { + Self { + data, + n, + cols, + _phantom: PhantomData, + } + } +} + pub type ScalarZnxDftOwned = ScalarZnxDft, B>; pub trait ScalarZnxDftToRef { diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs index f5f8f7f..f02fa03 100644 --- a/base2k/src/scalar_znx_dft_ops.rs +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -71,9 +71,11 @@ impl ScalarZnxDftOps for Module { self.ptr, res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, res.size() as u64, + res.cols() as u64, a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, b.at_ptr(b_col, 0) as *const vec_znx_dft_t, b.size() as u64, + b.cols() as u64, ) } } @@ -90,9 +92,11 @@ impl ScalarZnxDftOps for Module { self.ptr, res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, res.size() as u64, + res.cols() as u64, a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, res.at_ptr(res_col, 0) as *const vec_znx_dft_t, res.size() as u64, + res.cols() as u64, ) } } diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 8f70272..f5f220e 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -2,6 +2,7 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned}; use std::marker::PhantomData; +use std::{cmp::min, fmt}; pub struct VecZnxBig { data: D, @@ -162,3 +163,38 @@ impl VecZnxBigToRef for VecZnxBig<&[u8], B> { } } } + +impl> fmt::Display for VecZnxBig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "VecZnx(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + } + Ok(()) + } +} diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 169c66a..933deb3 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -528,7 +528,7 @@ impl VecZnxBigOps for Module { // assert_alignement(tmp_bytes.as_ptr()); } - let (tmp_bytes, _) = scratch.tmp_scalar_slice(::vec_znx_big_normalize_tmp_bytes( + let (tmp_bytes, _) = scratch.tmp_slice(::vec_znx_big_normalize_tmp_bytes( &self, )); unsafe { diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 83b7c26..927e39e 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -141,7 +141,7 @@ impl VecZnxDftOps for Module { let mut res_mut = res.to_mut(); let a_ref = a.to_ref(); - let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_idft_tmp_bytes()); + let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_idft_tmp_bytes()); let min_size: usize = min(res_mut.size(), a_ref.size()); diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index cdabe24..c80e9f1 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -175,7 +175,7 @@ impl VecZnxOps for Module { assert_eq!(res.n(), self.n()); } - let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_normalize_tmp_bytes()); + let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes()); unsafe { vec_znx::vec_znx_normalize_base2k( @@ -203,7 +203,7 @@ impl VecZnxOps for Module { assert_eq!(a.n(), self.n()); } - let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_normalize_tmp_bytes()); + let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes()); unsafe { vec_znx::vec_znx_normalize_base2k( diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index 5230dfd..94da450 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -171,7 +171,7 @@ where let k_rem: usize = k % log_base2k; if k_rem != 0 { - let (carry, _) = scratch.tmp_scalar_slice::(rsh_tmp_bytes::(n)); + let (carry, _) = scratch.tmp_slice::(rsh_tmp_bytes::(n)); unsafe { std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::()); From 6cbd2a6a9380dd7648aac6e05e6ca93227757321 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 7 May 2025 16:47:58 +0200 Subject: [PATCH 43/87] Some fixes & QoL to Base2k --- base2k/examples/rlwe_encrypt.rs | 6 ++--- base2k/spqlios-arithmetic | 2 +- base2k/src/encoding.rs | 44 ++++++++++++++++++--------------- base2k/src/mat_znx_dft_ops.rs | 2 +- base2k/src/sampling.rs | 8 +++--- base2k/src/stats.rs | 4 +-- base2k/src/vec_znx_big.rs | 4 +-- base2k/src/vec_znx_dft.rs | 36 +++++++++++++++++++++++++++ 8 files changed, 73 insertions(+), 33 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index b9d78f4..4db6ef5 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,7 +1,7 @@ use base2k::{ - AddNormal, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, - ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, - VecZnxDftOps, VecZnxOps, ZnxInfos, + AddNormal, Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, + ScalarZnxDftOps, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, + VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos, }; use itertools::izip; use sampling::source::Source; diff --git a/base2k/spqlios-arithmetic b/base2k/spqlios-arithmetic index b6fa494..b919282 160000 --- a/base2k/spqlios-arithmetic +++ b/base2k/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit b6fa494a14c52842712f8ff032ea80812467dec2 +Subproject commit b919282c9b913e8b11418df6afdb0baa02debc9b diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index ba48474..45214c6 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -17,6 +17,20 @@ pub trait Encoding { /// * `log_max`: base two logarithm of the infinity norm of the input data. fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize); + /// encodes a single i64 on the receiver at the given index. + /// + /// # Arguments + /// + /// * `col_i`: the index of the poly where to encode the data. + /// * `log_base2k`: base two negative logarithm decomposition of the receiver. + /// * `log_k`: base two negative logarithm of the scaling of the data. + /// * `i`: index of the coefficient on which to encode the data. + /// * `data`: data to encode on the receiver. + /// * `log_max`: base two logarithm of the infinity norm of the input data. + fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize); +} + +pub trait Decoding { /// decode a vector of i64 from the receiver. /// /// # Arguments @@ -35,18 +49,6 @@ pub trait Encoding { /// * `data`: data to decode from the receiver. fn decode_vec_float(&self, col_i: usize, log_base2k: usize, data: &mut [Float]); - /// encodes a single i64 on the receiver at the given index. - /// - /// # Arguments - /// - /// * `col_i`: the index of the poly where to encode the data. - /// * `log_base2k`: base two negative logarithm decomposition of the receiver. - /// * `log_k`: base two negative logarithm of the scaling of the data. - /// * `i`: index of the coefficient on which to encode the data. - /// * `data`: data to encode on the receiver. - /// * `log_max`: base two logarithm of the infinity norm of the input data. - fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize); - /// decode a single of i64 from the receiver at the given index. /// /// # Arguments @@ -64,6 +66,12 @@ impl + AsRef<[u8]>> Encoding for VecZnx { encode_vec_i64(self, col_i, log_base2k, log_k, data, log_max) } + fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { + encode_coeff_i64(self, col_i, log_base2k, log_k, i, value, log_max) + } +} + +impl> Decoding for VecZnx { fn decode_vec_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { decode_vec_i64(self, col_i, log_base2k, log_k, data) } @@ -72,10 +80,6 @@ impl + AsRef<[u8]>> Encoding for VecZnx { decode_vec_float(self, col_i, log_base2k, data) } - fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { - encode_coeff_i64(self, col_i, log_base2k, log_k, i, value, log_max) - } - fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { decode_coeff_i64(self, col_i, log_base2k, log_k, i) } @@ -139,7 +143,7 @@ fn encode_vec_i64 + AsRef<[u8]>>( } } -fn decode_vec_i64 + AsRef<[u8]>>(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { +fn decode_vec_i64>(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { @@ -167,7 +171,7 @@ fn decode_vec_i64 + AsRef<[u8]>>(a: &VecZnx, col_i: usize, log }) } -fn decode_vec_float + AsRef<[u8]>>(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) { +fn decode_vec_float>(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) { let size: usize = a.size(); #[cfg(debug_assertions)] { @@ -252,7 +256,7 @@ fn encode_coeff_i64 + AsRef<[u8]>>( } } -fn decode_coeff_i64 + AsRef<[u8]>>(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { +fn decode_coeff_i64>(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { #[cfg(debug_assertions)] { assert!(i < a.n()); @@ -280,7 +284,7 @@ fn decode_coeff_i64 + AsRef<[u8]>>(a: &VecZnx, col_i: usize, l mod tests { use crate::vec_znx_ops::*; use crate::znx_base::*; - use crate::{Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos}; + use crate::{Decoding, Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos}; use itertools::izip; use sampling::source::Source; diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 85e6264..f302e9b 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -305,7 +305,7 @@ impl MatZnxDftOps for Module { #[cfg(test)] mod tests { use crate::{ - Encoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, + Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut, }; use sampling::source::Source; diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index b2d6f22..b4e1489 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -80,7 +80,7 @@ where ); let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; - let log_base2k_rem: usize = log_k % log_base2k; + let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k; if log_base2k_rem != 0 { a.at_mut(col_i, limb).iter_mut().for_each(|a| { @@ -123,7 +123,7 @@ where ); let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; - let log_base2k_rem: usize = log_k % log_base2k; + let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k; if log_base2k_rem != 0 { a.at_mut(col_i, limb).iter_mut().for_each(|a| { @@ -198,7 +198,7 @@ where ); let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; - let log_base2k_rem: usize = log_k % log_base2k; + let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k; if log_base2k_rem != 0 { a.at_mut(col_i, limb).iter_mut().for_each(|a| { @@ -241,7 +241,7 @@ where ); let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; - let log_base2k_rem: usize = log_k % log_base2k; + let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k; if log_base2k_rem != 0 { a.at_mut(col_i, limb).iter_mut().for_each(|a| { diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs index c6d16b4..8db40f2 100644 --- a/base2k/src/stats.rs +++ b/base2k/src/stats.rs @@ -1,5 +1,5 @@ use crate::znx_base::ZnxInfos; -use crate::{Encoding, VecZnx}; +use crate::{Decoding, VecZnx}; use rug::Float; use rug::float::Round; use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound}; @@ -9,7 +9,7 @@ pub trait Stats { fn std(&self, col_i: usize, log_base2k: usize) -> f64; } -impl + AsRef<[u8]>> Stats for VecZnx { +impl> Stats for VecZnx { fn std(&self, col_i: usize, log_base2k: usize) -> f64 { let prec: u32 = (self.size() * log_base2k) as u32; let mut data: Vec = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index f5f220e..d8c1bdd 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,8 +1,8 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned}; +use std::fmt; use std::marker::PhantomData; -use std::{cmp::min, fmt}; pub struct VecZnxBig { data: D, @@ -168,7 +168,7 @@ impl> fmt::Display for VecZnxBig { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, - "VecZnx(n={}, cols={}, size={})", + "VecZnxBig(n={}, cols={}, size={})", self.n, self.cols, self.size )?; diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 66e58cf..0e7f952 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use crate::ffi::vec_znx_dft; use crate::znx_base::ZnxInfos; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; +use std::fmt; pub struct VecZnxDft { data: D, @@ -163,3 +164,38 @@ impl VecZnxDftToRef for VecZnxDft<&[u8], B> { } } } + +impl> fmt::Display for VecZnxDft { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "VecZnxDft(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + } + Ok(()) + } +} From 48ac28c4ce403f3f0c36ed84545c8b3750844a00 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 7 May 2025 17:04:42 +0200 Subject: [PATCH 44/87] Added sk/pk encryption for rlwe/rlwedft with tests --- rlwe/src/elem.rs | 4 +- rlwe/src/encryption.rs | 406 +++++++++++++++++++++++++++++++++++------ rlwe/src/keys.rs | 113 ++++++++++-- 3 files changed, 451 insertions(+), 72 deletions(-) diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index fe1b3b4..d1ddb74 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -154,9 +154,9 @@ pub struct RLWECtDft { } impl RLWECtDft, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { Self { - data: module.new_vec_znx_dft(1, derive_size(log_base2k, log_k)), + data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_k)), log_base2k: log_base2k, log_k: log_k, } diff --git a/rlwe/src/encryption.rs b/rlwe/src/encryption.rs index 148ded4..0bdae33 100644 --- a/rlwe/src/encryption.rs +++ b/rlwe/src/encryption.rs @@ -1,16 +1,16 @@ use std::cmp::min; use base2k::{ - AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, - VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, - VecZnxDftToRef, VecZnxToMut, VecZnxToRef, + AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, + ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, + VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, }; use sampling::source::Source; use crate::{ elem::{Infos, RLWECt, RLWECtDft, RLWEPt}, - keys::SecretKeyDft, + keys::{PublicKey, SecretDistribution, SecretKeyDft}, }; pub fn encrypt_rlwe_sk_scratch_bytes(module: &Module, size: usize) -> usize { @@ -24,9 +24,9 @@ pub fn encrypt_rlwe_sk( sk: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, sigma: f64, bound: f64, + scratch: &mut Scratch, ) where VecZnx: VecZnxToMut + VecZnxToRef, VecZnx

: VecZnxToRef, @@ -74,12 +74,10 @@ pub fn decrypt_rlwe( VecZnx: VecZnxToRef, ScalarZnxDft: ScalarZnxDftToRef, { - let size: usize = min(pt.size(), ct.size()); - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct module.vec_znx_dft(&mut c0_dft, 0, ct, 1); // c0_dft = DFT(a) * DFT(s) @@ -111,16 +109,16 @@ impl RLWECt { sk: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, sigma: f64, bound: f64, + scratch: &mut Scratch, ) where VecZnx: VecZnxToMut + VecZnxToRef, VecZnx

: VecZnxToRef, ScalarZnxDft: ScalarZnxDftToRef, { encrypt_rlwe_sk( - module, self, pt, sk, source_xa, source_xe, scratch, sigma, bound, + module, self, pt, sk, source_xa, source_xe, sigma, bound, scratch, ) } @@ -132,84 +130,258 @@ impl RLWECt { { decrypt_rlwe(module, pt, self, sk, scratch); } + + pub fn encrypt_pk( + &mut self, + module: &Module, + pt: Option<&RLWEPt

>, + pk: &PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx

: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + { + encrypt_rlwe_pk( + module, self, pt, pk, source_xu, source_xe, sigma, bound, scratch, + ) + } } -pub(crate) fn encrypt_rlwe_zero_dft_scratch_bytes(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) +pub(crate) fn encrypt_zero_rlwe_dft_sk( + module: &Module, + ct: &mut RLWECtDft, + sk: &SecretKeyDft, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let log_base2k: usize = ct.log_base2k(); + let log_k: usize = ct.log_k(); + let size: usize = ct.size(); + + #[cfg(debug_assertions)] + { + match sk.dist { + SecretDistribution::NONE => panic!("invalid sk.dist = SecretDistribution::NONE"), + _ => {} + } + assert_eq!(ct.cols(), 2); + } + + // ct[1] = DFT(a) + { + let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); + tmp_znx.fill_uniform(log_base2k, 0, size, source_xa); + module.vec_znx_dft(ct, 1, &tmp_znx, 0); + } + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); + + { + let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + // c0_dft = ct[1] * DFT(s) + module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1); + + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); + } + + // c0_big += e + c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); + + // c0 = norm(c0_big = -as - e), NOTE: e is centered at 0. + let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); + module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); + module.vec_znx_negate_inplace(&mut tmp_znx, 0); + // ct[0] = DFT(-as + e) + module.vec_znx_dft(ct, 0, &tmp_znx, 0); +} + +pub(crate) fn encrypt_zero_rlwe_dft_scratch_bytes(module: &Module, size: usize) -> usize { + (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) + + module.bytes_of_vec_znx_big(1, size) + + module.bytes_of_vec_znx(1, size) + + module.vec_znx_big_normalize_tmp_bytes() +} + +pub fn decrypt_rlwe_dft( + module: &Module, + pt: &mut RLWEPt

, + ct: &RLWECtDft, + sk: &SecretKeyDft, + scratch: &mut Scratch, +) where + VecZnx

: VecZnxToMut + VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct + + { + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct + // c0_dft = DFT(a) * DFT(s) + module.svp_apply(&mut c0_dft, 0, sk, 0, ct, 1); + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); + } + + { + let (mut c1_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, ct.size()); + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_idft(&mut c1_big, 0, ct, 0, scratch_2); + module.vec_znx_big_add_inplace(&mut c0_big, 0, &c1_big, 0); + } + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1); + + pt.log_base2k = ct.log_base2k(); + pt.log_k = min(pt.log_k(), ct.log_k()); +} + +pub fn decrypt_rlwe_dft_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() + | module.bytes_of_vec_znx_dft(1, size) + | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) + + module.bytes_of_vec_znx_big(1, size) } impl RLWECtDft { - fn encrypt_zero( + pub(crate) fn encrypt_zero_sk( + &mut self, module: &Module, - ct: &mut RLWECtDft, - sk: &SecretKeyDft, + sk_dft: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, sigma: f64, bound: f64, + scratch: &mut Scratch, ) where VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, ScalarZnxDft: ScalarZnxDftToRef, { - let log_base2k: usize = ct.log_base2k(); - let log_k: usize = ct.log_k(); - let size: usize = ct.size(); - - // ct[1] = DFT(a) - { - let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); - tmp_znx.fill_uniform(log_base2k, 1, size, source_xa); - module.vec_znx_dft(ct, 1, &tmp_znx, 0); - } - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); - - { - let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); - // c0_dft = DFT(a) * DFT(s) - module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1); - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); - } - - // c0_big += e - c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); - - // c0 = norm(c0_big = -as + e) - let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); - module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); - // ct[0] = DFT(-as + e) - module.vec_znx_dft(ct, 0, &tmp_znx, 0); + encrypt_zero_rlwe_dft_sk( + module, self, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ) } - fn encrypt_zero_scratch_bytes(module: &Module, size: usize) -> usize { - (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) - + module.bytes_of_vec_znx_big(1, size) - + module.bytes_of_vec_znx(1, size) - + module.vec_znx_big_normalize_tmp_bytes() + pub fn decrypt( + &self, + module: &Module, + pt: &mut RLWEPt

, + sk_dft: &SecretKeyDft, + scratch: &mut Scratch, + ) where + VecZnx

: VecZnxToMut + VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + decrypt_rlwe_dft(module, pt, self, sk_dft, scratch); } } +pub fn encrypt_rlwe_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { + ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) + + module.bytes_of_scalar_znx_dft(1) + + module.vec_znx_big_normalize_tmp_bytes() +} + +pub(crate) fn encrypt_rlwe_pk( + module: &Module, + ct: &mut RLWECt, + pt: Option<&RLWEPt

>, + pk: &PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx

: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, +{ + #[cfg(debug_assertions)] + { + assert_eq!(ct.log_base2k(), pk.log_base2k()); + assert_eq!(ct.n(), module.n()); + assert_eq!(pk.n(), module.n()); + if let Some(pt) = pt { + assert_eq!(pt.log_base2k(), pk.log_base2k()); + assert_eq!(pt.n(), module.n()); + } + } + + let log_base2k: usize = pk.log_base2k(); + let size_pk: usize = pk.size(); + + // Generates u according to the underlying secret distribution. + let (mut u_dft, scratch_1) = scratch.tmp_scalar_dft(module, 1); + + { + let (mut u, _) = scratch_1.tmp_scalar(module, 1); + match pk.dist { + SecretDistribution::NONE => panic!( + "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate" + ), + SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), + SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + } + + module.svp_prepare(&mut u_dft, 0, &u, 0); + } + + let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) + let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) + + // ct[0] = pk[0] * u + m + e0 + module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0); + module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); + tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); + + if let Some(pt) = pt { + module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0); + } + + module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3); + + // ct[1] = pk[1] * u + e1 + module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1); + module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); + tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); + module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3); +} + #[cfg(test)] mod tests { - use base2k::{Encoding, FFT64, Module, ScratchOwned, ZnxZero}; + use base2k::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero}; use itertools::izip; use sampling::source::Source; use crate::{ - elem::{Infos, RLWECt, RLWEPt}, - keys::{SecretKey, SecretKeyDft}, + elem::{Infos, RLWECt, RLWECtDft, RLWEPt}, + encryption::{decrypt_rlwe_dft_scratch_bytes, encrypt_zero_rlwe_dft_scratch_bytes}, + keys::{PublicKey, SecretKey, SecretKeyDft}, }; - use super::{decrypt_rlwe_scratch_bytes, encrypt_rlwe_sk_scratch_bytes}; + use super::{decrypt_rlwe_scratch_bytes, encrypt_rlwe_pk_scratch_bytes, encrypt_rlwe_sk_scratch_bytes}; #[test] fn encrypt_sk_vec_znx_fft64() { let module: Module = Module::::new(32); let log_base2k: usize = 8; let log_k_ct: usize = 54; - let log_k_pt: usize = 40; + let log_k_pt: usize = 30; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; @@ -217,13 +389,16 @@ mod tests { let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct, 2); let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_pt); + let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new(encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) | decrypt_rlwe_scratch_bytes(&module, ct.size())); - let sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); sk_dft.dft(&module, &sk); @@ -242,9 +417,9 @@ mod tests { &sk_dft, &mut source_xa, &mut source_xe, - scratch.borrow(), sigma, bound, + scratch.borrow(), ); pt.data.zero(); @@ -256,6 +431,7 @@ mod tests { pt.data .decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have); + // TODO: properly assert the decryption noise through std(dec(ct) - pt) let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64; izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { let b_scaled = (*b as f64) / scale; @@ -269,4 +445,118 @@ mod tests { module.free(); } + + #[test] + fn encrypt_zero_rlwe_dft_sk_fft64() { + let module: Module = Module::::new(1024); + let log_base2k: usize = 8; + let log_k_ct: usize = 55; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut ct_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2); + + let mut scratch: ScratchOwned = ScratchOwned::new( + encrypt_rlwe_sk_scratch_bytes(&module, ct_dft.size()) + | decrypt_rlwe_dft_scratch_bytes(&module, ct_dft.size()) + | encrypt_zero_rlwe_dft_scratch_bytes(&module, ct_dft.size()), + ); + + ct_dft.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2); + module.free(); + } + + #[test] + fn encrypt_pk_vec_znx_fft64() { + let module: Module = Module::::new(32); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let log_k_pk: usize = 64; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct, 2); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xu: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut pk: PublicKey, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk); + pk.generate( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + ); + + let mut scratch: ScratchOwned = ScratchOwned::new( + encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) + | decrypt_rlwe_scratch_bytes(&module, ct.size()) + | encrypt_rlwe_pk_scratch_bytes(&module, pk.size()), + ); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0); + + pt_want + .data + .encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10); + + ct.encrypt_pk( + &module, + Some(&pt_want), + &pk, + &mut source_xu, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); + + assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2); + + module.free(); + } } diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index 767d1eb..89c33e3 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,19 +1,31 @@ use base2k::{ Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut, - ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, + ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, ScratchOwned, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos, }; use sampling::source::Source; -use crate::elem::derive_size; +use crate::{ + elem::{Infos, RLWECtDft}, + encryption::encrypt_zero_rlwe_dft_scratch_bytes, +}; + +#[derive(Clone, Copy, Debug)] +pub enum SecretDistribution { + TernaryFixed(usize), // Ternary with fixed Hamming weight + TernaryProb(f64), // Ternary with probabilistic Hamming weight + NONE, +} pub struct SecretKey { pub data: ScalarZnx, + pub dist: SecretDistribution, } impl SecretKey> { pub fn new(module: &Module) -> Self { Self { - data: module.new_scalar(1), + data: module.new_scalar_znx(1), + dist: SecretDistribution::NONE, } } } @@ -24,10 +36,12 @@ where { pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { self.data.fill_ternary_prob(0, prob, source); + self.dist = SecretDistribution::TernaryProb(prob); } pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { self.data.fill_ternary_hw(0, hw, source); + self.dist = SecretDistribution::TernaryFixed(hw); } } @@ -51,12 +65,14 @@ where pub struct SecretKeyDft { pub data: ScalarZnxDft, + pub dist: SecretDistribution, } impl SecretKeyDft, B> { pub fn new(module: &Module) -> Self { Self { data: module.new_scalar_znx_dft(1), + dist: SecretDistribution::NONE, } } @@ -65,7 +81,16 @@ impl SecretKeyDft, B> { SecretKeyDft, B>: ScalarZnxDftToMut, SecretKey: ScalarZnxToRef, { - module.svp_prepare(self, 0, sk, 0) + #[cfg(debug_assertions)] + { + match sk.dist { + SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"), + _ => {} + } + } + + module.svp_prepare(self, 0, sk, 0); + self.dist = sk.dist; } } @@ -88,21 +113,85 @@ where } pub struct PublicKey { - pub data: VecZnxDft, + pub data: RLWECtDft, + pub dist: SecretDistribution, } impl PublicKey, B> { - pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { Self { - data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_q)), + data: RLWECtDft::new(module, log_base2k, log_k, 2), + dist: SecretDistribution::NONE, } } } -impl> PublicKey { - pub fn generate(&mut self, module: &Module, sk: &SecretKey>, scratch: &mut Scratch) - where - ScalarZnxDft: ScalarZnxDftToMut, - { +impl Infos for PublicKey { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data.data + } + + fn log_base2k(&self) -> usize { + self.data.log_base2k + } + + fn log_k(&self) -> usize { + self.data.log_k + } +} + +impl VecZnxDftToMut for PublicKey +where + VecZnxDft: VecZnxDftToMut, +{ + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl VecZnxDftToRef for PublicKey +where + VecZnxDft: VecZnxDftToRef, +{ + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl PublicKey { + pub fn generate( + &mut self, + module: &Module, + sk_dft: &SecretKeyDft, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + ) where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef + ZnxInfos, + { + #[cfg(debug_assertions)] + { + match sk_dft.dist { + SecretDistribution::NONE => panic!("invalid sk_dft: SecretDistribution::NONE"), + _ => {} + } + } + + // Its ok to allocate scratch space here since pk is usually generated only once. + let mut scratch: ScratchOwned = ScratchOwned::new(encrypt_zero_rlwe_dft_scratch_bytes(module, self.size())); + self.data.encrypt_zero_sk( + module, + sk_dft, + source_xa, + source_xe, + sigma, + bound, + scratch.borrow(), + ); + self.dist = sk_dft.dist; } } From 2ec905bbc36cfa88f78b7d02044e9150e30c5ca4 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 8 May 2025 10:16:20 +0200 Subject: [PATCH 45/87] added vec_znx_idft_consume --- base2k/src/vec_znx_big.rs | 2 +- base2k/src/vec_znx_dft.rs | 8 +++++++- base2k/src/vec_znx_dft_ops.rs | 29 ++++++++++++++++++++++++++++- 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index d8c1bdd..2875b97 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,6 +1,6 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxDft, ZnxSliceSize, alloc_aligned}; use std::fmt; use std::marker::PhantomData; diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 0e7f952..61e1be5 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use crate::ffi::vec_znx_dft; use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned}; use std::fmt; pub struct VecZnxDft { @@ -13,6 +13,12 @@ pub struct VecZnxDft { _phantom: PhantomData, } +impl VecZnxDft { + pub fn into_big(self) -> VecZnxBig { + VecZnxBig::::from_data(self.data, self.n, self.cols, self.size) + } +} + impl ZnxInfos for VecZnxDft { fn cols(&self) -> usize { self.cols diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 927e39e..cf06cc2 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -1,7 +1,10 @@ use crate::ffi::{vec_znx_big, vec_znx_dft}; use crate::vec_znx_dft::bytes_of_vec_znx_dft; use crate::znx_base::ZnxInfos; -use crate::{Backend, Scratch, VecZnxBigToMut, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxSliceSize}; +use crate::{ + Backend, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, + ZnxSliceSize, +}; use crate::{FFT64, Module, ZnxView, ZnxViewMut, ZnxZero}; use std::cmp::min; @@ -44,6 +47,9 @@ pub trait VecZnxDftOps { where R: VecZnxBigToMut, A: VecZnxDftToMut; + fn vec_znx_idft_consume(&self, a: VecZnxDft, a_cols: usize) -> VecZnxBig + where + VecZnxDft: VecZnxDftToMut; fn vec_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) where @@ -97,6 +103,27 @@ impl VecZnxDftOps for Module { } } + fn vec_znx_idft_consume(&self, mut a: VecZnxDft, a_col: usize) -> VecZnxBig + where + VecZnxDft: VecZnxDftToMut, + { + let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); + + unsafe { + (0..a_mut.size()).for_each(|j| { + vec_znx_dft::vec_znx_idft_tmp_a( + self.ptr, + a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_big::vec_znx_big_t, + 1 as u64, + a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + ) + }); + + a.into_big() + } + } + fn vec_znx_idft_tmp_bytes(&self) -> usize { unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize } } From 398ad604d9a1a3237bcaf2b917082b3131c5b766 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 8 May 2025 10:24:35 +0200 Subject: [PATCH 46/87] added GRLWE and RGSW --- rlwe/src/elem.rs | 104 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 102 insertions(+), 2 deletions(-) diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index d1ddb74..e86fd08 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,6 +1,6 @@ use base2k::{ - Backend, Module, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, - ZnxInfos, + Backend, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, + VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, }; pub trait Infos { @@ -197,6 +197,106 @@ where } } +pub struct GRLWECt { + pub data: MatZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl GRLWECt, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { + Self { + data: module.new_mat_znx_dft(rows, 1, 2, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for GRLWECt { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl MatZnxDftToMut for GRLWECt +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl MatZnxDftToRef for GRLWECt +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +pub struct RGSWCt { + pub data: MatZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl RGSWCt, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { + Self { + data: module.new_mat_znx_dft(rows, 2, 2, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for RGSWCt { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl MatZnxDftToMut for RGSWCt +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl MatZnxDftToRef for RGSWCt +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.data.to_ref() + } +} + pub(crate) fn derive_size(log_base2k: usize, log_k: usize) -> usize { (log_k + log_base2k - 1) / log_base2k } From 8b3b2e4b9c03bf9035b55b556b3e4d0adbff4578 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 8 May 2025 10:55:51 +0200 Subject: [PATCH 47/87] added grlwe sk encryption --- base2k/src/vec_znx_big.rs | 2 +- rlwe/src/elem.rs | 261 +---------------------- rlwe/src/elem_grlwe.rs | 53 +++++ rlwe/src/elem_rgsw.rs | 140 ++++++++++++ rlwe/src/{encryption.rs => elem_rlwe.rs} | 233 ++++++++++++++++---- rlwe/src/keys.rs | 10 +- rlwe/src/lib.rs | 5 +- rlwe/src/utils.rs | 3 + 8 files changed, 400 insertions(+), 307 deletions(-) create mode 100644 rlwe/src/elem_grlwe.rs create mode 100644 rlwe/src/elem_rgsw.rs rename rlwe/src/{encryption.rs => elem_rlwe.rs} (76%) create mode 100644 rlwe/src/utils.rs diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 2875b97..d8c1bdd 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,6 +1,6 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxDft, ZnxSliceSize, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned}; use std::fmt; use std::marker::PhantomData; diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index e86fd08..c943de1 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,7 +1,6 @@ -use base2k::{ - Backend, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, - VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, -}; +use base2k::ZnxInfos; + +use crate::utils::derive_size; pub trait Infos { type Inner: ZnxInfos; @@ -46,257 +45,3 @@ pub trait Infos { /// Returns the bit precision of the ciphertext. fn log_k(&self) -> usize; } - -pub struct RLWECt { - pub data: VecZnx, - pub log_base2k: usize, - pub log_k: usize, -} - -impl Infos for RLWECt { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl VecZnxToMut for RLWECt -where - VecZnx: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data.to_mut() - } -} - -impl VecZnxToRef for RLWECt -where - VecZnx: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data.to_ref() - } -} - -pub struct RLWEPt { - pub data: VecZnx, - pub log_base2k: usize, - pub log_k: usize, -} - -impl Infos for RLWEPt { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl VecZnxToMut for RLWEPt -where - VecZnx: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data.to_mut() - } -} - -impl VecZnxToRef for RLWEPt -where - VecZnx: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data.to_ref() - } -} - -impl RLWECt> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { - Self { - data: module.new_vec_znx(cols, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl RLWEPt> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -pub struct RLWECtDft { - pub data: VecZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} - -impl RLWECtDft, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { - Self { - data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for RLWECtDft { - type Inner = VecZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl VecZnxDftToMut for RLWECtDft -where - VecZnxDft: VecZnxDftToMut, -{ - fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { - self.data.to_mut() - } -} - -impl VecZnxDftToRef for RLWECtDft -where - VecZnxDft: VecZnxDftToRef, -{ - fn to_ref(&self) -> VecZnxDft<&[u8], B> { - self.data.to_ref() - } -} - -pub struct GRLWECt { - pub data: MatZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} - -impl GRLWECt, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { - Self { - data: module.new_mat_znx_dft(rows, 1, 2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for GRLWECt { - type Inner = MatZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl MatZnxDftToMut for GRLWECt -where - MatZnxDft: MatZnxDftToMut, -{ - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - self.data.to_mut() - } -} - -impl MatZnxDftToRef for GRLWECt -where - MatZnxDft: MatZnxDftToRef, -{ - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - self.data.to_ref() - } -} - -pub struct RGSWCt { - pub data: MatZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} - -impl RGSWCt, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { - Self { - data: module.new_mat_znx_dft(rows, 2, 2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for RGSWCt { - type Inner = MatZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl MatZnxDftToMut for RGSWCt -where - MatZnxDft: MatZnxDftToMut, -{ - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - self.data.to_mut() - } -} - -impl MatZnxDftToRef for RGSWCt -where - MatZnxDft: MatZnxDftToRef, -{ - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - self.data.to_ref() - } -} - -pub(crate) fn derive_size(log_base2k: usize, log_k: usize) -> usize { - (log_k + log_base2k - 1) / log_base2k -} diff --git a/rlwe/src/elem_grlwe.rs b/rlwe/src/elem_grlwe.rs new file mode 100644 index 0000000..b269cb3 --- /dev/null +++ b/rlwe/src/elem_grlwe.rs @@ -0,0 +1,53 @@ +use base2k::{Backend, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module}; + +use crate::{elem::Infos, utils::derive_size}; + +pub struct GRLWECt { + pub data: MatZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl GRLWECt, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { + Self { + data: module.new_mat_znx_dft(rows, 1, 2, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for GRLWECt { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl MatZnxDftToMut for GRLWECt +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl MatZnxDftToRef for GRLWECt +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.data.to_ref() + } +} diff --git a/rlwe/src/elem_rgsw.rs b/rlwe/src/elem_rgsw.rs new file mode 100644 index 0000000..1a1ea24 --- /dev/null +++ b/rlwe/src/elem_rgsw.rs @@ -0,0 +1,140 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxDftOps, ZnxView, ZnxViewMut, +}; +use sampling::source::Source; + +use crate::{ + elem::Infos, + elem_grlwe::GRLWECt, + elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, + keys::SecretKeyDft, + utils::derive_size, +}; + +pub struct RGSWCt { + pub data: MatZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl RGSWCt, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { + Self { + data: module.new_mat_znx_dft(rows, 2, 2, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for RGSWCt { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl MatZnxDftToMut for RGSWCt +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl MatZnxDftToRef for RGSWCt +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl GRLWECt, FFT64> { + pub fn encrypt_sk_scratch_bytes(module: &Module, size: usize) -> usize { + RLWECt::encrypt_sk_scratch_bytes(module, size) + + module.bytes_of_vec_znx(2, size) + + module.bytes_of_vec_znx(1, size) + + module.bytes_of_vec_znx_dft(2, size) + } + + pub fn encrypt_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { + RLWECt::encrypt_pk_scratch_bytes(module, pk_size) + } + + pub fn decrypt_scratch_bytes(module: &Module, size: usize) -> usize { + RLWECtDft::decrypt_scratch_bytes(module, size) + } +} + +pub fn encrypt_grlwe_sk( + module: &Module, + ct: &mut GRLWECt, + pt: &ScalarZnx

, + sk: &SecretKeyDft, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + MatZnxDft: MatZnxDftToMut, + ScalarZnx

: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let rows: usize = ct.rows(); + let size: usize = ct.size(); + + let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); + let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size); + let (mut tmp_dft, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size); + + let mut tmp_pt: RLWEPt<&mut [u8]> = RLWEPt { + data: tmp_znx_pt, + log_base2k: ct.log_base2k(), + log_k: ct.log_k(), + }; + + let mut tmp_ct: RLWECt<&mut [u8]> = RLWECt { + data: tmp_znx_ct, + log_base2k: ct.log_base2k(), + log_k: ct.log_k(), + }; + + (0..rows).for_each(|row_i| { + tmp_pt + .data + .at_mut(0, row_i) + .copy_from_slice(&pt.to_ref().raw()); + + tmp_ct.encrypt_sk( + module, + Some(&tmp_pt), + sk, + source_xa, + source_xe, + sigma, + bound, + scratch_3, + ); + + tmp_pt.data.at_mut(0, row_i).fill(0); + + module.vec_znx_dft(&mut tmp_dft, 0, &tmp_ct, 0); + module.vec_znx_dft(&mut tmp_dft, 1, &tmp_ct, 1); + + module.vmp_prepare_row(ct, row_i, 0, &tmp_dft); + }); +} diff --git a/rlwe/src/encryption.rs b/rlwe/src/elem_rlwe.rs similarity index 76% rename from rlwe/src/encryption.rs rename to rlwe/src/elem_rlwe.rs index 0bdae33..8a7d444 100644 --- a/rlwe/src/encryption.rs +++ b/rlwe/src/elem_rlwe.rs @@ -1,20 +1,180 @@ -use std::cmp::min; - use base2k::{ AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, }; - use sampling::source::Source; use crate::{ - elem::{Infos, RLWECt, RLWECtDft, RLWEPt}, + elem::Infos, keys::{PublicKey, SecretDistribution, SecretKeyDft}, + utils::derive_size, }; -pub fn encrypt_rlwe_sk_scratch_bytes(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) +pub struct RLWECt { + pub data: VecZnx, + pub log_base2k: usize, + pub log_k: usize, +} + +impl RLWECt> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { + Self { + data: module.new_vec_znx(cols, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for RLWECt { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl VecZnxToMut for RLWECt +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for RLWECt +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +pub struct RLWEPt { + pub data: VecZnx, + pub log_base2k: usize, + pub log_k: usize, +} + +impl Infos for RLWEPt { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl VecZnxToMut for RLWEPt +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for RLWEPt +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl RLWEPt> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +pub struct RLWECtDft { + pub data: VecZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl RLWECtDft, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { + Self { + data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for RLWECtDft { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl VecZnxDftToMut for RLWECtDft +where + VecZnxDft: VecZnxDftToMut, +{ + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl VecZnxDftToRef for RLWECtDft +where + VecZnxDft: VecZnxDftToRef, +{ + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl RLWECt> { + pub fn encrypt_sk_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + } + + pub fn encrypt_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { + ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) + + module.bytes_of_scalar_znx_dft(1) + + module.vec_znx_big_normalize_tmp_bytes() + } + + pub fn decrypt_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + } } pub fn encrypt_rlwe_sk( @@ -94,11 +254,7 @@ pub fn decrypt_rlwe( module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1); pt.log_base2k = ct.log_base2k(); - pt.log_k = min(pt.log_k(), ct.log_k()); -} - -pub fn decrypt_rlwe_scratch_bytes(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + pt.log_k = pt.log_k().min(ct.log_k()); } impl RLWECt { @@ -207,11 +363,20 @@ pub(crate) fn encrypt_zero_rlwe_dft_sk( module.vec_znx_dft(ct, 0, &tmp_znx, 0); } -pub(crate) fn encrypt_zero_rlwe_dft_scratch_bytes(module: &Module, size: usize) -> usize { - (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) - + module.bytes_of_vec_znx_big(1, size) - + module.bytes_of_vec_znx(1, size) - + module.vec_znx_big_normalize_tmp_bytes() +impl RLWECtDft, FFT64> { + pub fn encrypt_zero_sk_scratch_bytes(module: &Module, size: usize) -> usize { + (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) + + module.bytes_of_vec_znx_big(1, size) + + module.bytes_of_vec_znx(1, size) + + module.vec_znx_big_normalize_tmp_bytes() + } + + pub fn decrypt_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() + | module.bytes_of_vec_znx_dft(1, size) + | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) + + module.bytes_of_vec_znx_big(1, size) + } } pub fn decrypt_rlwe_dft( @@ -246,14 +411,7 @@ pub fn decrypt_rlwe_dft( module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1); pt.log_base2k = ct.log_base2k(); - pt.log_k = min(pt.log_k(), ct.log_k()); -} - -pub fn decrypt_rlwe_dft_scratch_bytes(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() - | module.bytes_of_vec_znx_dft(1, size) - | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) - + module.bytes_of_vec_znx_big(1, size) + pt.log_k = pt.log_k().min(ct.log_k()); } impl RLWECtDft { @@ -290,12 +448,6 @@ impl RLWECtDft { } } -pub fn encrypt_rlwe_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { - ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) - + module.bytes_of_scalar_znx_dft(1) - + module.vec_znx_big_normalize_tmp_bytes() -} - pub(crate) fn encrypt_rlwe_pk( module: &Module, ct: &mut RLWECt, @@ -369,13 +521,10 @@ mod tests { use sampling::source::Source; use crate::{ - elem::{Infos, RLWECt, RLWECtDft, RLWEPt}, - encryption::{decrypt_rlwe_dft_scratch_bytes, encrypt_zero_rlwe_dft_scratch_bytes}, + elem_rlwe::{Infos, RLWECt, RLWECtDft, RLWEPt}, keys::{PublicKey, SecretKey, SecretKeyDft}, }; - use super::{decrypt_rlwe_scratch_bytes, encrypt_rlwe_pk_scratch_bytes, encrypt_rlwe_sk_scratch_bytes}; - #[test] fn encrypt_sk_vec_znx_fft64() { let module: Module = Module::::new(32); @@ -393,8 +542,9 @@ mod tests { let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = - ScratchOwned::new(encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) | decrypt_rlwe_scratch_bytes(&module, ct.size())); + let mut scratch: ScratchOwned = ScratchOwned::new( + RLWECt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECt::decrypt_scratch_bytes(&module, ct.size()), + ); let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); @@ -469,9 +619,8 @@ mod tests { let mut ct_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2); let mut scratch: ScratchOwned = ScratchOwned::new( - encrypt_rlwe_sk_scratch_bytes(&module, ct_dft.size()) - | decrypt_rlwe_dft_scratch_bytes(&module, ct_dft.size()) - | encrypt_zero_rlwe_dft_scratch_bytes(&module, ct_dft.size()), + RLWECtDft::decrypt_scratch_bytes(&module, ct_dft.size()) + | RLWECtDft::encrypt_zero_sk_scratch_bytes(&module, ct_dft.size()), ); ct_dft.encrypt_zero_sk( @@ -523,9 +672,9 @@ mod tests { ); let mut scratch: ScratchOwned = ScratchOwned::new( - encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) - | decrypt_rlwe_scratch_bytes(&module, ct.size()) - | encrypt_rlwe_pk_scratch_bytes(&module, pk.size()), + RLWECt::encrypt_sk_scratch_bytes(&module, ct.size()) + | RLWECt::decrypt_scratch_bytes(&module, ct.size()) + | RLWECt::encrypt_pk_scratch_bytes(&module, pk.size()), ); let mut data_want: Vec = vec![0i64; module.n()]; diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index 89c33e3..2f7b2c7 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -4,10 +4,7 @@ use base2k::{ }; use sampling::source::Source; -use crate::{ - elem::{Infos, RLWECtDft}, - encryption::encrypt_zero_rlwe_dft_scratch_bytes, -}; +use crate::{elem::Infos, elem_rlwe::RLWECtDft}; #[derive(Clone, Copy, Debug)] pub enum SecretDistribution { @@ -182,7 +179,10 @@ impl PublicKey { } // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::new(encrypt_zero_rlwe_dft_scratch_bytes(module, self.size())); + let mut scratch: ScratchOwned = ScratchOwned::new(RLWECtDft::encrypt_zero_sk_scratch_bytes( + module, + self.size(), + )); self.data.encrypt_zero_sk( module, sk_dft, diff --git a/rlwe/src/lib.rs b/rlwe/src/lib.rs index 023acb5..9eea116 100644 --- a/rlwe/src/lib.rs +++ b/rlwe/src/lib.rs @@ -1,3 +1,6 @@ pub mod elem; -pub mod encryption; +pub mod elem_grlwe; +pub mod elem_rgsw; +pub mod elem_rlwe; pub mod keys; +mod utils; diff --git a/rlwe/src/utils.rs b/rlwe/src/utils.rs new file mode 100644 index 0000000..0bb0b45 --- /dev/null +++ b/rlwe/src/utils.rs @@ -0,0 +1,3 @@ +pub(crate) fn derive_size(log_base2k: usize, log_k: usize) -> usize { + (log_k + log_base2k - 1) / log_base2k +} From 1f384ce54dd8223ff9f7c84dac80da696da34747 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 8 May 2025 15:21:24 +0200 Subject: [PATCH 48/87] Added vec_znx_add/sub_scalar & available on Scratch --- base2k/src/lib.rs | 24 ++++++++---- base2k/src/vec_znx_ops.rs | 77 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 10 deletions(-) diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 450a69f..bb8ce55 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -150,19 +150,27 @@ impl Scratch { unsafe { &mut *(data as *mut [u8] as *mut Self) } } - fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { - let ptr = data.as_mut_ptr(); - let self_len = data.len(); + #[allow(dead_code)] + fn available(&self) -> usize { + let ptr: *const u8 = self.data.as_ptr(); + let self_len: usize = self.data.len(); + let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); + self_len.saturating_sub(aligned_offset) + } - let aligned_offset = ptr.align_offset(DEFAULTALIGN); - let aligned_len = self_len.saturating_sub(aligned_offset); + fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { + let ptr: *mut u8 = data.as_mut_ptr(); + let self_len: usize = data.len(); + + let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); + let aligned_len: usize = self_len.saturating_sub(aligned_offset); if let Some(rem_len) = aligned_len.checked_sub(take_len) { unsafe { - let rem_ptr = ptr.add(aligned_offset).add(take_len); - let rem_slice = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); + let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len); + let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); - let take_slice = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len); + let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len); return (take_slice, rem_slice); } diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index c80e9f1..f57e99f 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,6 +1,7 @@ use crate::ffi::vec_znx; use crate::{ - Backend, Module, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, + Backend, Module, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, + ZnxViewMut, ZnxZero, }; use itertools::izip; use std::cmp::min; @@ -51,12 +52,18 @@ pub trait VecZnxOps { A: VecZnxToRef, B: VecZnxToRef; - /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`. + /// Adds the selected column of `a` to the selected column of `res` and writes the result on the selected column of `res`. fn vec_znx_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef; + /// Adds the selected column of `a` on the selected column and limb of `res`. + fn vec_znx_add_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef; + /// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`. fn vec_znx_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) where @@ -80,6 +87,12 @@ pub trait VecZnxOps { R: VecZnxToMut, A: VecZnxToRef; + /// Subtracts the selected column of `a` on the selected column and limb of `res`. + fn vec_znx_sub_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef; + // Negates the selected column of `a` and stores the result in `res_col` of `res`. fn vec_znx_negate(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where @@ -282,6 +295,36 @@ impl VecZnxOps for Module { } } + fn vec_znx_add_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef, + { + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: crate::ScalarZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + ) + } + } + fn vec_znx_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) where R: VecZnxToMut, @@ -315,6 +358,36 @@ impl VecZnxOps for Module { } } + fn vec_znx_sub_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef, + { + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: crate::ScalarZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + ) + } + } + fn vec_znx_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, From 107e83c65c062f8a6c5761846fd05baa30c52d85 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 8 May 2025 15:35:21 +0200 Subject: [PATCH 49/87] Added grlwe encrypt + test --- rlwe/src/elem_grlwe.rs | 192 ++++++++++++++++++++++++++++++++++++++++- rlwe/src/elem_rlwe.rs | 29 ++++--- 2 files changed, 207 insertions(+), 14 deletions(-) diff --git a/rlwe/src/elem_grlwe.rs b/rlwe/src/elem_grlwe.rs index b269cb3..a0000cf 100644 --- a/rlwe/src/elem_grlwe.rs +++ b/rlwe/src/elem_grlwe.rs @@ -1,6 +1,16 @@ -use base2k::{Backend, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module}; +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxOps, + ZnxZero, +}; +use sampling::source::Source; -use crate::{elem::Infos, utils::derive_size}; +use crate::{ + elem::Infos, + elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, + keys::SecretKeyDft, + utils::derive_size, +}; pub struct GRLWECt { pub data: MatZnxDft, @@ -18,6 +28,18 @@ impl GRLWECt, B> { } } +impl GRLWECt +where + MatZnxDft: MatZnxDftToRef, +{ + pub fn get_row(&self, module: &Module, i: usize, res: &mut RLWECtDft) + where + VecZnxDft: VecZnxDftToMut, + { + module.vmp_extract_row(res, self, i, 0); + } +} + impl Infos for GRLWECt { type Inner = MatZnxDft; @@ -51,3 +73,169 @@ where self.data.to_ref() } } + +impl GRLWECt, FFT64> { + pub fn encrypt_sk_scratch_bytes(module: &Module, size: usize) -> usize { + RLWECt::encrypt_sk_scratch_bytes(module, size) + + module.bytes_of_vec_znx(2, size) + + module.bytes_of_vec_znx(1, size) + + module.bytes_of_vec_znx_dft(2, size) + } + + // pub fn encrypt_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { + // RLWECt::encrypt_pk_scratch_bytes(module, pk_size) + // } +} + +pub fn encrypt_grlwe_sk( + module: &Module, + ct: &mut GRLWECt, + pt: &ScalarZnx

, + sk: &SecretKeyDft, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + MatZnxDft: MatZnxDftToMut, + ScalarZnx

: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let rows: usize = ct.rows(); + let size: usize = ct.size(); + let log_base2k: usize = ct.log_base2k(); + + let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); + let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size); + let (mut vec_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size); + + let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt { + data: tmp_znx_pt, + log_base2k: log_base2k, + log_k: ct.log_k(), + }; + + let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt { + data: tmp_znx_ct, + log_base2k: log_base2k, + log_k: ct.log_k(), + }; + + (0..rows).for_each(|row_i| { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0); + module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scratch_3); + + // rlwe encrypt of vec_znx_pt into vec_znx_ct + vec_znx_ct.encrypt_sk( + module, + Some(&vec_znx_pt), + sk, + source_xa, + source_xe, + sigma, + bound, + scratch_3, + ); + + vec_znx_pt.data.zero(); // zeroes for next iteration + + // Switch vec_znx_ct into DFT domain + module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0); + module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1); + + // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft + module.vmp_prepare_row(ct, row_i, 0, &vec_znx_dft_ct); + }); +} + +impl GRLWECt { + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx

, + sk_dft: &SecretKeyDft, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToMut, + ScalarZnx

: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + encrypt_grlwe_sk( + module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ) + } +} + +#[cfg(test)] +mod tests { + use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps}; + use sampling::source::Source; + + use crate::{ + elem::Infos, + elem_rlwe::{RLWECtDft, RLWEPt}, + keys::{SecretKey, SecretKeyDft}, + }; + + use super::GRLWECt; + + #[test] + fn encrypt_sk_vec_znx_fft64() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let rows: usize = 4; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECtDft::decrypt_scratch_bytes(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2); + + (0..ct.rows()).for_each(|row_i| { + ct.get_row(&module, row_i, &mut ct_rlwe_dft); + ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); + let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); + assert!((sigma - std_pt) <= 0.2, "{} {}", sigma, std_pt); + }); + + module.free(); + } +} diff --git a/rlwe/src/elem_rlwe.rs b/rlwe/src/elem_rlwe.rs index 8a7d444..19b5496 100644 --- a/rlwe/src/elem_rlwe.rs +++ b/rlwe/src/elem_rlwe.rs @@ -181,7 +181,7 @@ pub fn encrypt_rlwe_sk( module: &Module, ct: &mut RLWECt, pt: Option<&RLWEPt

>, - sk: &SecretKeyDft, + sk_dft: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -206,7 +206,7 @@ pub fn encrypt_rlwe_sk( module.vec_znx_dft(&mut c0_dft, 0, ct, 1); // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, sk, 0); + module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); // c0_big = IDFT(c0_dft) module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); @@ -227,7 +227,7 @@ pub fn decrypt_rlwe( module: &Module, pt: &mut RLWEPt

, ct: &RLWECt, - sk: &SecretKeyDft, + sk_dft: &SecretKeyDft, scratch: &mut Scratch, ) where VecZnx

: VecZnxToMut + VecZnxToRef, @@ -241,7 +241,7 @@ pub fn decrypt_rlwe( module.vec_znx_dft(&mut c0_dft, 0, ct, 1); // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, sk, 0); + module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); // c0_big = IDFT(c0_dft) module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); @@ -262,7 +262,7 @@ impl RLWECt { &mut self, module: &Module, pt: Option<&RLWEPt

>, - sk: &SecretKeyDft, + sk_dft: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -274,17 +274,22 @@ impl RLWECt { ScalarZnxDft: ScalarZnxDftToRef, { encrypt_rlwe_sk( - module, self, pt, sk, source_xa, source_xe, sigma, bound, scratch, + module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, ) } - pub fn decrypt(&self, module: &Module, pt: &mut RLWEPt

, sk: &SecretKeyDft, scratch: &mut Scratch) - where + pub fn decrypt( + &self, + module: &Module, + pt: &mut RLWEPt

, + sk_dft: &SecretKeyDft, + scratch: &mut Scratch, + ) where VecZnx

: VecZnxToMut + VecZnxToRef, VecZnx: VecZnxToRef, ScalarZnxDft: ScalarZnxDftToRef, { - decrypt_rlwe(module, pt, self, sk, scratch); + decrypt_rlwe(module, pt, self, sk_dft, scratch); } pub fn encrypt_pk( @@ -526,7 +531,7 @@ mod tests { }; #[test] - fn encrypt_sk_vec_znx_fft64() { + fn encrypt_sk_fft64() { let module: Module = Module::::new(32); let log_base2k: usize = 8; let log_k_ct: usize = 54; @@ -597,7 +602,7 @@ mod tests { } #[test] - fn encrypt_zero_rlwe_dft_sk_fft64() { + fn encrypt_zero_sk_fft64() { let module: Module = Module::::new(1024); let log_base2k: usize = 8; let log_k_ct: usize = 55; @@ -639,7 +644,7 @@ mod tests { } #[test] - fn encrypt_pk_vec_znx_fft64() { + fn encrypt_pk_fft64() { let module: Module = Module::::new(32); let log_base2k: usize = 8; let log_k_ct: usize = 54; From de3b34477d4013dd588f1c6deac9cdf7a2de15b3 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 8 May 2025 18:32:19 +0200 Subject: [PATCH 50/87] added rgsw encrypt + test --- base2k/src/lib.rs | 4 +- base2k/src/scalar_znx.rs | 57 ++++++++- base2k/src/scalar_znx_dft_ops.rs | 206 +++++++++++++++---------------- base2k/src/vec_znx.rs | 6 +- base2k/src/vec_znx_big_ops.rs | 22 ++++ rlwe/src/elem_grlwe.rs | 8 +- rlwe/src/elem_rgsw.rs | 203 ++++++++++++++++++++++++------ rlwe/src/elem_rlwe.rs | 40 ++++-- 8 files changed, 384 insertions(+), 162 deletions(-) diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index bb8ce55..b6ed099 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -196,7 +196,7 @@ impl Scratch { } } - pub fn tmp_scalar(&mut self, module: &Module, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { + pub fn tmp_scalar_znx(&mut self, module: &Module, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols)); ( @@ -205,7 +205,7 @@ impl Scratch { ) } - pub fn tmp_scalar_dft(&mut self, module: &Module, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) { + pub fn tmp_scalar_znx_dft(&mut self, module: &Module, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) { let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols)); ( diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index dde286a..28ee38a 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -1,5 +1,5 @@ use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, Module, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned}; +use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; @@ -144,6 +144,17 @@ impl ScalarZnxToMut for ScalarZnx> { } } +impl VecZnxToMut for ScalarZnx>{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + VecZnx { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: 1, + } + } +} + impl ScalarZnxToRef for ScalarZnx> { fn to_ref(&self) -> ScalarZnx<&[u8]> { ScalarZnx { @@ -154,6 +165,17 @@ impl ScalarZnxToRef for ScalarZnx> { } } +impl VecZnxToRef for ScalarZnx>{ + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: 1, + } + } +} + impl ScalarZnxToMut for ScalarZnx<&mut [u8]> { fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> { ScalarZnx { @@ -164,6 +186,17 @@ impl ScalarZnxToMut for ScalarZnx<&mut [u8]> { } } +impl VecZnxToMut for ScalarZnx<&mut [u8]> { + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + } + } +} + impl ScalarZnxToRef for ScalarZnx<&mut [u8]> { fn to_ref(&self) -> ScalarZnx<&[u8]> { ScalarZnx { @@ -174,6 +207,17 @@ impl ScalarZnxToRef for ScalarZnx<&mut [u8]> { } } +impl VecZnxToRef for ScalarZnx<&mut [u8]> { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + } + } +} + impl ScalarZnxToRef for ScalarZnx<&[u8]> { fn to_ref(&self) -> ScalarZnx<&[u8]> { ScalarZnx { @@ -183,3 +227,14 @@ impl ScalarZnxToRef for ScalarZnx<&[u8]> { } } } + +impl VecZnxToRef for ScalarZnx<&[u8]> { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + } + } +} diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs index f02fa03..1e0313a 100644 --- a/base2k/src/scalar_znx_dft_ops.rs +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -1,103 +1,103 @@ -use crate::ffi::svp; -use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{ - Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft, - VecZnxDftToMut, VecZnxDftToRef, -}; - -pub trait ScalarZnxDftAlloc { - fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned; - fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; - fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned; -} - -pub trait ScalarZnxDftOps { - fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: ScalarZnxDftToMut, - A: ScalarZnxToRef; - fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxDftToMut, - A: ScalarZnxDftToRef, - B: VecZnxDftToRef; - fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: ScalarZnxDftToRef; -} - -impl ScalarZnxDftAlloc for Module { - fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned { - ScalarZnxDftOwned::new(self, cols) - } - - fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize { - ScalarZnxDftOwned::bytes_of(self, cols) - } - - fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned { - ScalarZnxDftOwned::new_from_bytes(self, cols, bytes) - } -} - -impl ScalarZnxDftOps for Module { - fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: ScalarZnxDftToMut, - A: ScalarZnxToRef, - { - unsafe { - svp::svp_prepare( - self.ptr, - res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t, - a.to_ref().at_ptr(a_col, 0), - ) - } - } - - fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxDftToMut, - A: ScalarZnxDftToRef, - B: VecZnxDftToRef, - { - let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); - let b: VecZnxDft<&[u8], FFT64> = b.to_ref(); - unsafe { - svp::svp_apply_dft_to_dft( - self.ptr, - res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, - res.size() as u64, - res.cols() as u64, - a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, - b.at_ptr(b_col, 0) as *const vec_znx_dft_t, - b.size() as u64, - b.cols() as u64, - ) - } - } - - fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: ScalarZnxDftToRef, - { - let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); - unsafe { - svp::svp_apply_dft_to_dft( - self.ptr, - res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, - res.size() as u64, - res.cols() as u64, - a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, - res.at_ptr(res_col, 0) as *const vec_znx_dft_t, - res.size() as u64, - res.cols() as u64, - ) - } - } -} +use crate::ffi::svp; +use crate::ffi::vec_znx_dft::vec_znx_dft_t; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; +use crate::{ + Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft, + VecZnxDftToMut, VecZnxDftToRef, +}; + +pub trait ScalarZnxDftAlloc { + fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned; + fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; + fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned; +} + +pub trait ScalarZnxDftOps { + fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxDftToMut, + A: ScalarZnxToRef; + fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + B: VecZnxDftToRef; + fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef; +} + +impl ScalarZnxDftAlloc for Module { + fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned { + ScalarZnxDftOwned::new(self, cols) + } + + fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize { + ScalarZnxDftOwned::bytes_of(self, cols) + } + + fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned { + ScalarZnxDftOwned::new_from_bytes(self, cols, bytes) + } +} + +impl ScalarZnxDftOps for Module { + fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxDftToMut, + A: ScalarZnxToRef, + { + unsafe { + svp::svp_prepare( + self.ptr, + res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t, + a.to_ref().at_ptr(a_col, 0), + ) + } + } + + fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + B: VecZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); + let b: VecZnxDft<&[u8], FFT64> = b.to_ref(); + unsafe { + svp::svp_apply_dft_to_dft( + self.ptr, + res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, + res.size() as u64, + res.cols() as u64, + a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, + b.at_ptr(b_col, 0) as *const vec_znx_dft_t, + b.size() as u64, + b.cols() as u64, + ) + } + } + + fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); + unsafe { + svp::svp_apply_dft_to_dft( + self.ptr, + res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, + res.size() as u64, + res.cols() as u64, + a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, + res.at_ptr(res_col, 0) as *const vec_znx_dft_t, + res.size() as u64, + res.cols() as u64, + ) + } + } +} diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 70d8fb3..31459d4 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -20,9 +20,9 @@ use std::{cmp::min, fmt}; /// are small polynomials of Zn\[X\]. pub struct VecZnx { pub data: D, - n: usize, - cols: usize, - size: usize, + pub n: usize, + pub cols: usize, + pub size: usize, } impl ZnxInfos for VecZnx { diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 933deb3..809a1eb 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -114,6 +114,9 @@ pub trait VecZnxBigOps { R: VecZnxBigToMut, A: VecZnxToRef; + /// Negates `a` inplace. + fn vec_znx_big_negate_inplace(&self, a: &mut A, a_col: usize) where A: VecZnxBigToMut; + /// Normalizes `a` and stores the result on `b`. /// /// # Arguments @@ -503,6 +506,25 @@ impl VecZnxBigOps for Module { } } + fn vec_znx_big_negate_inplace(&self, a: &mut A, res_col: usize) where A: VecZnxBigToMut { + let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_negate( + self.ptr, + a.at_mut_ptr(res_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(res_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + fn vec_znx_big_normalize( &self, log_base2k: usize, diff --git a/rlwe/src/elem_grlwe.rs b/rlwe/src/elem_grlwe.rs index a0000cf..a460ec4 100644 --- a/rlwe/src/elem_grlwe.rs +++ b/rlwe/src/elem_grlwe.rs @@ -91,7 +91,7 @@ pub fn encrypt_grlwe_sk( module: &Module, ct: &mut GRLWECt, pt: &ScalarZnx

, - sk: &SecretKeyDft, + sk_dft: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -131,7 +131,7 @@ pub fn encrypt_grlwe_sk( vec_znx_ct.encrypt_sk( module, Some(&vec_znx_pt), - sk, + sk_dft, source_xa, source_xe, sigma, @@ -186,7 +186,7 @@ mod tests { use super::GRLWECt; #[test] - fn encrypt_sk_vec_znx_fft64() { + fn encrypt_sk_fft64() { let module: Module = Module::::new(2048); let log_base2k: usize = 8; let log_k_ct: usize = 54; @@ -233,7 +233,7 @@ mod tests { ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); - assert!((sigma - std_pt) <= 0.2, "{} {}", sigma, std_pt); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); }); module.free(); diff --git a/rlwe/src/elem_rgsw.rs b/rlwe/src/elem_rgsw.rs index 1a1ea24..75d6583 100644 --- a/rlwe/src/elem_rgsw.rs +++ b/rlwe/src/elem_rgsw.rs @@ -1,13 +1,13 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, - ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxDftOps, ZnxView, ZnxViewMut, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxOps, + ZnxZero, }; use sampling::source::Source; use crate::{ elem::Infos, - elem_grlwe::GRLWECt, - elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, + elem_rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, keys::SecretKeyDft, utils::derive_size, }; @@ -62,28 +62,32 @@ where } } -impl GRLWECt, FFT64> { +impl RGSWCt, FFT64> { pub fn encrypt_sk_scratch_bytes(module: &Module, size: usize) -> usize { RLWECt::encrypt_sk_scratch_bytes(module, size) + module.bytes_of_vec_znx(2, size) + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) } +} - pub fn encrypt_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { - RLWECt::encrypt_pk_scratch_bytes(module, pk_size) - } - - pub fn decrypt_scratch_bytes(module: &Module, size: usize) -> usize { - RLWECtDft::decrypt_scratch_bytes(module, size) +impl RGSWCt +where + MatZnxDft: MatZnxDftToRef, +{ + pub fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + where + VecZnxDft: VecZnxDftToMut, + { + module.vmp_extract_row(res, self, row_i, col_j); } } -pub fn encrypt_grlwe_sk( +pub fn encrypt_rgsw_sk( module: &Module, - ct: &mut GRLWECt, + ct: &mut RGSWCt, pt: &ScalarZnx

, - sk: &SecretKeyDft, + sk_dft: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -94,47 +98,164 @@ pub fn encrypt_grlwe_sk( ScalarZnx

: ScalarZnxToRef, ScalarZnxDft: ScalarZnxDftToRef, { - let rows: usize = ct.rows(); let size: usize = ct.size(); + let log_base2k: usize = ct.log_base2k(); - let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); - let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size); - let (mut tmp_dft, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size); + let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size); + let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, 2, size); - let mut tmp_pt: RLWEPt<&mut [u8]> = RLWEPt { + let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt { data: tmp_znx_pt, - log_base2k: ct.log_base2k(), + log_base2k: log_base2k, log_k: ct.log_k(), }; - let mut tmp_ct: RLWECt<&mut [u8]> = RLWECt { + let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt { data: tmp_znx_ct, - log_base2k: ct.log_base2k(), + log_base2k: log_base2k, log_k: ct.log_k(), }; - (0..rows).for_each(|row_i| { - tmp_pt - .data - .at_mut(0, row_i) - .copy_from_slice(&pt.to_ref().raw()); + (0..ct.rows()).for_each(|row_j| { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0); + module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2); - tmp_ct.encrypt_sk( - module, - Some(&tmp_pt), - sk, - source_xa, - source_xe, - sigma, - bound, - scratch_3, - ); + (0..ct.cols()).for_each(|col_i| { + // rlwe encrypt of vec_znx_pt into vec_znx_ct + encrypt_rlwe_sk( + module, + &mut vec_znx_ct, + Some((&vec_znx_pt, col_i)), + sk_dft, + source_xa, + source_xe, + sigma, + bound, + scrach_2, + ); - tmp_pt.data.at_mut(0, row_i).fill(0); + // Switch vec_znx_ct into DFT domain + { + let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, 2, size); + module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0); + module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1); + module.vmp_prepare_row(ct, row_j, col_i, &vec_znx_dft_ct); + } + }); - module.vec_znx_dft(&mut tmp_dft, 0, &tmp_ct, 0); - module.vec_znx_dft(&mut tmp_dft, 1, &tmp_ct, 1); - - module.vmp_prepare_row(ct, row_i, 0, &tmp_dft); + vec_znx_pt.data.zero(); // zeroes for next iteration }); } + +impl RGSWCt { + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx

, + sk_dft: &SecretKeyDft, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToMut, + ScalarZnx

: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + encrypt_rgsw_sk( + module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ) + } +} + +#[cfg(test)] +mod tests { + use base2k::{ + FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero, + }; + use sampling::source::Source; + + use crate::{ + elem::Infos, + elem_rlwe::{RLWECtDft, RLWEPt}, + keys::{SecretKey, SecretKeyDft}, + }; + + use super::RGSWCt; + + #[test] + fn encrypt_rgsw_sk_fft64() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let rows: usize = 4; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECtDft::decrypt_scratch_bytes(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); + + (0..ct.cols()).for_each(|col_j| { + (0..ct.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + if col_j == 1 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); + + ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + + pt_want.data.zero(); + }); + }); + + module.free(); + } +} diff --git a/rlwe/src/elem_rlwe.rs b/rlwe/src/elem_rlwe.rs index 19b5496..938b3c5 100644 --- a/rlwe/src/elem_rlwe.rs +++ b/rlwe/src/elem_rlwe.rs @@ -180,7 +180,7 @@ impl RLWECt> { pub fn encrypt_rlwe_sk( module: &Module, ct: &mut RLWECt, - pt: Option<&RLWEPt

>, + pt: Option<(&RLWEPt

, usize)>, sk_dft: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, @@ -213,8 +213,18 @@ pub fn encrypt_rlwe_sk( } // c0_big = m - c0_big - if let Some(pt) = pt { - module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0); + if let Some((pt, col)) = pt { + match col { + 0 => module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0), + 1 => { + module.vec_znx_big_negate_inplace(&mut c0_big, 0); + module.vec_znx_add_inplace(ct, 1, pt, 0); + module.vec_znx_normalize_inplace(log_base2k, ct, 1, scratch_1); + } + _ => panic!("invalid target column: {}", col), + } + } else { + module.vec_znx_big_negate_inplace(&mut c0_big, 0); } // c0_big += e c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); @@ -273,9 +283,23 @@ impl RLWECt { VecZnx

: VecZnxToRef, ScalarZnxDft: ScalarZnxDftToRef, { - encrypt_rlwe_sk( - module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ) + if let Some(pt) = pt { + encrypt_rlwe_sk( + module, + self, + Some((pt, 0)), + sk_dft, + source_xa, + source_xe, + sigma, + bound, + scratch, + ) + } else { + encrypt_rlwe_sk::( + module, self, None, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ) + } } pub fn decrypt( @@ -483,10 +507,10 @@ pub(crate) fn encrypt_rlwe_pk( let size_pk: usize = pk.size(); // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.tmp_scalar_dft(module, 1); + let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); { - let (mut u, _) = scratch_1.tmp_scalar(module, 1); + let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); match pk.dist { SecretDistribution::NONE => panic!( "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate" From 9913040aa1b2b3a92041fa72c426816b330a9b32 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 9 May 2025 10:39:00 +0200 Subject: [PATCH 51/87] Added grlwe ops + tests --- base2k/src/scalar_znx.rs | 8 +- base2k/src/vec_znx.rs | 34 ++ base2k/src/vec_znx_big.rs | 35 +- base2k/src/vec_znx_big_ops.rs | 9 +- base2k/src/vec_znx_dft.rs | 37 +- base2k/src/vec_znx_dft_ops.rs | 27 +- base2k/src/znx_base.rs | 8 +- rlwe/src/elem_grlwe.rs | 265 ++++++++--- rlwe/src/elem_rgsw.rs | 94 +--- rlwe/src/elem_rlwe.rs | 282 +++--------- rlwe/src/keys.rs | 11 +- rlwe/src/lib.rs | 1 + rlwe/src/test_fft64/elem_grlwe.rs | 722 ++++++++++++++++++++++++++++++ rlwe/src/test_fft64/elem_rgsw.rs | 88 ++++ rlwe/src/test_fft64/elem_rlwe.rs | 196 ++++++++ rlwe/src/test_fft64/mod.rs | 3 + 16 files changed, 1435 insertions(+), 385 deletions(-) create mode 100644 rlwe/src/test_fft64/elem_grlwe.rs create mode 100644 rlwe/src/test_fft64/elem_rgsw.rs create mode 100644 rlwe/src/test_fft64/elem_rlwe.rs create mode 100644 rlwe/src/test_fft64/mod.rs diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index 28ee38a..108ba3f 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -1,5 +1,7 @@ use crate::znx_base::ZnxInfos; -use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut}; +use crate::{ + Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned, +}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; @@ -144,7 +146,7 @@ impl ScalarZnxToMut for ScalarZnx> { } } -impl VecZnxToMut for ScalarZnx>{ +impl VecZnxToMut for ScalarZnx> { fn to_mut(&mut self) -> VecZnx<&mut [u8]> { VecZnx { data: self.data.as_mut_slice(), @@ -165,7 +167,7 @@ impl ScalarZnxToRef for ScalarZnx> { } } -impl VecZnxToRef for ScalarZnx>{ +impl VecZnxToRef for ScalarZnx> { fn to_ref(&self) -> VecZnx<&[u8]> { VecZnx { data: self.data.as_slice(), diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 31459d4..b945b2c 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,6 +1,7 @@ use crate::DataView; use crate::DataViewMut; use crate::ZnxSliceSize; +use crate::ZnxZero; use crate::alloc_aligned; use crate::assert_alignement; use crate::cast_mut; @@ -182,6 +183,39 @@ fn normalize + AsRef<[u8]>>(log_base2k: usize, a: &mut VecZnx, } } +impl VecZnx +where + VecZnx: VecZnxToMut + ZnxInfos, +{ + /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. + pub fn extract_column(&mut self, self_col: usize, a: &VecZnx, a_col: usize) + where + VecZnx: VecZnxToRef + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert!(self_col < self.cols()); + assert!(a_col < a.cols()); + } + + let min_size: usize = self.size.min(a.size()); + let max_size: usize = self.size; + + let mut self_mut: VecZnx<&mut [u8]> = self.to_mut(); + let a_ref: VecZnx<&[u8]> = a.to_ref(); + + (0..min_size).for_each(|i: usize| { + self_mut + .at_mut(self_col, i) + .copy_from_slice(a_ref.at(a_col, i)); + }); + + (min_size..max_size).for_each(|i| { + self_mut.zero_at(self_col, i); + }); + } +} + impl> fmt::Display for VecZnx { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index d8c1bdd..8b3223b 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,6 +1,6 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned}; use std::fmt; use std::marker::PhantomData; @@ -94,6 +94,39 @@ impl VecZnxBig { } } +impl VecZnxBig +where + VecZnxBig: VecZnxBigToMut + ZnxInfos, +{ + /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. + pub fn extract_column(&mut self, self_col: usize, a: &VecZnxBig, a_col: usize) + where + VecZnxBig: VecZnxBigToRef + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert!(self_col < self.cols()); + assert!(a_col < a.cols()); + } + + let min_size: usize = self.size.min(a.size()); + let max_size: usize = self.size; + + let mut self_mut: VecZnxBig<&mut [u8], FFT64> = self.to_mut(); + let a_ref: VecZnxBig<&[u8], FFT64> = a.to_ref(); + + (0..min_size).for_each(|i: usize| { + self_mut + .at_mut(self_col, i) + .copy_from_slice(a_ref.at(a_col, i)); + }); + + (min_size..max_size).for_each(|i| { + self_mut.zero_at(self_col, i); + }); + } +} + pub type VecZnxBigOwned = VecZnxBig, B>; pub trait VecZnxBigToRef { diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 809a1eb..8208c97 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -115,7 +115,9 @@ pub trait VecZnxBigOps { A: VecZnxToRef; /// Negates `a` inplace. - fn vec_znx_big_negate_inplace(&self, a: &mut A, a_col: usize) where A: VecZnxBigToMut; + fn vec_znx_big_negate_inplace(&self, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut; /// Normalizes `a` and stores the result on `b`. /// @@ -506,7 +508,10 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_negate_inplace(&self, a: &mut A, res_col: usize) where A: VecZnxBigToMut { + fn vec_znx_big_negate_inplace(&self, a: &mut A, res_col: usize) + where + A: VecZnxBigToMut, + { let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); #[cfg(debug_assertions)] { diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 61e1be5..b4bc973 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -2,7 +2,9 @@ use std::marker::PhantomData; use crate::ffi::vec_znx_dft; use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned}; +use crate::{ + Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, alloc_aligned, +}; use std::fmt; pub struct VecZnxDft { @@ -89,6 +91,39 @@ impl>, B: Backend> VecZnxDft { } } +impl VecZnxDft +where + VecZnxDft: VecZnxDftToMut + ZnxInfos, +{ + /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. + pub fn extract_column(&mut self, self_col: usize, a: &VecZnxDft, a_col: usize) + where + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert!(self_col < self.cols()); + assert!(a_col < a.cols()); + } + + let min_size: usize = self.size.min(a.size()); + let max_size: usize = self.size; + + let mut self_mut: VecZnxDft<&mut [u8], FFT64> = self.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + (0..min_size).for_each(|i: usize| { + self_mut + .at_mut(self_col, i) + .copy_from_slice(a_ref.at(a_col, i)); + }); + + (min_size..max_size).for_each(|i| { + self_mut.zero_at(self_col, i); + }); + } +} + pub type VecZnxDftOwned = VecZnxDft, B>; impl VecZnxDft { diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index cf06cc2..282ef4d 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -47,7 +47,9 @@ pub trait VecZnxDftOps { where R: VecZnxBigToMut, A: VecZnxDftToMut; - fn vec_znx_idft_consume(&self, a: VecZnxDft, a_cols: usize) -> VecZnxBig + + /// Consumes a to return IDFT(a) in big coeff space. + fn vec_znx_idft_consume(&self, a: VecZnxDft) -> VecZnxBig where VecZnxDft: VecZnxDftToMut; @@ -103,25 +105,28 @@ impl VecZnxDftOps for Module { } } - fn vec_znx_idft_consume(&self, mut a: VecZnxDft, a_col: usize) -> VecZnxBig + fn vec_znx_idft_consume(&self, mut a: VecZnxDft) -> VecZnxBig where VecZnxDft: VecZnxDftToMut, { let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); unsafe { + // Rev col and rows because ZnxDft.sl() >= ZnxBig.sl() (0..a_mut.size()).for_each(|j| { - vec_znx_dft::vec_znx_idft_tmp_a( - self.ptr, - a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_big::vec_znx_big_t, - 1 as u64, - a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, - 1 as u64, - ) + (0..a_mut.cols()).for_each(|i| { + vec_znx_dft::vec_znx_idft_tmp_a( + self.ptr, + a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t, + 1 as u64, + a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + ) + }); }); - - a.into_big() } + + a.into_big() } fn vec_znx_idft_tmp_bytes(&self) -> usize { diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index 94da450..a168e18 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -101,25 +101,25 @@ pub trait ZnxViewMut: ZnxView + DataViewMut> { //(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known impl ZnxViewMut for T where T: ZnxView + DataViewMut> {} -pub trait ZnxZero: ZnxViewMut +pub trait ZnxZero: ZnxViewMut + ZnxSliceSize where Self: Sized, { fn zero(&mut self) { unsafe { - std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count()); + std::ptr::write_bytes(self.as_mut_ptr(), 0, self.sl() * self.poly_count()); } } fn zero_at(&mut self, i: usize, j: usize) { unsafe { - std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n()); + std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.sl()); } } } // Blanket implementations -impl ZnxZero for T where T: ZnxViewMut {} +impl ZnxZero for T where T: ZnxViewMut + ZnxSliceSize {} // WARNING should not work for mat_znx_dft but it does use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; diff --git a/rlwe/src/elem_grlwe.rs b/rlwe/src/elem_grlwe.rs index a460ec4..b865c1e 100644 --- a/rlwe/src/elem_grlwe.rs +++ b/rlwe/src/elem_grlwe.rs @@ -1,6 +1,7 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, - ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxOps, + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, + ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, }; use sampling::source::Source; @@ -32,11 +33,23 @@ impl GRLWECt where MatZnxDft: MatZnxDftToRef, { - pub fn get_row(&self, module: &Module, i: usize, res: &mut RLWECtDft) + pub fn get_row(&self, module: &Module, row_i: usize, res: &mut RLWECtDft) where - VecZnxDft: VecZnxDftToMut, + VecZnxDft: VecZnxDftToMut, { - module.vmp_extract_row(res, self, i, 0); + module.vmp_extract_row(res, self, row_i, 0); + } +} + +impl GRLWECt +where + MatZnxDft: MatZnxDftToMut, +{ + pub fn set_row(&mut self, module: &Module, row_i: usize, a: &RLWECtDft) + where + VecZnxDft: VecZnxDftToRef, + { + module.vmp_prepare_row(self, row_i, 0, a); } } @@ -75,16 +88,42 @@ where } impl GRLWECt, FFT64> { - pub fn encrypt_sk_scratch_bytes(module: &Module, size: usize) -> usize { - RLWECt::encrypt_sk_scratch_bytes(module, size) + pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { + RLWECt::encrypt_sk_scratch_space(module, size) + module.bytes_of_vec_znx(2, size) + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) } - // pub fn encrypt_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { - // RLWECt::encrypt_pk_scratch_bytes(module, pk_size) - // } + pub fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + module.bytes_of_vec_znx_dft(2, grlwe_size) + + (module.vec_znx_big_normalize_tmp_bytes() + | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) + + module.bytes_of_vec_znx_dft(1, a_size))) + } + + pub fn mul_rlwe_inplace_scratch_space(module: &Module, res_size: usize, grlwe_size: usize) -> usize { + Self::mul_rlwe_scratch_space(module, res_size, res_size, grlwe_size) + } + + pub fn mul_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + (Self::mul_rlwe_scratch_space(module, res_size, a_size, grlwe_size) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(2, a_size) + + module.bytes_of_vec_znx(2, res_size) + } + + pub fn mul_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, grlwe_size: usize) -> usize { + (Self::mul_rlwe_inplace_scratch_space(module, res_size, grlwe_size) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(2, res_size) + } + + pub fn mul_grlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) + } + + pub fn mul_grlwe_inplace_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) + } } pub fn encrypt_grlwe_sk( @@ -170,72 +209,176 @@ impl GRLWECt { module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, ) } -} -#[cfg(test)] -mod tests { - use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps}; - use sampling::source::Source; + pub fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef, + { + let log_base2k: usize = self.log_base2k(); - use crate::{ - elem::Infos, - elem_rlwe::{RLWECtDft, RLWEPt}, - keys::{SecretKey, SecretKeyDft}, - }; + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(a.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), module.n()); + } - use super::GRLWECt; + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise - #[test] - fn encrypt_sk_fft64() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let rows: usize = 4; + { + let (mut a1_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, a.size()); + module.vec_znx_dft(&mut a1_dft, 0, a, 1); + module.vmp_apply(&mut res_dft, &a1_dft, self, scratch2); + } - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; + let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - let mut ct: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); + module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); + module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); + } - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnx: VecZnxToMut + VecZnxToRef, + { + unsafe { + let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. + self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); + } + } - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECtDft::decrypt_scratch_bytes(&module, ct.size()), - ); + pub fn mul_rlwe_dft( + &self, + module: &Module, + res: &mut RLWECtDft, + a: &RLWECtDft, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { + let log_base2k: usize = self.log_base2k(); - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); + let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); + let mut a_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: a_data, + log_base2k: a.log_base2k(), + log_k: a.log_k(), + }; - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2); + a.idft(module, &mut a_idft, scratch_1); - (0..ct.rows()).for_each(|row_i| { - ct.get_row(&module, row_i, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); - let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: res_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + { + let log_base2k: usize = self.log_base2k(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: res_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + res.idft(module, &mut res_idft, scratch_1); + + self.mul_rlwe_inplace(module, &mut res_idft, scratch_1); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + pub fn mul_grlwe( + &self, + module: &Module, + res: &mut GRLWECt, + a: &GRLWECt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToRef + ZnxInfos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: a.log_base2k(), + log_k: a.log_k(), + }; + + let min_rows: usize = res.rows().min(a.rows()); + + (0..min_rows).for_each(|row_i| { + a.get_row(module, row_i, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, &tmp_row); }); - module.free(); + tmp_row.data.zero(); + + (min_rows..res.rows()).for_each(|row_i| { + res.set_row(module, row_i, &tmp_row); + }) + } + + pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + (0..res.rows()).for_each(|row_i| { + res.get_row(module, row_i, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, &tmp_row); + }); } } diff --git a/rlwe/src/elem_rgsw.rs b/rlwe/src/elem_rgsw.rs index 75d6583..eab378a 100644 --- a/rlwe/src/elem_rgsw.rs +++ b/rlwe/src/elem_rgsw.rs @@ -63,8 +63,8 @@ where } impl RGSWCt, FFT64> { - pub fn encrypt_sk_scratch_bytes(module: &Module, size: usize) -> usize { - RLWECt::encrypt_sk_scratch_bytes(module, size) + pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { + RLWECt::encrypt_sk_scratch_space(module, size) + module.bytes_of_vec_znx(2, size) + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) @@ -169,93 +169,3 @@ impl RGSWCt { ) } } - -#[cfg(test)] -mod tests { - use base2k::{ - FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero, - }; - use sampling::source::Source; - - use crate::{ - elem::Infos, - elem_rlwe::{RLWECtDft, RLWEPt}, - keys::{SecretKey, SecretKeyDft}, - }; - - use super::RGSWCt; - - #[test] - fn encrypt_rgsw_sk_fft64() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let rows: usize = 4; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECtDft::decrypt_scratch_bytes(&module, ct.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); - - (0..ct.cols()).for_each(|col_j| { - (0..ct.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); - - if col_j == 1 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); - } - - ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); - - ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); - - pt_want.data.zero(); - }); - }); - - module.free(); - } -} diff --git a/rlwe/src/elem_rlwe.rs b/rlwe/src/elem_rlwe.rs index 938b3c5..54cb4f9 100644 --- a/rlwe/src/elem_rlwe.rs +++ b/rlwe/src/elem_rlwe.rs @@ -1,12 +1,13 @@ use base2k::{ - AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, - ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, - VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, + AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, + ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, + VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, }; use sampling::source::Source; use crate::{ elem::Infos, + elem_grlwe::GRLWECt, keys::{PublicKey, SecretDistribution, SecretKeyDft}, utils::derive_size, }; @@ -18,9 +19,9 @@ pub struct RLWECt { } impl RLWECt> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { Self { - data: module.new_vec_znx(cols, derive_size(log_base2k, log_k)), + data: module.new_vec_znx(2, derive_size(log_base2k, log_k)), log_base2k: log_base2k, log_k: log_k, } @@ -61,6 +62,27 @@ where } } +impl RLWECt +where + VecZnx: VecZnxToRef, +{ + #[allow(dead_code)] + pub(crate) fn dft(&self, module: &Module, res: &mut RLWECtDft) + where + VecZnxDft: VecZnxDftToMut + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.cols(), 2); + assert_eq!(res.cols(), 2); + assert_eq!(self.log_base2k(), res.log_base2k()) + } + + module.vec_znx_dft(res, 0, self, 0); + module.vec_znx_dft(res, 1, self, 1); + } +} + pub struct RLWEPt { pub data: VecZnx, pub log_base2k: usize, @@ -118,9 +140,9 @@ pub struct RLWECtDft { } impl RLWECtDft, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { Self { - data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_k)), + data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)), log_base2k: log_base2k, log_k: log_k, } @@ -161,18 +183,49 @@ where } } +impl RLWECtDft +where + VecZnxDft: VecZnxDftToRef, +{ + #[allow(dead_code)] + pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { + module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) + } + + pub(crate) fn idft(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + where + VecZnx: VecZnxToMut + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.cols(), 2); + assert_eq!(res.cols(), 2); + assert_eq!(self.log_base2k(), res.log_base2k()) + } + + let min_size: usize = self.size().min(res.size()); + + let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size); + + module.vec_znx_idft(&mut res_big, 0, &self.data, 0, scratch1); + module.vec_znx_idft(&mut res_big, 1, &self.data, 1, scratch1); + module.vec_znx_big_normalize(self.log_base2k(), res, 0, &res_big, 0, scratch1); + module.vec_znx_big_normalize(self.log_base2k(), res, 1, &res_big, 1, scratch1); + } +} + impl RLWECt> { - pub fn encrypt_sk_scratch_bytes(module: &Module, size: usize) -> usize { + pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) } - pub fn encrypt_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { + pub fn encrypt_pk_scratch_space(module: &Module, pk_size: usize) -> usize { ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) + module.bytes_of_scalar_znx_dft(1) + module.vec_znx_big_normalize_tmp_bytes() } - pub fn decrypt_scratch_bytes(module: &Module, size: usize) -> usize { + pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) } } @@ -393,14 +446,14 @@ pub(crate) fn encrypt_zero_rlwe_dft_sk( } impl RLWECtDft, FFT64> { - pub fn encrypt_zero_sk_scratch_bytes(module: &Module, size: usize) -> usize { + pub fn encrypt_zero_sk_scratch_space(module: &Module, size: usize) -> usize { (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + module.bytes_of_vec_znx(1, size) + module.vec_znx_big_normalize_tmp_bytes() } - pub fn decrypt_scratch_bytes(module: &Module, size: usize) -> usize { + pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size) | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) @@ -475,6 +528,14 @@ impl RLWECtDft { { decrypt_rlwe_dft(module, pt, self, sk_dft, scratch); } + + pub fn mul_grlwe_assign(&mut self, module: &Module, a: &GRLWECt, scratch: &mut Scratch) + where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + a.mul_rlwe_dft_inplace(module, self, scratch); + } } pub(crate) fn encrypt_rlwe_pk( @@ -517,6 +578,7 @@ pub(crate) fn encrypt_rlwe_pk( ), SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + SecretDistribution::ZERO => {} } module.svp_prepare(&mut u_dft, 0, &u, 0); @@ -542,199 +604,3 @@ pub(crate) fn encrypt_rlwe_pk( tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3); } - -#[cfg(test)] -mod tests { - use base2k::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero}; - use itertools::izip; - use sampling::source::Source; - - use crate::{ - elem_rlwe::{Infos, RLWECt, RLWECtDft, RLWEPt}, - keys::{PublicKey, SecretKey, SecretKeyDft}, - }; - - #[test] - fn encrypt_sk_fft64() { - let module: Module = Module::::new(32); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let log_k_pt: usize = 30; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct, 2); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_pt); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECt::decrypt_scratch_bytes(&module, ct.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - let mut data_want: Vec = vec![0i64; module.n()]; - - data_want - .iter_mut() - .for_each(|x| *x = source_xa.next_i64() & 0xFF); - - pt.data - .encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10); - - ct.encrypt_sk( - &module, - Some(&pt), - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - pt.data.zero(); - - ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - - let mut data_have: Vec = vec![0i64; module.n()]; - - pt.data - .decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have); - - // TODO: properly assert the decryption noise through std(dec(ct) - pt) - let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64; - izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { - let b_scaled = (*b as f64) / scale; - assert!( - (*a as f64 - b_scaled).abs() < 0.1, - "{} {}", - *a as f64, - b_scaled - ) - }); - - module.free(); - } - - #[test] - fn encrypt_zero_sk_fft64() { - let module: Module = Module::::new(1024); - let log_base2k: usize = 8; - let log_k_ct: usize = 55; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([1u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - let mut ct_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2); - - let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECtDft::decrypt_scratch_bytes(&module, ct_dft.size()) - | RLWECtDft::encrypt_zero_sk_scratch_bytes(&module, ct_dft.size()), - ); - - ct_dft.encrypt_zero_sk( - &module, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - - assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2); - module.free(); - } - - #[test] - fn encrypt_pk_fft64() { - let module: Module = Module::::new(32); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let log_k_pk: usize = 64; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct, 2); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - let mut source_xu: Source = Source::new([0u8; 32]); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - let mut pk: PublicKey, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk); - pk.generate( - &module, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - ); - - let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECt::encrypt_sk_scratch_bytes(&module, ct.size()) - | RLWECt::decrypt_scratch_bytes(&module, ct.size()) - | RLWECt::encrypt_pk_scratch_bytes(&module, pk.size()), - ); - - let mut data_want: Vec = vec![0i64; module.n()]; - - data_want - .iter_mut() - .for_each(|x| *x = source_xa.next_i64() & 0); - - pt_want - .data - .encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10); - - ct.encrypt_pk( - &module, - Some(&pt_want), - &pk, - &mut source_xu, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - - ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); - - assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2); - - module.free(); - } -} diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index 2f7b2c7..19fda01 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,6 +1,7 @@ use base2k::{ Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, ScratchOwned, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos, + ZnxZero, }; use sampling::source::Source; @@ -10,6 +11,7 @@ use crate::{elem::Infos, elem_rlwe::RLWECtDft}; pub enum SecretDistribution { TernaryFixed(usize), // Ternary with fixed Hamming weight TernaryProb(f64), // Ternary with probabilistic Hamming weight + ZERO, // Debug mod NONE, } @@ -40,6 +42,11 @@ where self.data.fill_ternary_hw(0, hw, source); self.dist = SecretDistribution::TernaryFixed(hw); } + + pub fn fill_zero(&mut self) { + self.data.zero(); + self.dist = SecretDistribution::ZERO; + } } impl ScalarZnxToMut for SecretKey @@ -117,7 +124,7 @@ pub struct PublicKey { impl PublicKey, B> { pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { Self { - data: RLWECtDft::new(module, log_base2k, log_k, 2), + data: RLWECtDft::new(module, log_base2k, log_k), dist: SecretDistribution::NONE, } } @@ -179,7 +186,7 @@ impl PublicKey { } // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::new(RLWECtDft::encrypt_zero_sk_scratch_bytes( + let mut scratch: ScratchOwned = ScratchOwned::new(RLWECtDft::encrypt_zero_sk_scratch_space( module, self.size(), )); diff --git a/rlwe/src/lib.rs b/rlwe/src/lib.rs index 9eea116..cad8dbc 100644 --- a/rlwe/src/lib.rs +++ b/rlwe/src/lib.rs @@ -3,4 +3,5 @@ pub mod elem_grlwe; pub mod elem_rgsw; pub mod elem_rlwe; pub mod keys; +mod test_fft64; mod utils; diff --git a/rlwe/src/test_fft64/elem_grlwe.rs b/rlwe/src/test_fft64/elem_grlwe.rs new file mode 100644 index 0000000..aa871f3 --- /dev/null +++ b/rlwe/src/test_fft64/elem_grlwe.rs @@ -0,0 +1,722 @@ +#[cfg(test)] + +mod test { + use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps}; + use sampling::source::Source; + + use crate::{ + elem::Infos, + elem_grlwe::GRLWECt, + elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, + keys::{SecretKey, SecretKeyDft}, + test_fft64::elem_grlwe::noise_grlwe_rlwe_product, + }; + + #[test] + fn encrypt_sk() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let rows: usize = 4; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + + (0..ct.rows()).for_each(|row_i| { + ct.get_row(&module, row_i, &mut ct_rlwe_dft); + ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); + let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + }); + + module.free(); + } + + #[test] + fn mul_rlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GRLWECt::mul_rlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_grlwe.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_grlwe.mul_rlwe(&module, &mut ct_rlwe_out, &ct_rlwe_in, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn mul_rlwe_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe: usize = 45; + let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GRLWECt::mul_rlwe_scratch_space(&module, ct_rlwe.size(), ct_rlwe.size(), ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_grlwe.mul_rlwe_inplace(&module, &mut ct_rlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn mul_rlwe_dft() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_in_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rlwe_out_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GRLWECt::mul_rlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_grlwe.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); + ct_grlwe.mul_rlwe_dft( + &module, + &mut ct_rlwe_out_dft, + &ct_rlwe_in_dft, + scratch.borrow(), + ); + ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn mul_rlwe_dft_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe: usize = 45; + let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GRLWECt::mul_rlwe_scratch_space(&module, ct_rlwe.size(), ct_rlwe.size(), ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.dft(&module, &mut ct_rlwe_dft); + ct_grlwe.mul_rlwe_dft_inplace(&module, &mut ct_rlwe_dft, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn mul_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s0s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) + | GRLWECt::mul_grlwe_scratch_space( + &module, + ct_grlwe_s0s2.size(), + ct_grlwe_s0s1.size(), + ct_grlwe_s1s2.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + let mut sk2: SecretKey> = SecretKey::new(&module); + sk2.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk2_dft.dft(&module, &sk2); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_s0s1.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s2}(s1) -> s1 -> s2 + ct_grlwe_s1s2.encrypt_sk( + &module, + &sk1.data, + &sk2_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) + ct_grlwe_s1s2.mul_grlwe( + &module, + &mut ct_grlwe_s0s2, + &ct_grlwe_s0s1, + scratch.borrow(), + ); + + let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + + (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { + ct_grlwe_s0s2.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + + module.free(); + } + + #[test] + fn mul_grlwe_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) + | GRLWECt::mul_grlwe_scratch_space( + &module, + ct_grlwe_s0s1.size(), + ct_grlwe_s0s1.size(), + ct_grlwe_s1s2.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + let mut sk2: SecretKey> = SecretKey::new(&module); + sk2.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk2_dft.dft(&module, &sk2); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_s0s1.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s2}(s1) -> s1 -> s2 + ct_grlwe_s1s2.encrypt_sk( + &module, + &sk1.data, + &sk2_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) + ct_grlwe_s1s2.mul_grlwe_inplace(&module, &mut ct_grlwe_s0s1, scratch.borrow()); + + let ct_grlwe_s0s2: GRLWECt, FFT64> = ct_grlwe_s0s1; + + let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + + (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { + ct_grlwe_s0s2.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + + module.free(); + } +} + +#[allow(dead_code)] +pub(crate) fn noise_grlwe_rlwe_product( + n: f64, + log_base2k: usize, + var_xs: f64, + var_msg: f64, + var_a_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + a_logq: usize, + b_logq: usize, +) -> f64 { + let a_logq: usize = a_logq.min(b_logq); + let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k; + + let b_scale = 2.0f64.powi(b_logq as i32); + let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); + + let base: f64 = (1 << (log_base2k)) as f64; + let var_base: f64 = base * base / 12f64; + + // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) + // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs + let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); + noise += var_msg * var_a_err * a_scale * a_scale * n; + noise = noise.sqrt(); + noise /= b_scale; + noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] +} diff --git a/rlwe/src/test_fft64/elem_rgsw.rs b/rlwe/src/test_fft64/elem_rgsw.rs new file mode 100644 index 0000000..b7af5ca --- /dev/null +++ b/rlwe/src/test_fft64/elem_rgsw.rs @@ -0,0 +1,88 @@ +#[cfg(test)] +mod tests { + use base2k::{ + FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero, + }; + use sampling::source::Source; + + use crate::{ + elem::Infos, + elem_rgsw::RGSWCt, + elem_rlwe::{RLWECtDft, RLWEPt}, + keys::{SecretKey, SecretKeyDft}, + }; + + #[test] + fn encrypt_rgsw_sk() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let rows: usize = 4; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); + + (0..ct.cols()).for_each(|col_j| { + (0..ct.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + if col_j == 1 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); + + ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + + pt_want.data.zero(); + }); + }); + + module.free(); + } +} diff --git a/rlwe/src/test_fft64/elem_rlwe.rs b/rlwe/src/test_fft64/elem_rlwe.rs new file mode 100644 index 0000000..d6f812b --- /dev/null +++ b/rlwe/src/test_fft64/elem_rlwe.rs @@ -0,0 +1,196 @@ +#[cfg(test)] +mod tests { + use base2k::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero}; + use itertools::izip; + use sampling::source::Source; + + use crate::{ + elem::Infos, + elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, + keys::{PublicKey, SecretKey, SecretKeyDft}, + }; + + #[test] + fn encrypt_sk() { + let module: Module = Module::::new(32); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let log_k_pt: usize = 30; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_pt); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + RLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECt::decrypt_scratch_space(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0xFF); + + pt.data + .encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10); + + ct.encrypt_sk( + &module, + Some(&pt), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + pt.data.zero(); + + ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + let mut data_have: Vec = vec![0i64; module.n()]; + + pt.data + .decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have); + + // TODO: properly assert the decryption noise through std(dec(ct) - pt) + let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64; + izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { + let b_scaled = (*b as f64) / scale; + assert!( + (*a as f64 - b_scaled).abs() < 0.1, + "{} {}", + *a as f64, + b_scaled + ) + }); + + module.free(); + } + + #[test] + fn encrypt_zero_sk() { + let module: Module = Module::::new(1024); + let log_base2k: usize = 8; + let log_k_ct: usize = 55; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut ct_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + + let mut scratch: ScratchOwned = ScratchOwned::new( + RLWECtDft::decrypt_scratch_space(&module, ct_dft.size()) + | RLWECtDft::encrypt_zero_sk_scratch_space(&module, ct_dft.size()), + ); + + ct_dft.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2); + module.free(); + } + + #[test] + fn encrypt_pk() { + let module: Module = Module::::new(32); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let log_k_pk: usize = 64; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xu: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut pk: PublicKey, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk); + pk.generate( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + ); + + let mut scratch: ScratchOwned = ScratchOwned::new( + RLWECt::encrypt_sk_scratch_space(&module, ct.size()) + | RLWECt::decrypt_scratch_space(&module, ct.size()) + | RLWECt::encrypt_pk_scratch_space(&module, pk.size()), + ); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0); + + pt_want + .data + .encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10); + + ct.encrypt_pk( + &module, + Some(&pt_want), + &pk, + &mut source_xu, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); + + assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2); + + module.free(); + } +} diff --git a/rlwe/src/test_fft64/mod.rs b/rlwe/src/test_fft64/mod.rs new file mode 100644 index 0000000..edac310 --- /dev/null +++ b/rlwe/src/test_fft64/mod.rs @@ -0,0 +1,3 @@ +mod elem_grlwe; +mod elem_rgsw; +mod elem_rlwe; From ee7b5744e4119bb3c6e38c396e135b8e47dbf221 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sat, 10 May 2025 11:26:01 +0200 Subject: [PATCH 52/87] Added rgsw ops --- rlwe/src/elem_rgsw.rs | 265 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 261 insertions(+), 4 deletions(-) diff --git a/rlwe/src/elem_rgsw.rs b/rlwe/src/elem_rgsw.rs index eab378a..75f468f 100644 --- a/rlwe/src/elem_rgsw.rs +++ b/rlwe/src/elem_rgsw.rs @@ -1,12 +1,14 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, - ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxOps, + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, + ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, }; use sampling::source::Source; use crate::{ elem::Infos, + elem_grlwe::GRLWECt, elem_rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, keys::SecretKeyDft, utils::derive_size, @@ -69,20 +71,42 @@ impl RGSWCt, FFT64> { + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) } + + pub fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { + module.bytes_of_vec_znx_dft(2, rgsw_size) + + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) + | module.vec_znx_big_normalize_tmp_bytes()) + } + + pub fn mul_rlwe_inplace_scratch_space(module: &Module, res_size: usize, rgsw_size: usize) -> usize { + Self::mul_rlwe_scratch_space(module, res_size, res_size, rgsw_size) + } } impl RGSWCt where MatZnxDft: MatZnxDftToRef, { - pub fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + pub fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) where - VecZnxDft: VecZnxDftToMut, + VecZnxDft: VecZnxDftToMut, { module.vmp_extract_row(res, self, row_i, col_j); } } +impl RGSWCt +where + MatZnxDft: MatZnxDftToMut, +{ + pub fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) + where + VecZnxDft: VecZnxDftToRef, + { + module.vmp_prepare_row(self, row_i, col_j, a); + } +} + pub fn encrypt_rgsw_sk( module: &Module, ct: &mut RGSWCt, @@ -168,4 +192,237 @@ impl RGSWCt { module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, ) } + + pub fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef, + { + let log_base2k: usize = self.log_base2k(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(a.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), module.n()); + } + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise + + { + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, a.size()); + module.vec_znx_dft(&mut a_dft, 0, a, 0); + module.vec_znx_dft(&mut a_dft, 1, a, 1); + module.vmp_apply(&mut res_dft, &a_dft, self, scratch2); + } + + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); + module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); + } + + pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnx: VecZnxToMut + VecZnxToRef, + { + unsafe { + let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. + self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); + } + } + + pub fn mul_rlwe_dft( + &self, + module: &Module, + res: &mut RLWECtDft, + a: &RLWECtDft, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { + let log_base2k: usize = self.log_base2k(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); + + let mut a_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: a_data, + log_base2k: a.log_base2k(), + log_k: a.log_k(), + }; + + a.idft(module, &mut a_idft, scratch_1); + + let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: res_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + { + let log_base2k: usize = self.log_base2k(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: res_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + res.idft(module, &mut res_idft, scratch_1); + + self.mul_rlwe_inplace(module, &mut res_idft, scratch_1); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + pub fn mul_grlwe( + &self, + module: &Module, + res: &mut GRLWECt, + a: &GRLWECt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToRef + ZnxInfos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: a.log_base2k(), + log_k: a.log_k(), + }; + + let min_rows: usize = res.rows().min(a.rows()); + + (0..min_rows).for_each(|row_i| { + a.get_row(module, row_i, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, &tmp_row); + }); + + tmp_row.data.zero(); + + (min_rows..res.rows()).for_each(|row_i| { + res.set_row(module, row_i, &tmp_row); + }) + } + + pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + (0..res.rows()).for_each(|row_i| { + res.get_row(module, row_i, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, &tmp_row); + }); + } + + pub fn mul_rgsw(&self, module: &Module, res: &mut RGSWCt, a: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToRef + ZnxInfos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: a.log_base2k(), + log_k: a.log_k(), + }; + + let min_rows: usize = res.rows().min(a.rows()); + + (0..min_rows).for_each(|row_i| { + a.get_row(module, row_i, 0, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, 0, &tmp_row); + }); + + (0..min_rows).for_each(|row_i| { + a.get_row(module, row_i, 1, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, 1, &tmp_row); + }); + + tmp_row.data.zero(); + + (min_rows..res.rows()).for_each(|row_i| { + res.set_row(module, row_i, 0, &tmp_row); + res.set_row(module, row_i, 1, &tmp_row); + }) + } + + pub fn mul_rgsw_inplace(&self, module: &Module, res: &mut RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + (0..res.rows()).for_each(|row_i| { + res.get_row(module, row_i, 0, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, 0, &tmp_row); + }); + + (0..res.rows()).for_each(|row_i| { + res.get_row(module, row_i, 1, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, 1, &tmp_row); + }); + } } From 17e1678fb031d8e59760da33cbeeea2331ff5cb0 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sat, 10 May 2025 11:27:54 +0200 Subject: [PATCH 53/87] Added scratch space size for rgsw ops --- rlwe/src/elem_rgsw.rs | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/rlwe/src/elem_rgsw.rs b/rlwe/src/elem_rgsw.rs index 75f468f..539aca0 100644 --- a/rlwe/src/elem_rgsw.rs +++ b/rlwe/src/elem_rgsw.rs @@ -81,6 +81,33 @@ impl RGSWCt, FFT64> { pub fn mul_rlwe_inplace_scratch_space(module: &Module, res_size: usize, rgsw_size: usize) -> usize { Self::mul_rlwe_scratch_space(module, res_size, res_size, rgsw_size) } + + pub fn mul_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + (Self::mul_rlwe_scratch_space(module, res_size, a_size, grlwe_size) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(2, a_size) + + module.bytes_of_vec_znx(2, res_size) + } + + pub fn mul_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, grlwe_size: usize) -> usize { + (Self::mul_rlwe_inplace_scratch_space(module, res_size, grlwe_size) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(2, res_size) + } + + pub fn mul_grlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) + } + + pub fn mul_grlwe_inplace_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) + } + + pub fn mul_rgsw_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) + } + + pub fn mul_rgsw_inplace_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) + } } impl RGSWCt From 912876807e2d15962114ced49495abc6b5213344 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sat, 10 May 2025 11:39:16 +0200 Subject: [PATCH 54/87] wip: rgsw ops test --- rlwe/src/test_fft64/elem_rgsw.rs | 145 ++++++++++++++++++++++++++++++- 1 file changed, 141 insertions(+), 4 deletions(-) diff --git a/rlwe/src/test_fft64/elem_rgsw.rs b/rlwe/src/test_fft64/elem_rgsw.rs index b7af5ca..97f34e5 100644 --- a/rlwe/src/test_fft64/elem_rgsw.rs +++ b/rlwe/src/test_fft64/elem_rgsw.rs @@ -1,16 +1,15 @@ #[cfg(test)] mod tests { use base2k::{ - FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero, + FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxView, ZnxViewMut, ZnxZero, FFT64 }; use sampling::source::Source; use crate::{ elem::Infos, elem_rgsw::RGSWCt, - elem_rlwe::{RLWECtDft, RLWEPt}, - keys::{SecretKey, SecretKeyDft}, + elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, + keys::{SecretKey, SecretKeyDft}, test_fft64::elem_rgsw::noise_rgsw_rlwe_product, }; #[test] @@ -85,4 +84,142 @@ mod tests { module.free(); } + + #[test] + fn mul_rlwe() { + let module: Module = Module::::new(32); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + //pt_want + // .data + // .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[0] = 1; + + pt_rgsw.raw_mut()[1] = 1; // X^{1} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RGSWCt::mul_rlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_rgsw.size(), + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw.mul_rlwe(&module, &mut ct_rlwe_out, &ct_rlwe_in, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + + println!("{}", pt_want.data); + println!("{}", pt_have.data); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_rgsw_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } } + +#[allow(dead_code)] +pub(crate) fn noise_rgsw_rlwe_product( + n: f64, + log_base2k: usize, + var_xs: f64, + var_msg: f64, + var_a0_err: f64, + var_a1_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + a_logq: usize, + b_logq: usize, +) -> f64 { + let a_logq: usize = a_logq.min(b_logq); + let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k; + + let b_scale = 2.0f64.powi(b_logq as i32); + let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); + + let base: f64 = (1 << (log_base2k)) as f64; + let var_base: f64 = base * base / 12f64; + + // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) + // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs + let mut noise: f64 = 2.0 * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); + noise += var_msg * var_a0_err * a_scale * a_scale * n; + noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs; + noise = noise.sqrt(); + noise /= b_scale; + noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] +} \ No newline at end of file From 4e5a8dba0952068ef1605a4ca3f413e7da78b747 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sat, 10 May 2025 15:37:13 +0200 Subject: [PATCH 55/87] fixed rgsw mul rlwe test --- rlwe/src/elem_rlwe.rs | 2 +- rlwe/src/test_fft64/elem_rgsw.rs | 42 ++++++++++++++++++++------------ 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/rlwe/src/elem_rlwe.rs b/rlwe/src/elem_rlwe.rs index 54cb4f9..72f48a5 100644 --- a/rlwe/src/elem_rlwe.rs +++ b/rlwe/src/elem_rlwe.rs @@ -194,7 +194,7 @@ where pub(crate) fn idft(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) where - VecZnx: VecZnxToMut + ZnxInfos, + VecZnx: VecZnxToMut, { #[cfg(debug_assertions)] { diff --git a/rlwe/src/test_fft64/elem_rgsw.rs b/rlwe/src/test_fft64/elem_rgsw.rs index 97f34e5..9ab790f 100644 --- a/rlwe/src/test_fft64/elem_rgsw.rs +++ b/rlwe/src/test_fft64/elem_rgsw.rs @@ -1,7 +1,8 @@ #[cfg(test)] mod tests { use base2k::{ - FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxView, ZnxViewMut, ZnxZero, FFT64 + FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, }; use sampling::source::Source; @@ -9,7 +10,8 @@ mod tests { elem::Infos, elem_rgsw::RGSWCt, elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, - keys::{SecretKey, SecretKeyDft}, test_fft64::elem_rgsw::noise_rgsw_rlwe_product, + keys::{SecretKey, SecretKeyDft}, + test_fft64::elem_rgsw::noise_rgsw_rlwe_product, }; #[test] @@ -87,7 +89,7 @@ mod tests { #[test] fn mul_rlwe() { - let module: Module = Module::::new(32); + let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; let log_k_rlwe_in: usize = 45; @@ -109,13 +111,15 @@ mod tests { let mut source_xa: Source = Source::new([0u8; 32]); // Random input plaintext - //pt_want + // pt_want // .data // .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - pt_want.to_mut().at_mut(0, 0)[0] = 1; + pt_want.to_mut().at_mut(0, 0)[1] = 1; - pt_rgsw.raw_mut()[1] = 1; // X^{1} + let r: usize = 1; + + pt_rgsw.raw_mut()[r] = 1; // X^{r} let mut scratch: ScratchOwned = ScratchOwned::new( RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) @@ -161,22 +165,28 @@ mod tests { ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + module.vec_znx_rotate_inplace(r as i64, &mut pt_want, 0); + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - println!("{}", pt_want.data); - println!("{}", pt_have.data); - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + let noise_want: f64 = noise_rgsw_rlwe_product( module.n() as f64, log_base2k, 0.5, - 0.5, - 0f64, - 0f64, - sigma * sigma, - 0f64, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, log_k_rlwe_in, log_k_grlwe, ); @@ -222,4 +232,4 @@ pub(crate) fn noise_rgsw_rlwe_product( noise = noise.sqrt(); noise /= b_scale; noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] -} \ No newline at end of file +} From 5d56d78d91b66655aeaaccab8a8fea5a281ed1fa Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sat, 10 May 2025 18:14:14 +0200 Subject: [PATCH 56/87] factorized out vmp ops into a common trait & implementation --- rlwe/src/elem.rs | 170 ++++++++++++++++- rlwe/src/elem_grlwe.rs | 239 ++++++++++------------- rlwe/src/elem_rgsw.rs | 317 +++++++++---------------------- rlwe/src/test_fft64/elem_rgsw.rs | 8 +- 4 files changed, 362 insertions(+), 372 deletions(-) diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index c943de1..98e2677 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,6 +1,13 @@ -use base2k::ZnxInfos; +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, + VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, +}; -use crate::utils::derive_size; +use crate::{ + elem_grlwe::GRLWECt, + elem_rlwe::{RLWECt, RLWECtDft}, + utils::derive_size, +}; pub trait Infos { type Inner: ZnxInfos; @@ -45,3 +52,162 @@ pub trait Infos { /// Returns the bit precision of the ciphertext. fn log_k(&self) -> usize; } + +pub trait GetRow { + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + where + VecZnxDft: VecZnxDftToMut; +} + +pub trait SetRow { + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) + where + VecZnxDft: VecZnxDftToRef; +} + +pub(crate) trait MatZnxDftProducts: Infos +where + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef; + + fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnx: VecZnxToMut + VecZnxToRef, + { + unsafe { + let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. + self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); + } + } + + fn mul_rlwe_dft( + &self, + module: &Module, + res: &mut RLWECtDft, + a: &RLWECtDft, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { + let log_base2k: usize = self.log_base2k(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); + + let mut a_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: a_data, + log_base2k: a.log_base2k(), + log_k: a.log_k(), + }; + + a.idft(module, &mut a_idft, scratch_1); + + let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: res_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, + { + let log_base2k: usize = self.log_base2k(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: res_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + res.idft(module, &mut res_idft, scratch_1); + + self.mul_rlwe_inplace(module, &mut res_idft, scratch_1); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + fn mul_grlwe(&self, module: &Module, res: &mut GRLWECt, a: &GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToRef + ZnxInfos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: a.log_base2k(), + log_k: a.log_k(), + }; + + let min_rows: usize = res.rows().min(a.rows()); + + (0..min_rows).for_each(|row_i| { + a.get_row(module, row_i, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, &tmp_row); + }); + + tmp_row.data.zero(); + + (min_rows..res.rows()).for_each(|row_i| { + res.set_row(module, row_i, &tmp_row); + }) + } + + fn mul_grlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + R: GetRow + SetRow + Infos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + (0..self.cols()).for_each(|col_j| { + (0..res.rows()).for_each(|row_i| { + res.get_row(module, row_i, col_j, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, col_j, &tmp_row); + }); + }) + } +} diff --git a/rlwe/src/elem_grlwe.rs b/rlwe/src/elem_grlwe.rs index b865c1e..0567c07 100644 --- a/rlwe/src/elem_grlwe.rs +++ b/rlwe/src/elem_grlwe.rs @@ -7,7 +7,7 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::Infos, + elem::{GetRow, Infos, MatZnxDftProducts, SetRow}, elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, keys::SecretKeyDft, utils::derive_size, @@ -211,8 +211,106 @@ impl GRLWECt { } pub fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef, + { + MatZnxDftProducts::mul_rlwe(self, module, res, a, scratch); + } + + pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) where MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnx: VecZnxToMut + VecZnxToRef, + { + MatZnxDftProducts::mul_rlwe_inplace(self, module, res, scratch); + } + + pub fn mul_rlwe_dft( + &self, + module: &Module, + res: &mut RLWECtDft, + a: &RLWECtDft, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { + MatZnxDftProducts::mul_rlwe_dft(self, module, res, a, scratch); + } + + pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, + { + MatZnxDftProducts::mul_rlwe_dft_inplace(self, module, res, scratch); + } + + pub fn mul_grlwe( + &self, + module: &Module, + res: &mut GRLWECt, + a: &GRLWECt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToRef + ZnxInfos, + { + MatZnxDftProducts::mul_grlwe(self, module, res, a, scratch); + } + + pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + R: GetRow + SetRow + Infos, + { + MatZnxDftProducts::mul_grlwe_inplace(self, module, res, scratch); + } +} + +impl GetRow for GRLWECt +where + MatZnxDft: MatZnxDftToRef, +{ + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + where + VecZnxDft: VecZnxDftToMut, + { + #[cfg(debug_assertions)] + { + assert_eq!(col_j, 0); + } + module.vmp_extract_row(res, self, row_i, col_j); + } +} + +impl SetRow for GRLWECt +where + MatZnxDft: MatZnxDftToMut, +{ + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) + where + VecZnxDft: VecZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(col_j, 0); + } + module.vmp_prepare_row(self, row_i, col_j, a); + } +} + +impl MatZnxDftProducts, C> for GRLWECt +where + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, VecZnx: VecZnxToMut, VecZnx: VecZnxToRef, { @@ -242,143 +340,4 @@ impl GRLWECt { module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); } - - pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnx: VecZnxToMut + VecZnxToRef, - { - unsafe { - let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. - self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); - } - } - - pub fn mul_rlwe_dft( - &self, - module: &Module, - res: &mut RLWECtDft, - a: &RLWECtDft, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + ZnxInfos, - { - let log_base2k: usize = self.log_base2k(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); - - let mut a_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: a_data, - log_base2k: a.log_base2k(), - log_k: a.log_k(), - }; - - a.idft(module, &mut a_idft, scratch_1); - - let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); - - let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: res_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2); - - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); - } - - pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - { - let log_base2k: usize = self.log_base2k(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); - - let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: res_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - res.idft(module, &mut res_idft, scratch_1); - - self.mul_rlwe_inplace(module, &mut res_idft, scratch_1); - - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); - } - - pub fn mul_grlwe( - &self, - module: &Module, - res: &mut GRLWECt, - a: &GRLWECt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToRef + ZnxInfos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); - - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: a.log_base2k(), - log_k: a.log_k(), - }; - - let min_rows: usize = res.rows().min(a.rows()); - - (0..min_rows).for_each(|row_i| { - a.get_row(module, row_i, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, &tmp_row); - }); - - tmp_row.data.zero(); - - (min_rows..res.rows()).for_each(|row_i| { - res.set_row(module, row_i, &tmp_row); - }) - } - - pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); - - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - (0..res.rows()).for_each(|row_i| { - res.get_row(module, row_i, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, &tmp_row); - }); - } } diff --git a/rlwe/src/elem_rgsw.rs b/rlwe/src/elem_rgsw.rs index 539aca0..beeeef9 100644 --- a/rlwe/src/elem_rgsw.rs +++ b/rlwe/src/elem_rgsw.rs @@ -7,7 +7,7 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::Infos, + elem::{GetRow, Infos, MatZnxDftProducts, SetRow}, elem_grlwe::GRLWECt, elem_rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, keys::SecretKeyDft, @@ -110,30 +110,6 @@ impl RGSWCt, FFT64> { } } -impl RGSWCt -where - MatZnxDft: MatZnxDftToRef, -{ - pub fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) - where - VecZnxDft: VecZnxDftToMut, - { - module.vmp_extract_row(res, self, row_i, col_j); - } -} - -impl RGSWCt -where - MatZnxDft: MatZnxDftToMut, -{ - pub fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) - where - VecZnxDft: VecZnxDftToRef, - { - module.vmp_prepare_row(self, row_i, col_j, a); - } -} - pub fn encrypt_rgsw_sk( module: &Module, ct: &mut RGSWCt, @@ -221,6 +197,96 @@ impl RGSWCt { } pub fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef, + { + MatZnxDftProducts::mul_rlwe(self, module, res, a, scratch); + } + + pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnx: VecZnxToMut + VecZnxToRef, + { + MatZnxDftProducts::mul_rlwe_inplace(self, module, res, scratch); + } + + pub fn mul_rlwe_dft( + &self, + module: &Module, + res: &mut RLWECtDft, + a: &RLWECtDft, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { + MatZnxDftProducts::mul_rlwe_dft(self, module, res, a, scratch); + } + + pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, + { + MatZnxDftProducts::mul_rlwe_dft_inplace(self, module, res, scratch); + } + + pub fn mul_grlwe( + &self, + module: &Module, + res: &mut GRLWECt, + a: &GRLWECt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToRef + ZnxInfos, + { + MatZnxDftProducts::mul_grlwe(self, module, res, a, scratch); + } + + pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + R: GetRow + SetRow + Infos, + { + MatZnxDftProducts::mul_grlwe_inplace(self, module, res, scratch); + } +} + +impl GetRow for RGSWCt +where + MatZnxDft: MatZnxDftToRef, +{ + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + where + VecZnxDft: VecZnxDftToMut, + { + module.vmp_extract_row(res, self, row_i, col_j); + } +} + +impl SetRow for RGSWCt +where + MatZnxDft: MatZnxDftToMut, +{ + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) + where + VecZnxDft: VecZnxDftToRef, + { + module.vmp_prepare_row(self, row_i, col_j, a); + } +} + +impl MatZnxDftProducts, C> for RGSWCt +where + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) where MatZnxDft: MatZnxDftToRef, VecZnx: VecZnxToMut, @@ -251,205 +317,4 @@ impl RGSWCt { module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); } - - pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnx: VecZnxToMut + VecZnxToRef, - { - unsafe { - let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. - self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); - } - } - - pub fn mul_rlwe_dft( - &self, - module: &Module, - res: &mut RLWECtDft, - a: &RLWECtDft, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + ZnxInfos, - { - let log_base2k: usize = self.log_base2k(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); - - let mut a_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: a_data, - log_base2k: a.log_base2k(), - log_k: a.log_k(), - }; - - a.idft(module, &mut a_idft, scratch_1); - - let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); - - let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: res_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2); - - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); - } - - pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - { - let log_base2k: usize = self.log_base2k(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); - - let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: res_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - res.idft(module, &mut res_idft, scratch_1); - - self.mul_rlwe_inplace(module, &mut res_idft, scratch_1); - - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); - } - - pub fn mul_grlwe( - &self, - module: &Module, - res: &mut GRLWECt, - a: &GRLWECt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToRef + ZnxInfos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); - - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: a.log_base2k(), - log_k: a.log_k(), - }; - - let min_rows: usize = res.rows().min(a.rows()); - - (0..min_rows).for_each(|row_i| { - a.get_row(module, row_i, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, &tmp_row); - }); - - tmp_row.data.zero(); - - (min_rows..res.rows()).for_each(|row_i| { - res.set_row(module, row_i, &tmp_row); - }) - } - - pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); - - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - (0..res.rows()).for_each(|row_i| { - res.get_row(module, row_i, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, &tmp_row); - }); - } - - pub fn mul_rgsw(&self, module: &Module, res: &mut RGSWCt, a: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToRef + ZnxInfos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); - - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: a.log_base2k(), - log_k: a.log_k(), - }; - - let min_rows: usize = res.rows().min(a.rows()); - - (0..min_rows).for_each(|row_i| { - a.get_row(module, row_i, 0, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, 0, &tmp_row); - }); - - (0..min_rows).for_each(|row_i| { - a.get_row(module, row_i, 1, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, 1, &tmp_row); - }); - - tmp_row.data.zero(); - - (min_rows..res.rows()).for_each(|row_i| { - res.set_row(module, row_i, 0, &tmp_row); - res.set_row(module, row_i, 1, &tmp_row); - }) - } - - pub fn mul_rgsw_inplace(&self, module: &Module, res: &mut RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); - - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - (0..res.rows()).for_each(|row_i| { - res.get_row(module, row_i, 0, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, 0, &tmp_row); - }); - - (0..res.rows()).for_each(|row_i| { - res.get_row(module, row_i, 1, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, 1, &tmp_row); - }); - } } diff --git a/rlwe/src/test_fft64/elem_rgsw.rs b/rlwe/src/test_fft64/elem_rgsw.rs index 9ab790f..e076237 100644 --- a/rlwe/src/test_fft64/elem_rgsw.rs +++ b/rlwe/src/test_fft64/elem_rgsw.rs @@ -7,7 +7,7 @@ mod tests { use sampling::source::Source; use crate::{ - elem::Infos, + elem::{GetRow, Infos}, elem_rgsw::RGSWCt, elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, keys::{SecretKey, SecretKeyDft}, @@ -117,9 +117,9 @@ mod tests { pt_want.to_mut().at_mut(0, 0)[1] = 1; - let r: usize = 1; + let k: usize = 1; - pt_rgsw.raw_mut()[r] = 1; // X^{r} + pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) @@ -165,7 +165,7 @@ mod tests { ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - module.vec_znx_rotate_inplace(r as i64, &mut pt_want, 0); + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); From 54fab8e4f37dedb0f3d3482559f7690482f95ecd Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sun, 11 May 2025 11:13:53 +0200 Subject: [PATCH 57/87] renamed crate & files --- Cargo.toml | 2 +- {rlwe => core}/Cargo.toml | 0 {rlwe => core}/src/elem.rs | 4 ++-- rlwe/src/elem_grlwe.rs => core/src/grlwe.rs | 2 +- {rlwe => core}/src/keys.rs | 2 +- core/src/lib.rs | 7 +++++++ rlwe/src/elem_rgsw.rs => core/src/rgsw.rs | 4 ++-- rlwe/src/elem_rlwe.rs => core/src/rlwe.rs | 2 +- .../elem_grlwe.rs => core/src/test_fft64/grlwe.rs | 8 ++++---- core/src/test_fft64/mod.rs | 3 +++ .../elem_rgsw.rs => core/src/test_fft64/rgsw.rs | 6 +++--- .../elem_rlwe.rs => core/src/test_fft64/rlwe.rs | 2 +- {rlwe => core}/src/utils.rs | 0 rlwe/src/lib.rs | 7 ------- rlwe/src/test_fft64/mod.rs | 3 --- 15 files changed, 26 insertions(+), 26 deletions(-) rename {rlwe => core}/Cargo.toml (100%) rename {rlwe => core}/src/elem.rs (99%) rename rlwe/src/elem_grlwe.rs => core/src/grlwe.rs (99%) rename {rlwe => core}/src/keys.rs (99%) create mode 100644 core/src/lib.rs rename rlwe/src/elem_rgsw.rs => core/src/rgsw.rs (99%) rename rlwe/src/elem_rlwe.rs => core/src/rlwe.rs (99%) rename rlwe/src/test_fft64/elem_grlwe.rs => core/src/test_fft64/grlwe.rs (99%) create mode 100644 core/src/test_fft64/mod.rs rename rlwe/src/test_fft64/elem_rgsw.rs => core/src/test_fft64/rgsw.rs (98%) rename rlwe/src/test_fft64/elem_rlwe.rs => core/src/test_fft64/rlwe.rs (99%) rename {rlwe => core}/src/utils.rs (100%) delete mode 100644 rlwe/src/lib.rs delete mode 100644 rlwe/src/test_fft64/mod.rs diff --git a/Cargo.toml b/Cargo.toml index b99028c..6f2a91e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["base2k", "rlwe", "sampling", "utils"] +members = ["base2k", "core", "sampling", "utils"] resolver = "3" [workspace.dependencies] diff --git a/rlwe/Cargo.toml b/core/Cargo.toml similarity index 100% rename from rlwe/Cargo.toml rename to core/Cargo.toml diff --git a/rlwe/src/elem.rs b/core/src/elem.rs similarity index 99% rename from rlwe/src/elem.rs rename to core/src/elem.rs index 98e2677..ac245ad 100644 --- a/rlwe/src/elem.rs +++ b/core/src/elem.rs @@ -4,8 +4,8 @@ use base2k::{ }; use crate::{ - elem_grlwe::GRLWECt, - elem_rlwe::{RLWECt, RLWECtDft}, + grlwe::GRLWECt, + rlwe::{RLWECt, RLWECtDft}, utils::derive_size, }; diff --git a/rlwe/src/elem_grlwe.rs b/core/src/grlwe.rs similarity index 99% rename from rlwe/src/elem_grlwe.rs rename to core/src/grlwe.rs index 0567c07..df44a70 100644 --- a/rlwe/src/elem_grlwe.rs +++ b/core/src/grlwe.rs @@ -8,8 +8,8 @@ use sampling::source::Source; use crate::{ elem::{GetRow, Infos, MatZnxDftProducts, SetRow}, - elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, keys::SecretKeyDft, + rlwe::{RLWECt, RLWECtDft, RLWEPt}, utils::derive_size, }; diff --git a/rlwe/src/keys.rs b/core/src/keys.rs similarity index 99% rename from rlwe/src/keys.rs rename to core/src/keys.rs index 19fda01..8285f85 100644 --- a/rlwe/src/keys.rs +++ b/core/src/keys.rs @@ -5,7 +5,7 @@ use base2k::{ }; use sampling::source::Source; -use crate::{elem::Infos, elem_rlwe::RLWECtDft}; +use crate::{elem::Infos, rlwe::RLWECtDft}; #[derive(Clone, Copy, Debug)] pub enum SecretDistribution { diff --git a/core/src/lib.rs b/core/src/lib.rs new file mode 100644 index 0000000..a93d44e --- /dev/null +++ b/core/src/lib.rs @@ -0,0 +1,7 @@ +pub mod elem; +pub mod grlwe; +pub mod keys; +pub mod rgsw; +pub mod rlwe; +mod test_fft64; +mod utils; diff --git a/rlwe/src/elem_rgsw.rs b/core/src/rgsw.rs similarity index 99% rename from rlwe/src/elem_rgsw.rs rename to core/src/rgsw.rs index beeeef9..f271c15 100644 --- a/rlwe/src/elem_rgsw.rs +++ b/core/src/rgsw.rs @@ -8,9 +8,9 @@ use sampling::source::Source; use crate::{ elem::{GetRow, Infos, MatZnxDftProducts, SetRow}, - elem_grlwe::GRLWECt, - elem_rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, + grlwe::GRLWECt, keys::SecretKeyDft, + rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, utils::derive_size, }; diff --git a/rlwe/src/elem_rlwe.rs b/core/src/rlwe.rs similarity index 99% rename from rlwe/src/elem_rlwe.rs rename to core/src/rlwe.rs index 72f48a5..b52d56d 100644 --- a/rlwe/src/elem_rlwe.rs +++ b/core/src/rlwe.rs @@ -7,7 +7,7 @@ use sampling::source::Source; use crate::{ elem::Infos, - elem_grlwe::GRLWECt, + grlwe::GRLWECt, keys::{PublicKey, SecretDistribution, SecretKeyDft}, utils::derive_size, }; diff --git a/rlwe/src/test_fft64/elem_grlwe.rs b/core/src/test_fft64/grlwe.rs similarity index 99% rename from rlwe/src/test_fft64/elem_grlwe.rs rename to core/src/test_fft64/grlwe.rs index aa871f3..86c13ec 100644 --- a/rlwe/src/test_fft64/elem_grlwe.rs +++ b/core/src/test_fft64/grlwe.rs @@ -1,15 +1,15 @@ #[cfg(test)] -mod test { +mod tests { use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; use crate::{ elem::Infos, - elem_grlwe::GRLWECt, - elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, + grlwe::GRLWECt, keys::{SecretKey, SecretKeyDft}, - test_fft64::elem_grlwe::noise_grlwe_rlwe_product, + rlwe::{RLWECt, RLWECtDft, RLWEPt}, + test_fft64::grlwe::noise_grlwe_rlwe_product, }; #[test] diff --git a/core/src/test_fft64/mod.rs b/core/src/test_fft64/mod.rs new file mode 100644 index 0000000..36d380c --- /dev/null +++ b/core/src/test_fft64/mod.rs @@ -0,0 +1,3 @@ +mod grlwe; +mod rgsw; +mod rlwe; diff --git a/rlwe/src/test_fft64/elem_rgsw.rs b/core/src/test_fft64/rgsw.rs similarity index 98% rename from rlwe/src/test_fft64/elem_rgsw.rs rename to core/src/test_fft64/rgsw.rs index e076237..651f6b1 100644 --- a/rlwe/src/test_fft64/elem_rgsw.rs +++ b/core/src/test_fft64/rgsw.rs @@ -8,10 +8,10 @@ mod tests { use crate::{ elem::{GetRow, Infos}, - elem_rgsw::RGSWCt, - elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, keys::{SecretKey, SecretKeyDft}, - test_fft64::elem_rgsw::noise_rgsw_rlwe_product, + rgsw::RGSWCt, + rlwe::{RLWECt, RLWECtDft, RLWEPt}, + test_fft64::rgsw::noise_rgsw_rlwe_product, }; #[test] diff --git a/rlwe/src/test_fft64/elem_rlwe.rs b/core/src/test_fft64/rlwe.rs similarity index 99% rename from rlwe/src/test_fft64/elem_rlwe.rs rename to core/src/test_fft64/rlwe.rs index d6f812b..e735aa6 100644 --- a/rlwe/src/test_fft64/elem_rlwe.rs +++ b/core/src/test_fft64/rlwe.rs @@ -6,8 +6,8 @@ mod tests { use crate::{ elem::Infos, - elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, keys::{PublicKey, SecretKey, SecretKeyDft}, + rlwe::{RLWECt, RLWECtDft, RLWEPt}, }; #[test] diff --git a/rlwe/src/utils.rs b/core/src/utils.rs similarity index 100% rename from rlwe/src/utils.rs rename to core/src/utils.rs diff --git a/rlwe/src/lib.rs b/rlwe/src/lib.rs deleted file mode 100644 index cad8dbc..0000000 --- a/rlwe/src/lib.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod elem; -pub mod elem_grlwe; -pub mod elem_rgsw; -pub mod elem_rlwe; -pub mod keys; -mod test_fft64; -mod utils; diff --git a/rlwe/src/test_fft64/mod.rs b/rlwe/src/test_fft64/mod.rs deleted file mode 100644 index edac310..0000000 --- a/rlwe/src/test_fft64/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod elem_grlwe; -mod elem_rgsw; -mod elem_rlwe; From 73098af73a550077b2a49633d7a70cc998f879b6 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sun, 11 May 2025 18:33:47 +0200 Subject: [PATCH 58/87] abstracted products for all cross types --- core/src/elem.rs | 93 ++++-- core/src/grlwe.rs | 190 ++++++------ core/src/rgsw.rs | 183 ++++++------ core/src/rlwe.rs | 501 ++++++++++++++++++++------------ core/src/test_fft64/grlwe.rs | 434 +-------------------------- core/src/test_fft64/mod.rs | 1 + core/src/test_fft64/rgsw.rs | 116 +------- core/src/test_fft64/rlwe.rs | 431 ++++++++++++++++++++++++++- core/src/test_fft64/rlwe_dft.rs | 216 ++++++++++++++ 9 files changed, 1219 insertions(+), 946 deletions(-) create mode 100644 core/src/test_fft64/rlwe_dft.rs diff --git a/core/src/elem.rs b/core/src/elem.rs index ac245ad..192bc74 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,10 +1,11 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, - VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, + Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, + VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, }; use crate::{ grlwe::GRLWECt, + rgsw::RGSWCt, rlwe::{RLWECt, RLWECtDft}, utils::derive_size, }; @@ -65,6 +66,36 @@ pub trait SetRow { VecZnxDft: VecZnxDftToRef; } +pub trait ProdByScratchSpace { + fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; + fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; +} + +pub trait ProdBy { + fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef; + + fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef; +} + +pub trait FromProdByScratchSpace { + fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; + fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; +} + +pub trait FromProdBy { + fn from_prod_by_grlwe(&mut self, module: &Module, lhs: &L, rhs: &GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef; + + fn from_prod_by_rgsw(&mut self, module: &Module, lhs: &L, rhs: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef; +} + pub(crate) trait MatZnxDftProducts: Infos where MatZnxDft: MatZnxDftToRef + ZnxInfos, @@ -75,6 +106,31 @@ where VecZnx: VecZnxToMut, VecZnx: VecZnxToRef; + fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize; + + fn mul_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { + Self::mul_rlwe_scratch_space(module, res_size, res_size, mat_size) + } + + fn mul_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { + (Self::mul_rlwe_scratch_space(module, res_size, a_size, mat_size) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(2, a_size) + + module.bytes_of_vec_znx(2, res_size) + } + + fn mul_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { + (Self::mul_rlwe_inplace_scratch_space(module, res_size, mat_size) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(2, res_size) + } + + fn mul_mat_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { + Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, a_size) + } + + fn mul_mat_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { + Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, res_size) + } + fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) where MatZnxDft: MatZnxDftToRef + ZnxInfos, @@ -132,7 +188,6 @@ where fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, { let log_base2k: usize = self.log_base2k(); @@ -160,11 +215,10 @@ where module.vec_znx_dft(res, 1, &res_idft, 1); } - fn mul_grlwe(&self, module: &Module, res: &mut GRLWECt, a: &GRLWECt, scratch: &mut Scratch) + fn mul_mat_rlwe(&self, module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToRef + ZnxInfos, + A: GetRow + Infos, + R: SetRow + Infos, { let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); @@ -176,22 +230,25 @@ where let min_rows: usize = res.rows().min(a.rows()); - (0..min_rows).for_each(|row_i| { - a.get_row(module, row_i, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, &tmp_row); + (0..res.rows()).for_each(|row_i| { + (0..self.cols()).for_each(|col_j| { + a.get_row(module, row_i, col_j, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, col_j, &tmp_row); + }); }); tmp_row.data.zero(); (min_rows..res.rows()).for_each(|row_i| { - res.set_row(module, row_i, &tmp_row); - }) + (0..self.cols()).for_each(|col_j| { + res.set_row(module, row_i, col_j, &tmp_row); + }); + }); } - fn mul_grlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) + fn mul_mat_rlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, R: GetRow + SetRow + Infos, { let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); @@ -202,12 +259,12 @@ where log_k: res.log_k(), }; - (0..self.cols()).for_each(|col_j| { - (0..res.rows()).for_each(|row_i| { + (0..res.rows()).for_each(|row_i| { + (0..self.cols()).for_each(|col_j| { res.get_row(module, row_i, col_j, &mut tmp_row); self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); res.set_row(module, row_i, col_j, &tmp_row); }); - }) + }); } } diff --git a/core/src/grlwe.rs b/core/src/grlwe.rs index df44a70..9c8c5b8 100644 --- a/core/src/grlwe.rs +++ b/core/src/grlwe.rs @@ -7,8 +7,9 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{GetRow, Infos, MatZnxDftProducts, SetRow}, + elem::{FromProdBy, FromProdByScratchSpace, GetRow, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace, SetRow}, keys::SecretKeyDft, + rgsw::RGSWCt, rlwe::{RLWECt, RLWECtDft, RLWEPt}, utils::derive_size, }; @@ -41,18 +42,6 @@ where } } -impl GRLWECt -where - MatZnxDft: MatZnxDftToMut, -{ - pub fn set_row(&mut self, module: &Module, row_i: usize, a: &RLWECtDft) - where - VecZnxDft: VecZnxDftToRef, - { - module.vmp_prepare_row(self, row_i, 0, a); - } -} - impl Infos for GRLWECt { type Inner = MatZnxDft; @@ -94,36 +83,6 @@ impl GRLWECt, FFT64> { + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) } - - pub fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - module.bytes_of_vec_znx_dft(2, grlwe_size) - + (module.vec_znx_big_normalize_tmp_bytes() - | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) - + module.bytes_of_vec_znx_dft(1, a_size))) - } - - pub fn mul_rlwe_inplace_scratch_space(module: &Module, res_size: usize, grlwe_size: usize) -> usize { - Self::mul_rlwe_scratch_space(module, res_size, res_size, grlwe_size) - } - - pub fn mul_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - (Self::mul_rlwe_scratch_space(module, res_size, a_size, grlwe_size) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, a_size) - + module.bytes_of_vec_znx(2, res_size) - } - - pub fn mul_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, grlwe_size: usize) -> usize { - (Self::mul_rlwe_inplace_scratch_space(module, res_size, grlwe_size) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, res_size) - } - - pub fn mul_grlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) - } - - pub fn mul_grlwe_inplace_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) - } } pub fn encrypt_grlwe_sk( @@ -209,67 +168,6 @@ impl GRLWECt { module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, ) } - - pub fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef, - { - MatZnxDftProducts::mul_rlwe(self, module, res, a, scratch); - } - - pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnx: VecZnxToMut + VecZnxToRef, - { - MatZnxDftProducts::mul_rlwe_inplace(self, module, res, scratch); - } - - pub fn mul_rlwe_dft( - &self, - module: &Module, - res: &mut RLWECtDft, - a: &RLWECtDft, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + ZnxInfos, - { - MatZnxDftProducts::mul_rlwe_dft(self, module, res, a, scratch); - } - - pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, - { - MatZnxDftProducts::mul_rlwe_dft_inplace(self, module, res, scratch); - } - - pub fn mul_grlwe( - &self, - module: &Module, - res: &mut GRLWECt, - a: &GRLWECt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToRef + ZnxInfos, - { - MatZnxDftProducts::mul_grlwe(self, module, res, a, scratch); - } - - pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - R: GetRow + SetRow + Infos, - { - MatZnxDftProducts::mul_grlwe_inplace(self, module, res, scratch); - } } impl GetRow for GRLWECt @@ -308,6 +206,13 @@ impl MatZnxDftProducts, C> for GRLWECt where MatZnxDft: MatZnxDftToRef + ZnxInfos, { + fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + module.bytes_of_vec_znx_dft(2, grlwe_size) + + (module.vec_znx_big_normalize_tmp_bytes() + | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) + + module.bytes_of_vec_znx_dft(1, a_size))) + } + fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) where MatZnxDft: MatZnxDftToRef, @@ -341,3 +246,80 @@ where module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); } } + +impl ProdByScratchSpace for GRLWECt, FFT64> { + fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( + module, lhs, rhs, + ) + } + + fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( + module, lhs, rhs, + ) + } +} + +impl FromProdByScratchSpace for GRLWECt, FFT64> { + fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( + module, res_size, lhs, rhs, + ) + } + + fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl ProdBy> for GRLWECt +where + GRLWECt: GetRow + SetRow + Infos, +{ + fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe_inplace(module, self, scratch); + } + + fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe_inplace(module, self, scratch); + } +} + +impl FromProdBy, GRLWECt> for GRLWECt +where + GRLWECt: GetRow + SetRow + Infos, + GRLWECt: GetRow + Infos, +{ + fn from_prod_by_grlwe( + &mut self, + module: &Module, + lhs: &GRLWECt, + rhs: &GRLWECt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe(module, self, lhs, scratch); + } + + fn from_prod_by_rgsw( + &mut self, + module: &Module, + lhs: &GRLWECt, + rhs: &RGSWCt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe(module, self, lhs, scratch); + } +} diff --git a/core/src/rgsw.rs b/core/src/rgsw.rs index f271c15..c4c7c1c 100644 --- a/core/src/rgsw.rs +++ b/core/src/rgsw.rs @@ -7,7 +7,7 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{GetRow, Infos, MatZnxDftProducts, SetRow}, + elem::{FromProdBy, FromProdByScratchSpace, GetRow, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace, SetRow}, grlwe::GRLWECt, keys::SecretKeyDft, rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, @@ -71,43 +71,6 @@ impl RGSWCt, FFT64> { + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) } - - pub fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { - module.bytes_of_vec_znx_dft(2, rgsw_size) - + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) - | module.vec_znx_big_normalize_tmp_bytes()) - } - - pub fn mul_rlwe_inplace_scratch_space(module: &Module, res_size: usize, rgsw_size: usize) -> usize { - Self::mul_rlwe_scratch_space(module, res_size, res_size, rgsw_size) - } - - pub fn mul_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - (Self::mul_rlwe_scratch_space(module, res_size, a_size, grlwe_size) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, a_size) - + module.bytes_of_vec_znx(2, res_size) - } - - pub fn mul_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, grlwe_size: usize) -> usize { - (Self::mul_rlwe_inplace_scratch_space(module, res_size, grlwe_size) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, res_size) - } - - pub fn mul_grlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) - } - - pub fn mul_grlwe_inplace_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) - } - - pub fn mul_rgsw_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) - } - - pub fn mul_rgsw_inplace_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) - } } pub fn encrypt_rgsw_sk( @@ -195,67 +158,6 @@ impl RGSWCt { module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, ) } - - pub fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef, - { - MatZnxDftProducts::mul_rlwe(self, module, res, a, scratch); - } - - pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnx: VecZnxToMut + VecZnxToRef, - { - MatZnxDftProducts::mul_rlwe_inplace(self, module, res, scratch); - } - - pub fn mul_rlwe_dft( - &self, - module: &Module, - res: &mut RLWECtDft, - a: &RLWECtDft, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + ZnxInfos, - { - MatZnxDftProducts::mul_rlwe_dft(self, module, res, a, scratch); - } - - pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, - { - MatZnxDftProducts::mul_rlwe_dft_inplace(self, module, res, scratch); - } - - pub fn mul_grlwe( - &self, - module: &Module, - res: &mut GRLWECt, - a: &GRLWECt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToRef + ZnxInfos, - { - MatZnxDftProducts::mul_grlwe(self, module, res, a, scratch); - } - - pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - R: GetRow + SetRow + Infos, - { - MatZnxDftProducts::mul_grlwe_inplace(self, module, res, scratch); - } } impl GetRow for RGSWCt @@ -286,6 +188,12 @@ impl MatZnxDftProducts, C> for RGSWCt where MatZnxDft: MatZnxDftToRef + ZnxInfos, { + fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { + module.bytes_of_vec_znx_dft(2, rgsw_size) + + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) + | module.vec_znx_big_normalize_tmp_bytes()) + } + fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) where MatZnxDft: MatZnxDftToRef, @@ -318,3 +226,80 @@ where module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); } } + +impl ProdByScratchSpace for RGSWCt, FFT64> { + fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( + module, lhs, rhs, + ) + } + + fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( + module, lhs, rhs, + ) + } +} + +impl FromProdByScratchSpace for RGSWCt, FFT64> { + fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( + module, res_size, lhs, rhs, + ) + } + + fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl ProdBy> for RGSWCt +where + RGSWCt: GetRow + SetRow + Infos, +{ + fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe_inplace(module, self, scratch); + } + + fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe_inplace(module, self, scratch); + } +} + +impl FromProdBy, RGSWCt> for RGSWCt +where + RGSWCt: GetRow + SetRow + Infos, + RGSWCt: GetRow + Infos, +{ + fn from_prod_by_grlwe( + &mut self, + module: &Module, + lhs: &RGSWCt, + rhs: &GRLWECt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe(module, self, lhs, scratch); + } + + fn from_prod_by_rgsw( + &mut self, + module: &Module, + lhs: &RGSWCt, + rhs: &RGSWCt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe(module, self, lhs, scratch); + } +} diff --git a/core/src/rlwe.rs b/core/src/rlwe.rs index b52d56d..ef1be64 100644 --- a/core/src/rlwe.rs +++ b/core/src/rlwe.rs @@ -6,9 +6,10 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::Infos, + elem::{FromProdBy, FromProdByScratchSpace, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace}, grlwe::GRLWECt, keys::{PublicKey, SecretDistribution, SecretKeyDft}, + rgsw::RGSWCt, utils::derive_size, }; @@ -83,134 +84,70 @@ where } } -pub struct RLWEPt { - pub data: VecZnx, - pub log_base2k: usize, - pub log_k: usize, -} - -impl Infos for RLWEPt { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data +impl ProdByScratchSpace for RLWECt> { + fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_inplace_scratch_space( + module, lhs, rhs, + ) } - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k + fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_inplace_scratch_space( + module, lhs, rhs, + ) } } -impl VecZnxToMut for RLWEPt +impl FromProdByScratchSpace for RLWECt> { + fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_scratch_space( + module, res_size, lhs, rhs, + ) + } + + fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl ProdBy> for RLWECt where - VecZnx: VecZnxToMut, + VecZnx: VecZnxToMut + VecZnxToRef, { - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data.to_mut() - } -} - -impl VecZnxToRef for RLWEPt -where - VecZnx: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data.to_ref() - } -} - -impl RLWEPt> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -pub struct RLWECtDft { - pub data: VecZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} - -impl RLWECtDft, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for RLWECtDft { - type Inner = VecZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl VecZnxDftToMut for RLWECtDft -where - VecZnxDft: VecZnxDftToMut, -{ - fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { - self.data.to_mut() - } -} - -impl VecZnxDftToRef for RLWECtDft -where - VecZnxDft: VecZnxDftToRef, -{ - fn to_ref(&self) -> VecZnxDft<&[u8], B> { - self.data.to_ref() - } -} - -impl RLWECtDft -where - VecZnxDft: VecZnxDftToRef, -{ - #[allow(dead_code)] - pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { - module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) - } - - pub(crate) fn idft(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) where - VecZnx: VecZnxToMut, + MatZnxDft: MatZnxDftToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(self.cols(), 2); - assert_eq!(res.cols(), 2); - assert_eq!(self.log_base2k(), res.log_base2k()) - } + rhs.mul_rlwe_inplace(module, self, scratch); + } - let min_size: usize = self.size().min(res.size()); + fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_rlwe_inplace(module, self, scratch); + } +} - let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size); +impl FromProdBy, RLWECt> for RLWECt +where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, +{ + fn from_prod_by_grlwe(&mut self, module: &Module, lhs: &RLWECt, rhs: &GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_rlwe(module, self, lhs, scratch); + } - module.vec_znx_idft(&mut res_big, 0, &self.data, 0, scratch1); - module.vec_znx_idft(&mut res_big, 1, &self.data, 1, scratch1); - module.vec_znx_big_normalize(self.log_base2k(), res, 0, &res_big, 0, scratch1); - module.vec_znx_big_normalize(self.log_base2k(), res, 1, &res_big, 1, scratch1); + fn from_prod_by_rgsw(&mut self, module: &Module, lhs: &RLWECt, rhs: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_rlwe(module, self, lhs, scratch); } } @@ -390,6 +327,204 @@ impl RLWECt { } } +pub(crate) fn encrypt_rlwe_pk( + module: &Module, + ct: &mut RLWECt, + pt: Option<&RLWEPt

>, + pk: &PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx

: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, +{ + #[cfg(debug_assertions)] + { + assert_eq!(ct.log_base2k(), pk.log_base2k()); + assert_eq!(ct.n(), module.n()); + assert_eq!(pk.n(), module.n()); + if let Some(pt) = pt { + assert_eq!(pt.log_base2k(), pk.log_base2k()); + assert_eq!(pt.n(), module.n()); + } + } + + let log_base2k: usize = pk.log_base2k(); + let size_pk: usize = pk.size(); + + // Generates u according to the underlying secret distribution. + let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); + + { + let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); + match pk.dist { + SecretDistribution::NONE => panic!( + "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate" + ), + SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), + SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + SecretDistribution::ZERO => {} + } + + module.svp_prepare(&mut u_dft, 0, &u, 0); + } + + let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) + let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) + + // ct[0] = pk[0] * u + m + e0 + module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0); + module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); + tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); + + if let Some(pt) = pt { + module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0); + } + + module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3); + + // ct[1] = pk[1] * u + e1 + module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1); + module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); + tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); + module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3); +} + +pub struct RLWEPt { + pub data: VecZnx, + pub log_base2k: usize, + pub log_k: usize, +} + +impl Infos for RLWEPt { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl VecZnxToMut for RLWEPt +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for RLWEPt +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl RLWEPt> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +pub struct RLWECtDft { + pub data: VecZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl RLWECtDft, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for RLWECtDft { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl VecZnxDftToMut for RLWECtDft +where + VecZnxDft: VecZnxDftToMut, +{ + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl VecZnxDftToRef for RLWECtDft +where + VecZnxDft: VecZnxDftToRef, +{ + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl RLWECtDft +where + VecZnxDft: VecZnxDftToRef, +{ + #[allow(dead_code)] + pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { + module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) + } + + pub(crate) fn idft(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + where + VecZnx: VecZnxToMut, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.cols(), 2); + assert_eq!(res.cols(), 2); + assert_eq!(self.log_base2k(), res.log_base2k()) + } + + let min_size: usize = self.size().min(res.size()); + + let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size); + + module.vec_znx_idft(&mut res_big, 0, &self.data, 0, scratch1); + module.vec_znx_idft(&mut res_big, 1, &self.data, 1, scratch1); + module.vec_znx_big_normalize(self.log_base2k(), res, 0, &res_big, 0, scratch1); + module.vec_znx_big_normalize(self.log_base2k(), res, 1, &res_big, 1, scratch1); + } +} + pub(crate) fn encrypt_zero_rlwe_dft_sk( module: &Module, ct: &mut RLWECtDft, @@ -528,79 +663,81 @@ impl RLWECtDft { { decrypt_rlwe_dft(module, pt, self, sk_dft, scratch); } +} - pub fn mul_grlwe_assign(&mut self, module: &Module, a: &GRLWECt, scratch: &mut Scratch) - where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - MatZnxDft: MatZnxDftToRef, - { - a.mul_rlwe_dft_inplace(module, self, scratch); +impl ProdByScratchSpace for RLWECtDft, FFT64> { + fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_inplace_scratch_space( + module, lhs, rhs, + ) + } + + fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_inplace_scratch_space( + module, lhs, rhs, + ) } } -pub(crate) fn encrypt_rlwe_pk( - module: &Module, - ct: &mut RLWECt, - pt: Option<&RLWEPt

>, - pk: &PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx

: VecZnxToRef, - VecZnxDft: VecZnxDftToRef, +impl FromProdByScratchSpace for RLWECtDft, FFT64> { + fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_scratch_space( + module, res_size, lhs, rhs, + ) + } + + fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl ProdBy> for RLWECtDft +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, { - #[cfg(debug_assertions)] + fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, { - assert_eq!(ct.log_base2k(), pk.log_base2k()); - assert_eq!(ct.n(), module.n()); - assert_eq!(pk.n(), module.n()); - if let Some(pt) = pt { - assert_eq!(pt.log_base2k(), pk.log_base2k()); - assert_eq!(pt.n(), module.n()); - } + rhs.mul_rlwe_dft_inplace(module, self, scratch); } - let log_base2k: usize = pk.log_base2k(); - let size_pk: usize = pk.size(); - - // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); - + fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, { - let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); - match pk.dist { - SecretDistribution::NONE => panic!( - "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate" - ), - SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), - SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), - SecretDistribution::ZERO => {} - } - - module.svp_prepare(&mut u_dft, 0, &u, 0); + rhs.mul_rlwe_dft_inplace(module, self, scratch); + } +} + +impl FromProdBy, RLWECtDft> for RLWECtDft +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + VecZnxDft: VecZnxDftToRef, +{ + fn from_prod_by_grlwe( + &mut self, + module: &Module, + lhs: &RLWECtDft, + rhs: &GRLWECt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_rlwe_dft(module, self, lhs, scratch); + } + + fn from_prod_by_rgsw( + &mut self, + module: &Module, + lhs: &RLWECtDft, + rhs: &RGSWCt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_rlwe_dft(module, self, lhs, scratch); } - - let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) - let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) - - // ct[0] = pk[0] * u + m + e0 - module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0); - module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); - tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); - - if let Some(pt) = pt { - module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0); - } - - module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3); - - // ct[1] = pk[1] * u + e1 - module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1); - module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); - tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); - module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3); } diff --git a/core/src/test_fft64/grlwe.rs b/core/src/test_fft64/grlwe.rs index 86c13ec..294411b 100644 --- a/core/src/test_fft64/grlwe.rs +++ b/core/src/test_fft64/grlwe.rs @@ -1,14 +1,14 @@ #[cfg(test)] mod tests { - use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps}; + use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; use crate::{ - elem::Infos, + elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace}, grlwe::GRLWECt, keys::{SecretKey, SecretKeyDft}, - rlwe::{RLWECt, RLWECtDft, RLWEPt}, + rlwe::{RLWECtDft, RLWEPt}, test_fft64::grlwe::noise_grlwe_rlwe_product, }; @@ -67,413 +67,7 @@ mod tests { } #[test] - fn mul_rlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | GRLWECt::mul_rlwe_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_grlwe.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_grlwe.mul_rlwe(&module, &mut ct_rlwe_out, &ct_rlwe_in, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn mul_rlwe_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe: usize = 45; - let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | GRLWECt::mul_rlwe_scratch_space(&module, ct_rlwe.size(), ct_rlwe.size(), ct_grlwe.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_grlwe.mul_rlwe_inplace(&module, &mut ct_rlwe, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn mul_rlwe_dft() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_in_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut ct_rlwe_out_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | GRLWECt::mul_rlwe_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_grlwe.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); - ct_grlwe.mul_rlwe_dft( - &module, - &mut ct_rlwe_out_dft, - &ct_rlwe_in_dft, - scratch.borrow(), - ); - ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn mul_rlwe_dft_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe: usize = 45; - let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | GRLWECt::mul_rlwe_scratch_space(&module, ct_rlwe.size(), ct_rlwe.size(), ct_grlwe.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.dft(&module, &mut ct_rlwe_dft); - ct_grlwe.mul_rlwe_dft_inplace(&module, &mut ct_rlwe_dft, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn mul_grlwe() { + fn from_prod_by_grlwe() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -493,7 +87,7 @@ mod tests { let mut scratch: ScratchOwned = ScratchOwned::new( GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) - | GRLWECt::mul_grlwe_scratch_space( + | GRLWECt::from_prod_by_grlwe_scratch_space( &module, ct_grlwe_s0s2.size(), ct_grlwe_s0s1.size(), @@ -544,12 +138,7 @@ mod tests { ); // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s1s2.mul_grlwe( - &module, - &mut ct_grlwe_s0s2, - &ct_grlwe_s0s1, - scratch.borrow(), - ); + ct_grlwe_s0s2.from_prod_by_grlwe(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); @@ -584,7 +173,7 @@ mod tests { } #[test] - fn mul_grlwe_inplace() { + fn prod_by_grlwe() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -603,12 +192,7 @@ mod tests { let mut scratch: ScratchOwned = ScratchOwned::new( GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) - | GRLWECt::mul_grlwe_scratch_space( - &module, - ct_grlwe_s0s1.size(), - ct_grlwe_s0s1.size(), - ct_grlwe_s1s2.size(), - ), + | GRLWECt::prod_by_grlwe_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), ); let mut sk0: SecretKey> = SecretKey::new(&module); @@ -654,7 +238,7 @@ mod tests { ); // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s1s2.mul_grlwe_inplace(&module, &mut ct_grlwe_s0s1, scratch.borrow()); + ct_grlwe_s0s1.prod_by_grlwe(&module, &ct_grlwe_s1s2, scratch.borrow()); let ct_grlwe_s0s2: GRLWECt, FFT64> = ct_grlwe_s0s1; diff --git a/core/src/test_fft64/mod.rs b/core/src/test_fft64/mod.rs index 36d380c..59e2895 100644 --- a/core/src/test_fft64/mod.rs +++ b/core/src/test_fft64/mod.rs @@ -1,3 +1,4 @@ mod grlwe; mod rgsw; mod rlwe; +mod rlwe_dft; diff --git a/core/src/test_fft64/rgsw.rs b/core/src/test_fft64/rgsw.rs index 651f6b1..83df85b 100644 --- a/core/src/test_fft64/rgsw.rs +++ b/core/src/test_fft64/rgsw.rs @@ -2,7 +2,7 @@ mod tests { use base2k::{ FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero, }; use sampling::source::Source; @@ -86,120 +86,6 @@ mod tests { module.free(); } - - #[test] - fn mul_rlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - // pt_want - // .data - // .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - pt_want.to_mut().at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RGSWCt::mul_rlwe_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_rgsw.size(), - ), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - Some(&pt_want), - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rgsw.mul_rlwe(&module, &mut ct_rlwe_out, &ct_rlwe_in, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } } #[allow(dead_code)] diff --git a/core/src/test_fft64/rlwe.rs b/core/src/test_fft64/rlwe.rs index e735aa6..acc10a1 100644 --- a/core/src/test_fft64/rlwe.rs +++ b/core/src/test_fft64/rlwe.rs @@ -1,13 +1,19 @@ #[cfg(test)] -mod tests { - use base2k::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero}; +mod tests_rlwe { + use base2k::{ + Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, + ZnxViewMut, ZnxZero, + }; use itertools::izip; use sampling::source::Source; use crate::{ - elem::Infos, + elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace}, + grlwe::GRLWECt, keys::{PublicKey, SecretKey, SecretKeyDft}, + rgsw::RGSWCt, rlwe::{RLWECt, RLWECtDft, RLWEPt}, + test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_rlwe_product}, }; #[test] @@ -193,4 +199,423 @@ mod tests { module.free(); } + + #[test] + fn from_prod_by_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECt::from_prod_by_grlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_grlwe.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_out.from_prod_by_grlwe(&module, &ct_rlwe_in, &ct_grlwe, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn prod_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe: usize = 45; + let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECt::prod_by_grlwe_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.prod_by_grlwe(&module, &ct_grlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn from_prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECt::from_prod_by_rgsw_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_rgsw.size(), + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_out.from_prod_by_rgsw(&module, &ct_rlwe_in, &ct_rgsw, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECt::prod_by_rgsw_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.prod_by_rgsw(&module, &ct_rgsw, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } } diff --git a/core/src/test_fft64/rlwe_dft.rs b/core/src/test_fft64/rlwe_dft.rs new file mode 100644 index 0000000..fe0038d --- /dev/null +++ b/core/src/test_fft64/rlwe_dft.rs @@ -0,0 +1,216 @@ +#[cfg(test)] +mod tests { + use crate::{ + elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace}, + grlwe::GRLWECt, + keys::{SecretKey, SecretKeyDft}, + rlwe::{RLWECt, RLWECtDft, RLWEPt}, + test_fft64::grlwe::noise_grlwe_rlwe_product, + }; + use base2k::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; + use sampling::source::Source; + + #[test] + fn from_prod_by_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_in_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rlwe_out_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECtDft::from_prod_by_grlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_grlwe.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); + ct_rlwe_out_dft.from_prod_by_grlwe(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); + ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn prod_by_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe: usize = 45; + let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECtDft::prod_by_grlwe_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.prod_by_grlwe(&module, &ct_grlwe, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } +} From e38ca404f9af08e2ae74d69a8c680c5bcd3a746c Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 12 May 2025 09:27:04 +0200 Subject: [PATCH 59/87] Added tests for GRLWE x RGSW --- core/src/elem.rs | 4 +- core/src/test_fft64/grlwe.rs | 227 +++++++++++++++++++++++++++++- core/src/test_fft64/rlwe_dft.rs | 236 +++++++++++++++++++++++++++++++- 3 files changed, 461 insertions(+), 6 deletions(-) diff --git a/core/src/elem.rs b/core/src/elem.rs index 192bc74..94311bc 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -231,7 +231,7 @@ where let min_rows: usize = res.rows().min(a.rows()); (0..res.rows()).for_each(|row_i| { - (0..self.cols()).for_each(|col_j| { + (0..res.cols()).for_each(|col_j| { a.get_row(module, row_i, col_j, &mut tmp_row); self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); res.set_row(module, row_i, col_j, &tmp_row); @@ -260,7 +260,7 @@ where }; (0..res.rows()).for_each(|row_i| { - (0..self.cols()).for_each(|col_j| { + (0..res.cols()).for_each(|col_j| { res.get_row(module, row_i, col_j, &mut tmp_row); self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); res.set_row(module, row_i, col_j, &tmp_row); diff --git a/core/src/test_fft64/grlwe.rs b/core/src/test_fft64/grlwe.rs index 294411b..44fefd6 100644 --- a/core/src/test_fft64/grlwe.rs +++ b/core/src/test_fft64/grlwe.rs @@ -1,15 +1,16 @@ #[cfg(test)] mod tests { - use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps}; + use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; use sampling::source::Source; use crate::{ elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace}, grlwe::GRLWECt, keys::{SecretKey, SecretKeyDft}, + rgsw::RGSWCt, rlwe::{RLWECtDft, RLWEPt}, - test_fft64::grlwe::noise_grlwe_rlwe_product, + test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_rlwe_product}, }; #[test] @@ -273,6 +274,228 @@ mod tests { module.free(); } + + #[test] + fn from_prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe_in: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_out: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_in.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_out.size()) + | GRLWECt::from_prod_by_rgsw_scratch_space( + &module, + ct_grlwe_out.size(), + ct_grlwe_in.size(), + ct_rgsw.size(), + ) + | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()), + ); + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_in.encrypt_sk( + &module, + &pt_grlwe, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) + ct_grlwe_out.from_prod_by_rgsw(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); + + let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); + + (0..ct_grlwe_out.rows()).for_each(|row_i| { + ct_grlwe_out.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + + module.free(); + } + + #[test] + fn prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe.size()) + | GRLWECt::prod_by_rgsw_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) + | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()), + ); + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe.encrypt_sk( + &module, + &pt_grlwe, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) + ct_grlwe.prod_by_rgsw(&module, &ct_rgsw, scratch.borrow()); + + let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); + + (0..ct_grlwe.rows()).for_each(|row_i| { + ct_grlwe.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + + module.free(); + } } #[allow(dead_code)] diff --git a/core/src/test_fft64/rlwe_dft.rs b/core/src/test_fft64/rlwe_dft.rs index fe0038d..448bdfb 100644 --- a/core/src/test_fft64/rlwe_dft.rs +++ b/core/src/test_fft64/rlwe_dft.rs @@ -4,10 +4,13 @@ mod tests { elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace}, grlwe::GRLWECt, keys::{SecretKey, SecretKeyDft}, + rgsw::RGSWCt, rlwe::{RLWECt, RLWECtDft, RLWEPt}, - test_fft64::grlwe::noise_grlwe_rlwe_product, + test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_rlwe_product}, + }; + use base2k::{ + FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut, }; - use base2k::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; #[test] @@ -213,4 +216,233 @@ mod tests { module.free(); } + + #[test] + fn from_prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rlwe_dft_in: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_dft_out: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECt::from_prod_by_rgsw_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_rgsw.size(), + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.dft(&module, &mut ct_rlwe_dft_in); + ct_rlwe_dft_out.from_prod_by_rgsw(&module, &ct_rlwe_dft_in, &ct_rgsw, scratch.borrow()); + ct_rlwe_dft_out.idft(&module, &mut ct_rlwe_out, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECt::prod_by_rgsw_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.prod_by_rgsw(&module, &ct_rgsw, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } } From d8a7d6cdaf16b016d623f2dadb8cb91195058031 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 12 May 2025 14:40:17 +0200 Subject: [PATCH 60/87] Some traits updates + added missing tests for products on RGSWCt --- core/src/elem.rs | 150 ++-- core/src/grlwe.rs | 108 +-- core/src/lib.rs | 1 + core/src/rgsw.rs | 96 +-- core/src/rlwe.rs | 148 ++-- core/src/test_fft64/grlwe.rs | 993 +++++++++++++------------ core/src/test_fft64/rgsw.rs | 627 ++++++++++++++-- core/src/test_fft64/rlwe.rs | 1197 +++++++++++++++---------------- core/src/test_fft64/rlwe_dft.rs | 889 ++++++++++++----------- 9 files changed, 2295 insertions(+), 1914 deletions(-) diff --git a/core/src/elem.rs b/core/src/elem.rs index 94311bc..b66c86d 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -66,92 +66,88 @@ pub trait SetRow { VecZnxDft: VecZnxDftToRef; } -pub trait ProdByScratchSpace { - fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; - fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; +pub trait ProdInplaceScratchSpace { + fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; + fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; } -pub trait ProdBy { - fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef; - - fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef; -} - -pub trait FromProdByScratchSpace { - fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; - fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; -} - -pub trait FromProdBy { - fn from_prod_by_grlwe(&mut self, module: &Module, lhs: &L, rhs: &GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef; - - fn from_prod_by_rgsw(&mut self, module: &Module, lhs: &L, rhs: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef; -} - -pub(crate) trait MatZnxDftProducts: Infos +pub trait ProdInplace where - MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToRef, { - fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef; + fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch); + fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch); +} - fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize; +pub trait ProdScratchSpace { + fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; + fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; +} - fn mul_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { - Self::mul_rlwe_scratch_space(module, res_size, res_size, mat_size) +pub trait Product +where + MatZnxDft: MatZnxDftToRef, +{ + type Lhs; + + fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch); + fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch); +} + +pub(crate) trait MatRLWEProductScratchSpace { + fn prod_with_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize; + + fn prod_with_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { + Self::prod_with_rlwe_scratch_space(module, res_size, res_size, mat_size) } - fn mul_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { - (Self::mul_rlwe_scratch_space(module, res_size, a_size, mat_size) | module.vec_znx_idft_tmp_bytes()) + fn prod_with_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { + (Self::prod_with_rlwe_scratch_space(module, res_size, a_size, mat_size) | module.vec_znx_idft_tmp_bytes()) + module.bytes_of_vec_znx(2, a_size) + module.bytes_of_vec_znx(2, res_size) } - fn mul_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { - (Self::mul_rlwe_inplace_scratch_space(module, res_size, mat_size) | module.vec_znx_idft_tmp_bytes()) + fn prod_with_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { + (Self::prod_with_rlwe_inplace_scratch_space(module, res_size, mat_size) | module.vec_znx_idft_tmp_bytes()) + module.bytes_of_vec_znx(2, res_size) } - fn mul_mat_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, a_size) + fn prod_with_mat_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { + Self::prod_with_rlwe_dft_scratch_space(module, res_size, a_size, mat_size) + + module.bytes_of_vec_znx_dft(2, a_size) + + module.bytes_of_vec_znx_dft(2, res_size) } - fn mul_mat_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, res_size) + fn prod_with_mat_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { + Self::prod_with_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, res_size) } +} - fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) +pub(crate) trait MatRLWEProduct: Infos { + fn prod_with_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef; + + fn prod_with_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + where + VecZnx: VecZnxToMut + VecZnxToRef, { unsafe { - let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. - self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); + let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. + self.prod_with_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); } } - fn mul_rlwe_dft( + fn prod_with_rlwe_dft( &self, module: &Module, - res: &mut RLWECtDft, - a: &RLWECtDft, + res: &mut RLWECtDft, + a: &RLWECtDft, scratch: &mut Scratch, ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + ZnxInfos, { let log_base2k: usize = self.log_base2k(); @@ -180,15 +176,15 @@ where log_k: res.log_k(), }; - self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2); + self.prod_with_rlwe(module, &mut res_idft, &a_idft, scratch_2); module.vec_znx_dft(res, 0, &res_idft, 0); module.vec_znx_dft(res, 1, &res_idft, 1); } - fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) + fn prod_with_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) where - VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, + VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, { let log_base2k: usize = self.log_base2k(); @@ -209,47 +205,55 @@ where res.idft(module, &mut res_idft, scratch_1); - self.mul_rlwe_inplace(module, &mut res_idft, scratch_1); + self.prod_with_rlwe_inplace(module, &mut res_idft, scratch_1); module.vec_znx_dft(res, 0, &res_idft, 0); module.vec_znx_dft(res, 1, &res_idft, 1); } - fn mul_mat_rlwe(&self, module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) + fn prod_with_mat_rlwe(&self, module: &Module, res: &mut RES, a: &LHS, scratch: &mut Scratch) where - A: GetRow + Infos, - R: SetRow + Infos, + LHS: GetRow + Infos, + RES: SetRow + Infos, { let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + let mut tmp_a_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { data: tmp_row_data, log_base2k: a.log_base2k(), log_k: a.log_k(), }; + let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, res.size()); + + let mut tmp_res_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_res_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + let min_rows: usize = res.rows().min(a.rows()); (0..res.rows()).for_each(|row_i| { (0..res.cols()).for_each(|col_j| { - a.get_row(module, row_i, col_j, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, col_j, &tmp_row); + a.get_row(module, row_i, col_j, &mut tmp_a_row); + self.prod_with_rlwe_dft(module, &mut tmp_res_row, &tmp_a_row, scratch2); + res.set_row(module, row_i, col_j, &tmp_res_row); }); }); - tmp_row.data.zero(); + tmp_res_row.data.zero(); (min_rows..res.rows()).for_each(|row_i| { (0..self.cols()).for_each(|col_j| { - res.set_row(module, row_i, col_j, &tmp_row); + res.set_row(module, row_i, col_j, &tmp_res_row); }); }); } - fn mul_mat_rlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) + fn prod_with_mat_rlwe_inplace(&self, module: &Module, res: &mut RES, scratch: &mut Scratch) where - R: GetRow + SetRow + Infos, + RES: GetRow + SetRow + Infos, { let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); @@ -262,7 +266,7 @@ where (0..res.rows()).for_each(|row_i| { (0..res.cols()).for_each(|col_j| { res.get_row(module, row_i, col_j, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + self.prod_with_rlwe_dft_inplace(module, &mut tmp_row, scratch1); res.set_row(module, row_i, col_j, &tmp_row); }); }); diff --git a/core/src/grlwe.rs b/core/src/grlwe.rs index 9c8c5b8..80c976d 100644 --- a/core/src/grlwe.rs +++ b/core/src/grlwe.rs @@ -7,7 +7,10 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{FromProdBy, FromProdByScratchSpace, GetRow, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace, SetRow}, + elem::{ + GetRow, Infos, MatRLWEProduct, MatRLWEProductScratchSpace, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, + Product, SetRow, + }, keys::SecretKeyDft, rgsw::RGSWCt, rlwe::{RLWECt, RLWECtDft, RLWEPt}, @@ -30,18 +33,6 @@ impl GRLWECt, B> { } } -impl GRLWECt -where - MatZnxDft: MatZnxDftToRef, -{ - pub fn get_row(&self, module: &Module, row_i: usize, res: &mut RLWECtDft) - where - VecZnxDft: VecZnxDftToMut, - { - module.vmp_extract_row(res, self, row_i, 0); - } -} - impl Infos for GRLWECt { type Inner = MatZnxDft; @@ -202,18 +193,20 @@ where } } -impl MatZnxDftProducts, C> for GRLWECt -where - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { +impl MatRLWEProductScratchSpace for GRLWECt, FFT64> { + fn prod_with_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { module.bytes_of_vec_znx_dft(2, grlwe_size) + (module.vec_znx_big_normalize_tmp_bytes() | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) + module.bytes_of_vec_znx_dft(1, a_size))) } +} - fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) +impl MatRLWEProduct for GRLWECt +where + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + fn prod_with_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) where MatZnxDft: MatZnxDftToRef, VecZnx: VecZnxToMut, @@ -247,79 +240,52 @@ where } } -impl ProdByScratchSpace for GRLWECt, FFT64> { - fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( - module, lhs, rhs, - ) +impl ProdInplaceScratchSpace for GRLWECt, FFT64> { + fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) } - fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( - module, lhs, rhs, - ) + fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) } } -impl FromProdByScratchSpace for GRLWECt, FFT64> { - fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( - module, res_size, lhs, rhs, - ) +impl ProdScratchSpace for GRLWECt, FFT64> { + fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) } - fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( - module, res_size, lhs, rhs, - ) + fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) } } -impl ProdBy> for GRLWECt +impl ProdInplace for GRLWECt where GRLWECt: GetRow + SetRow + Infos, + MatZnxDft: MatZnxDftToRef, { - fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe_inplace(module, self, scratch); + fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe_inplace(module, self, scratch); } - fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe_inplace(module, self, scratch); + fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe_inplace(module, self, scratch); } } -impl FromProdBy, GRLWECt> for GRLWECt +impl Product for GRLWECt where - GRLWECt: GetRow + SetRow + Infos, - GRLWECt: GetRow + Infos, + MatZnxDft: MatZnxDftToRef + MatZnxDftToMut, + MatZnxDft: MatZnxDftToRef, { - fn from_prod_by_grlwe( - &mut self, - module: &Module, - lhs: &GRLWECt, - rhs: &GRLWECt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe(module, self, lhs, scratch); + type Lhs = GRLWECt; + + fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe(module, self, lhs, scratch); } - fn from_prod_by_rgsw( - &mut self, - module: &Module, - lhs: &GRLWECt, - rhs: &RGSWCt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe(module, self, lhs, scratch); + fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe(module, self, lhs, scratch); } } diff --git a/core/src/lib.rs b/core/src/lib.rs index a93d44e..bed71cc 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -3,5 +3,6 @@ pub mod grlwe; pub mod keys; pub mod rgsw; pub mod rlwe; +#[cfg(test)] mod test_fft64; mod utils; diff --git a/core/src/rgsw.rs b/core/src/rgsw.rs index c4c7c1c..b866252 100644 --- a/core/src/rgsw.rs +++ b/core/src/rgsw.rs @@ -7,7 +7,10 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{FromProdBy, FromProdByScratchSpace, GetRow, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace, SetRow}, + elem::{ + GetRow, Infos, MatRLWEProduct, MatRLWEProductScratchSpace, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, + Product, SetRow, + }, grlwe::GRLWECt, keys::SecretKeyDft, rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, @@ -184,17 +187,19 @@ where } } -impl MatZnxDftProducts, C> for RGSWCt -where - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { +impl MatRLWEProductScratchSpace for RGSWCt, FFT64> { + fn prod_with_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { module.bytes_of_vec_znx_dft(2, rgsw_size) + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) | module.vec_znx_big_normalize_tmp_bytes()) } +} - fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) +impl MatRLWEProduct for RGSWCt +where + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + fn prod_with_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) where MatZnxDft: MatZnxDftToRef, VecZnx: VecZnxToMut, @@ -227,79 +232,52 @@ where } } -impl ProdByScratchSpace for RGSWCt, FFT64> { - fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( - module, lhs, rhs, - ) +impl ProdInplaceScratchSpace for RGSWCt, FFT64> { + fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) } - fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( - module, lhs, rhs, - ) + fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) } } -impl FromProdByScratchSpace for RGSWCt, FFT64> { - fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( - module, res_size, lhs, rhs, - ) +impl ProdScratchSpace for RGSWCt, FFT64> { + fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) } - fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( - module, res_size, lhs, rhs, - ) + fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) } } -impl ProdBy> for RGSWCt +impl ProdInplace for RGSWCt where RGSWCt: GetRow + SetRow + Infos, + MatZnxDft: MatZnxDftToRef, { - fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe_inplace(module, self, scratch); + fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe_inplace(module, self, scratch); } - fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe_inplace(module, self, scratch); + fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe_inplace(module, self, scratch); } } -impl FromProdBy, RGSWCt> for RGSWCt +impl Product for RGSWCt where - RGSWCt: GetRow + SetRow + Infos, - RGSWCt: GetRow + Infos, + MatZnxDft: MatZnxDftToRef + MatZnxDftToMut, + MatZnxDft: MatZnxDftToRef, { - fn from_prod_by_grlwe( - &mut self, - module: &Module, - lhs: &RGSWCt, - rhs: &GRLWECt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe(module, self, lhs, scratch); + type Lhs = RGSWCt; + + fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe(module, self, lhs, scratch); } - fn from_prod_by_rgsw( - &mut self, - module: &Module, - lhs: &RGSWCt, - rhs: &RGSWCt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe(module, self, lhs, scratch); + fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe(module, self, lhs, scratch); } } diff --git a/core/src/rlwe.rs b/core/src/rlwe.rs index ef1be64..2dab803 100644 --- a/core/src/rlwe.rs +++ b/core/src/rlwe.rs @@ -6,7 +6,7 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{FromProdBy, FromProdByScratchSpace, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace}, + elem::{Infos, MatRLWEProduct, MatRLWEProductScratchSpace, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, grlwe::GRLWECt, keys::{PublicKey, SecretDistribution, SecretKeyDft}, rgsw::RGSWCt, @@ -84,70 +84,54 @@ where } } -impl ProdByScratchSpace for RLWECt> { - fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_inplace_scratch_space( - module, lhs, rhs, - ) +impl ProdInplaceScratchSpace for RLWECt> { + fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_inplace_scratch_space(module, lhs, rhs) } - fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_inplace_scratch_space( - module, lhs, rhs, - ) + fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_inplace_scratch_space(module, lhs, rhs) } } -impl FromProdByScratchSpace for RLWECt> { - fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_scratch_space( - module, res_size, lhs, rhs, - ) +impl ProdScratchSpace for RLWECt> { + fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_scratch_space(module, res_size, lhs, rhs) } - fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_scratch_space( - module, res_size, lhs, rhs, - ) + fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_scratch_space(module, res_size, lhs, rhs) } } -impl ProdBy> for RLWECt +impl ProdInplace for RLWECt where VecZnx: VecZnxToMut + VecZnxToRef, + MatZnxDft: MatZnxDftToRef, { - fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe_inplace(module, self, scratch); + fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_rlwe_inplace(module, self, scratch); } - fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe_inplace(module, self, scratch); + fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_rlwe_inplace(module, self, scratch); } } -impl FromProdBy, RLWECt> for RLWECt +impl Product for RLWECt where VecZnx: VecZnxToMut + VecZnxToRef, VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, { - fn from_prod_by_grlwe(&mut self, module: &Module, lhs: &RLWECt, rhs: &GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe(module, self, lhs, scratch); + type Lhs = RLWECt; + + fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_rlwe(module, self, lhs, scratch); } - fn from_prod_by_rgsw(&mut self, module: &Module, lhs: &RLWECt, rhs: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe(module, self, lhs, scratch); + fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_rlwe(module, self, lhs, scratch); } } @@ -496,7 +480,7 @@ where impl RLWECtDft where - VecZnxDft: VecZnxDftToRef, + RLWECtDft: VecZnxDftToRef, { #[allow(dead_code)] pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { @@ -505,7 +489,7 @@ where pub(crate) fn idft(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) where - VecZnx: VecZnxToMut, + RLWECt: VecZnxToMut, { #[cfg(debug_assertions)] { @@ -518,8 +502,8 @@ where let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size); - module.vec_znx_idft(&mut res_big, 0, &self.data, 0, scratch1); - module.vec_znx_idft(&mut res_big, 1, &self.data, 1, scratch1); + module.vec_znx_idft(&mut res_big, 0, self, 0, scratch1); + module.vec_znx_idft(&mut res_big, 1, self, 1, scratch1); module.vec_znx_big_normalize(self.log_base2k(), res, 0, &res_big, 0, scratch1); module.vec_znx_big_normalize(self.log_base2k(), res, 1, &res_big, 1, scratch1); } @@ -665,79 +649,53 @@ impl RLWECtDft { } } -impl ProdByScratchSpace for RLWECtDft, FFT64> { - fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_inplace_scratch_space( - module, lhs, rhs, - ) +impl ProdInplaceScratchSpace for RLWECtDft, FFT64> { + fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_inplace_scratch_space(module, lhs, rhs) } - fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_inplace_scratch_space( - module, lhs, rhs, - ) + fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_inplace_scratch_space(module, lhs, rhs) } } -impl FromProdByScratchSpace for RLWECtDft, FFT64> { - fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_scratch_space( - module, res_size, lhs, rhs, - ) +impl ProdScratchSpace for RLWECtDft, FFT64> { + fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_scratch_space(module, res_size, lhs, rhs) } - fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_scratch_space( - module, res_size, lhs, rhs, - ) + fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_scratch_space(module, res_size, lhs, rhs) } } -impl ProdBy> for RLWECtDft +impl ProdInplace for RLWECtDft where VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, { - fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe_dft_inplace(module, self, scratch); + fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_rlwe_dft_inplace(module, self, scratch); } - fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe_dft_inplace(module, self, scratch); + fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_rlwe_dft_inplace(module, self, scratch); } } -impl FromProdBy, RLWECtDft> for RLWECtDft +impl Product for RLWECtDft where VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, { - fn from_prod_by_grlwe( - &mut self, - module: &Module, - lhs: &RLWECtDft, - rhs: &GRLWECt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe_dft(module, self, lhs, scratch); + type Lhs = RLWECtDft; + + fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_rlwe_dft(module, self, lhs, scratch); } - fn from_prod_by_rgsw( - &mut self, - module: &Module, - lhs: &RLWECtDft, - rhs: &RGSWCt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe_dft(module, self, lhs, scratch); + fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_rlwe_dft(module, self, lhs, scratch); } } diff --git a/core/src/test_fft64/grlwe.rs b/core/src/test_fft64/grlwe.rs index 44fefd6..81c1023 100644 --- a/core/src/test_fft64/grlwe.rs +++ b/core/src/test_fft64/grlwe.rs @@ -1,504 +1,499 @@ -#[cfg(test)] - -mod tests { - use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; - use sampling::source::Source; - - use crate::{ - elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace}, - grlwe::GRLWECt, - keys::{SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECtDft, RLWEPt}, - test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_rlwe_product}, - }; - - #[test] - fn encrypt_sk() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let rows: usize = 4; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); - - (0..ct.rows()).for_each(|row_i| { - ct.get_row(&module, row_i, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); - let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); - }); - - module.free(); - } - - #[test] - fn from_prod_by_grlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s0s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) - | GRLWECt::from_prod_by_grlwe_scratch_space( - &module, - ct_grlwe_s0s2.size(), - ct_grlwe_s0s1.size(), - ct_grlwe_s1s2.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - let mut sk2: SecretKey> = SecretKey::new(&module); - sk2.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk2_dft.dft(&module, &sk2); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_s0s1.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - // GRLWE_{s2}(s1) -> s1 -> s2 - ct_grlwe_s1s2.encrypt_sk( - &module, - &sk1.data, - &sk2_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s0s2.from_prod_by_grlwe(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); - - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); - - (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { - ct_grlwe_s0s2.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); - - module.free(); - } - - #[test] - fn prod_by_grlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) - | GRLWECt::prod_by_grlwe_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - let mut sk2: SecretKey> = SecretKey::new(&module); - sk2.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk2_dft.dft(&module, &sk2); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_s0s1.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - // GRLWE_{s2}(s1) -> s1 -> s2 - ct_grlwe_s1s2.encrypt_sk( - &module, - &sk1.data, - &sk2_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s0s1.prod_by_grlwe(&module, &ct_grlwe_s1s2, scratch.borrow()); - - let ct_grlwe_s0s2: GRLWECt, FFT64> = ct_grlwe_s0s1; - - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); - - (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { - ct_grlwe_s0s2.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); - - module.free(); - } - - #[test] - fn from_prod_by_rgsw() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe_in: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_out: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_in.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_out.size()) - | GRLWECt::from_prod_by_rgsw_scratch_space( - &module, - ct_grlwe_out.size(), - ct_grlwe_in.size(), - ct_rgsw.size(), - ) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()), - ); - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_in.encrypt_sk( - &module, - &pt_grlwe, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) - ct_grlwe_out.from_prod_by_rgsw(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); - - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); - - (0..ct_grlwe_out.rows()).for_each(|row_i| { - ct_grlwe_out.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); - - module.free(); - } - - #[test] - fn prod_by_rgsw() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe.size()) - | GRLWECt::prod_by_rgsw_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()), - ); - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe.encrypt_sk( - &module, - &pt_grlwe, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) - ct_grlwe.prod_by_rgsw(&module, &ct_rgsw, scratch.borrow()); - - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); - - (0..ct_grlwe.rows()).for_each(|row_i| { - ct_grlwe.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); - - module.free(); - } +use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, + grlwe::GRLWECt, + keys::{SecretKey, SecretKeyDft}, + rgsw::RGSWCt, + rlwe::{RLWECtDft, RLWEPt}, + test_fft64::rgsw::noise_rgsw_product, +}; + +#[test] +fn encrypt_sk() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let rows: usize = 4; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + + (0..ct.rows()).for_each(|row_i| { + ct.get_row(&module, row_i, 0, &mut ct_rlwe_dft); + ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); + let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + }); + + module.free(); +} + +#[test] +fn from_prod_by_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s0s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) + | GRLWECt::prod_by_grlwe_scratch_space( + &module, + ct_grlwe_s0s2.size(), + ct_grlwe_s0s1.size(), + ct_grlwe_s1s2.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + let mut sk2: SecretKey> = SecretKey::new(&module); + sk2.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk2_dft.dft(&module, &sk2); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_s0s1.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s2}(s1) -> s1 -> s2 + ct_grlwe_s1s2.encrypt_sk( + &module, + &sk1.data, + &sk2_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) + ct_grlwe_s0s2.prod_by_grlwe(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); + + let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + + (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { + ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + + module.free(); +} + +#[test] +fn prod_by_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) + | GRLWECt::prod_by_grlwe_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + let mut sk2: SecretKey> = SecretKey::new(&module); + sk2.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk2_dft.dft(&module, &sk2); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_s0s1.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s2}(s1) -> s1 -> s2 + ct_grlwe_s1s2.encrypt_sk( + &module, + &sk1.data, + &sk2_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) + ct_grlwe_s0s1.prod_by_grlwe_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); + + let ct_grlwe_s0s2: GRLWECt, FFT64> = ct_grlwe_s0s1; + + let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + + (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { + ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + + module.free(); +} + +#[test] +fn from_prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe_in: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_out: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_in.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_out.size()) + | GRLWECt::prod_by_rgsw_scratch_space( + &module, + ct_grlwe_out.size(), + ct_grlwe_in.size(), + ct_rgsw.size(), + ) + | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()), + ); + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_in.encrypt_sk( + &module, + &pt_grlwe, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) + ct_grlwe_out.prod_by_rgsw(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); + + let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); + + (0..ct_grlwe_out.rows()).for_each(|row_i| { + ct_grlwe_out.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + + module.free(); +} + +#[test] +fn prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe.size()) + | GRLWECt::prod_by_rgsw_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) + | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()), + ); + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe.encrypt_sk( + &module, + &pt_grlwe, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) + ct_grlwe.prod_by_rgsw_inplace(&module, &ct_rgsw, scratch.borrow()); + + let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); + + (0..ct_grlwe.rows()).for_each(|row_i| { + ct_grlwe.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + + module.free(); } -#[allow(dead_code)] pub(crate) fn noise_grlwe_rlwe_product( n: f64, log_base2k: usize, diff --git a/core/src/test_fft64/rgsw.rs b/core/src/test_fft64/rgsw.rs index 83df85b..50cd356 100644 --- a/core/src/test_fft64/rgsw.rs +++ b/core/src/test_fft64/rgsw.rs @@ -1,95 +1,582 @@ -#[cfg(test)] -mod tests { - use base2k::{ - FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero, - }; - use sampling::source::Source; +use base2k::{ + FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, +}; +use sampling::source::Source; - use crate::{ - elem::{GetRow, Infos}, - keys::{SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECt, RLWECtDft, RLWEPt}, - test_fft64::rgsw::noise_rgsw_rlwe_product, - }; +use crate::{ + elem::{GetRow, Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, + grlwe::GRLWECt, + keys::{SecretKey, SecretKeyDft}, + rgsw::RGSWCt, + rlwe::{RLWECtDft, RLWEPt}, + test_fft64::grlwe::noise_grlwe_rlwe_product, +}; - #[test] - fn encrypt_rgsw_sk() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let rows: usize = 4; +#[test] +fn encrypt_rgsw_sk() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let rows: usize = 4; - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; - let mut ct: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + let mut ct: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), - ); + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), + ); - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); - (0..ct.cols()).for_each(|col_j| { - (0..ct.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + (0..ct.cols()).for_each(|col_j| { + (0..ct.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); - if col_j == 1 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); - } + if col_j == 1 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } - ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); + ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); - pt_want.data.zero(); - }); + pt_want.data.zero(); }); + }); - module.free(); - } + module.free(); } -#[allow(dead_code)] -pub(crate) fn noise_rgsw_rlwe_product( +#[test] +fn from_prod_by_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rgsw_in: usize = 45; + let log_k_rgsw_out: usize = 45; + let rows: usize = (log_k_rgsw_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw_in: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_in, rows); + let mut ct_rgsw_out: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_out, rows); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw_out.size()) + | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw_in.size()) + | RGSWCt::prod_by_grlwe_scratch_space( + &module, + ct_rgsw_out.size(), + ct_rgsw_in.size(), + ct_grlwe.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw_in.encrypt_sk( + &module, + &pt_rgsw, + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw_out.prod_by_grlwe(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow()); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw_out); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_out); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_out.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_out.size()); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_out); + + (0..ct_rgsw_out.cols()).for_each(|col_j| { + (0..ct_rgsw_out.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); + + if col_j == 1 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_rgsw_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); + ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.2, + "have: {} want: {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); + + module.free(); +} + +#[test] +fn from_prod_by_grlwe_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rgsw: usize = 45; + let rows: usize = (log_k_rgsw + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw, rows); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw.size()) + | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RGSWCt::prod_by_grlwe_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw.prod_by_grlwe_inplace(&module, &ct_grlwe, scratch.borrow()); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw.size()); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw); + + (0..ct_rgsw.cols()).for_each(|col_j| { + (0..ct_rgsw.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); + + if col_j == 1 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_rgsw.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); + ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.2, + "have: {} want: {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); + + module.free(); +} + +#[test] +fn from_prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_rgsw_rhs: usize = 60; + let log_k_rgsw_lhs_in: usize = 45; + let log_k_rgsw_lhs_out: usize = 45; + let rows: usize = (log_k_rgsw_lhs_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw_rhs: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_rhs, rows); + let mut ct_rgsw_lhs_in: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_lhs_in, rows); + let mut ct_rgsw_lhs_out: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_lhs_out, rows); + let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + + let k: usize = 1; + + pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size()) + | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw_lhs_in.size()) + | RGSWCt::prod_by_rgsw_scratch_space( + &module, + ct_rgsw_lhs_out.size(), + ct_rgsw_lhs_in.size(), + ct_rgsw_rhs.size(), + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw_rhs.encrypt_sk( + &module, + &pt_rgsw_rhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw_lhs_in.encrypt_sk( + &module, + &pt_rgsw_lhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw_lhs_out.prod_by_rgsw(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow()); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw_lhs_out); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs_out); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs_out.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs_out.size()); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs_out); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); + + (0..ct_rgsw_lhs_out.cols()).for_each(|col_j| { + (0..ct_rgsw_lhs_out.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); + + if col_j == 1 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_rgsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); + ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rgsw_lhs_in, + log_k_rgsw_rhs, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "have: {} want: {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); + + module.free(); +} + +#[test] +fn from_prod_by_rgsw_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_rgsw_rhs: usize = 60; + let log_k_rgsw_lhs: usize = 45; + let rows: usize = (log_k_rgsw_lhs + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw_rhs: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_rhs, rows); + let mut ct_rgsw_lhs: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_lhs, rows); + let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + + let k: usize = 1; + + pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw_lhs.size()) + | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw_lhs.size()) + | RGSWCt::prod_by_rgsw_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw_rhs.encrypt_sk( + &module, + &pt_rgsw_rhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw_lhs.encrypt_sk( + &module, + &pt_rgsw_lhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw_lhs.prod_by_rgsw_inplace(&module, &ct_rgsw_rhs, scratch.borrow()); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw_lhs); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs.size()); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); + + (0..ct_rgsw_lhs.cols()).for_each(|col_j| { + (0..ct_rgsw_lhs.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); + + if col_j == 1 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_rgsw_lhs.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); + ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rgsw_lhs, + log_k_rgsw_rhs, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "have: {} want: {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); + + module.free(); +} + +pub(crate) fn noise_rgsw_product( n: f64, log_base2k: usize, var_xs: f64, diff --git a/core/src/test_fft64/rlwe.rs b/core/src/test_fft64/rlwe.rs index acc10a1..a2fabb9 100644 --- a/core/src/test_fft64/rlwe.rs +++ b/core/src/test_fft64/rlwe.rs @@ -1,621 +1,618 @@ -#[cfg(test)] -mod tests_rlwe { - use base2k::{ - Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, - ZnxViewMut, ZnxZero, - }; - use itertools::izip; - use sampling::source::Source; +use base2k::{ + Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, + ZnxViewMut, ZnxZero, +}; +use itertools::izip; +use sampling::source::Source; - use crate::{ - elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace}, - grlwe::GRLWECt, - keys::{PublicKey, SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECt, RLWECtDft, RLWEPt}, - test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_rlwe_product}, - }; +use crate::{ + elem::{Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, + grlwe::GRLWECt, + keys::{PublicKey, SecretKey, SecretKeyDft}, + rgsw::RGSWCt, + rlwe::{RLWECt, RLWECtDft, RLWEPt}, + test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_product}, +}; - #[test] - fn encrypt_sk() { - let module: Module = Module::::new(32); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let log_k_pt: usize = 30; +#[test] +fn encrypt_sk() { + let module: Module = Module::::new(32); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let log_k_pt: usize = 30; - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; - let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_pt); + let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_pt); - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECt::decrypt_scratch_space(&module, ct.size()), - ); + let mut scratch: ScratchOwned = ScratchOwned::new( + RLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECt::decrypt_scratch_space(&module, ct.size()), + ); - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); - let mut data_want: Vec = vec![0i64; module.n()]; + let mut data_want: Vec = vec![0i64; module.n()]; - data_want - .iter_mut() - .for_each(|x| *x = source_xa.next_i64() & 0xFF); + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0xFF); - pt.data - .encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10); + pt.data + .encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10); - ct.encrypt_sk( - &module, - Some(&pt), - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); + ct.encrypt_sk( + &module, + Some(&pt), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); - pt.data.zero(); + pt.data.zero(); - ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - let mut data_have: Vec = vec![0i64; module.n()]; + let mut data_have: Vec = vec![0i64; module.n()]; - pt.data - .decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have); - - // TODO: properly assert the decryption noise through std(dec(ct) - pt) - let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64; - izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { - let b_scaled = (*b as f64) / scale; - assert!( - (*a as f64 - b_scaled).abs() < 0.1, - "{} {}", - *a as f64, - b_scaled - ) - }); - - module.free(); - } - - #[test] - fn encrypt_zero_sk() { - let module: Module = Module::::new(1024); - let log_base2k: usize = 8; - let log_k_ct: usize = 55; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([1u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - let mut ct_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); - - let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECtDft::decrypt_scratch_space(&module, ct_dft.size()) - | RLWECtDft::encrypt_zero_sk_scratch_space(&module, ct_dft.size()), - ); - - ct_dft.encrypt_zero_sk( - &module, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - - assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2); - module.free(); - } - - #[test] - fn encrypt_pk() { - let module: Module = Module::::new(32); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let log_k_pk: usize = 64; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - let mut source_xu: Source = Source::new([0u8; 32]); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - let mut pk: PublicKey, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk); - pk.generate( - &module, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - ); - - let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECt::encrypt_sk_scratch_space(&module, ct.size()) - | RLWECt::decrypt_scratch_space(&module, ct.size()) - | RLWECt::encrypt_pk_scratch_space(&module, pk.size()), - ); - - let mut data_want: Vec = vec![0i64; module.n()]; - - data_want - .iter_mut() - .for_each(|x| *x = source_xa.next_i64() & 0); - - pt_want - .data - .encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10); - - ct.encrypt_pk( - &module, - Some(&pt_want), - &pk, - &mut source_xu, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - - ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); - - assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2); - - module.free(); - } - - #[test] - fn from_prod_by_grlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECt::from_prod_by_grlwe_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_grlwe.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_out.from_prod_by_grlwe(&module, &ct_rlwe_in, &ct_grlwe, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe_in, - log_k_grlwe, - ); + pt.data + .decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have); + // TODO: properly assert the decryption noise through std(dec(ct) - pt) + let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64; + izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { + let b_scaled = (*b as f64) / scale; assert!( - (noise_have - noise_want).abs() <= 0.1, + (*a as f64 - b_scaled).abs() < 0.1, "{} {}", - noise_have, - noise_want - ); + *a as f64, + b_scaled + ) + }); - module.free(); - } - - #[test] - fn prod_grlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe: usize = 45; - let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECt::prod_by_grlwe_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.prod_by_grlwe(&module, &ct_grlwe, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn from_prod_by_rgsw() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - pt_want.to_mut().at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECt::from_prod_by_rgsw_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_rgsw.size(), - ), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - Some(&pt_want), - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_out.from_prod_by_rgsw(&module, &ct_rlwe_in, &ct_rgsw, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn prod_by_rgsw() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - pt_want.to_mut().at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECt::prod_by_rgsw_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - Some(&pt_want), - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.prod_by_rgsw(&module, &ct_rgsw, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } + module.free(); +} + +#[test] +fn encrypt_zero_sk() { + let module: Module = Module::::new(1024); + let log_base2k: usize = 8; + let log_k_ct: usize = 55; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut ct_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + + let mut scratch: ScratchOwned = ScratchOwned::new( + RLWECtDft::decrypt_scratch_space(&module, ct_dft.size()) + | RLWECtDft::encrypt_zero_sk_scratch_space(&module, ct_dft.size()), + ); + + ct_dft.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2); + module.free(); +} + +#[test] +fn encrypt_pk() { + let module: Module = Module::::new(32); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let log_k_pk: usize = 64; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xu: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut pk: PublicKey, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk); + pk.generate( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + ); + + let mut scratch: ScratchOwned = ScratchOwned::new( + RLWECt::encrypt_sk_scratch_space(&module, ct.size()) + | RLWECt::decrypt_scratch_space(&module, ct.size()) + | RLWECt::encrypt_pk_scratch_space(&module, pk.size()), + ); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0); + + pt_want + .data + .encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10); + + ct.encrypt_pk( + &module, + Some(&pt_want), + &pk, + &mut source_xu, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); + + assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2); + + module.free(); +} + +#[test] +fn prod_by_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECt::prod_by_grlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_grlwe.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_out.prod_by_grlwe(&module, &ct_rlwe_in, &ct_grlwe, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); +} + +#[test] +fn prod_by_grlwe_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe: usize = 45; + let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECt::prod_by_grlwe_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.prod_by_grlwe_inplace(&module, &ct_grlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); +} + +#[test] +fn prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECt::prod_by_grlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_rgsw.size(), + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_out.prod_by_rgsw(&module, &ct_rlwe_in, &ct_rgsw, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); +} + +#[test] +fn prod_by_rgsw_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECt::prod_by_rgsw_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.prod_by_rgsw_inplace(&module, &ct_rgsw, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); } diff --git a/core/src/test_fft64/rlwe_dft.rs b/core/src/test_fft64/rlwe_dft.rs index 448bdfb..fe71a09 100644 --- a/core/src/test_fft64/rlwe_dft.rs +++ b/core/src/test_fft64/rlwe_dft.rs @@ -1,448 +1,443 @@ -#[cfg(test)] -mod tests { - use crate::{ - elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace}, - grlwe::GRLWECt, - keys::{SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECt, RLWECtDft, RLWEPt}, - test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_rlwe_product}, - }; - use base2k::{ - FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut, - }; - use sampling::source::Source; - - #[test] - fn from_prod_by_grlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_in_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut ct_rlwe_out_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECtDft::from_prod_by_grlwe_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_grlwe.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); - ct_rlwe_out_dft.from_prod_by_grlwe(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); - ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn prod_by_grlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe: usize = 45; - let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECtDft::prod_by_grlwe_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.prod_by_grlwe(&module, &ct_grlwe, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn from_prod_by_rgsw() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut ct_rlwe_dft_in: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_dft_out: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - pt_want.to_mut().at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECt::from_prod_by_rgsw_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_rgsw.size(), - ), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - Some(&pt_want), - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.dft(&module, &mut ct_rlwe_dft_in); - ct_rlwe_dft_out.from_prod_by_rgsw(&module, &ct_rlwe_dft_in, &ct_rgsw, scratch.borrow()); - ct_rlwe_dft_out.idft(&module, &mut ct_rlwe_out, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn prod_by_rgsw() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - pt_want.to_mut().at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECt::prod_by_rgsw_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - Some(&pt_want), - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.prod_by_rgsw(&module, &ct_rgsw, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } +use crate::{ + elem::{Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, + grlwe::GRLWECt, + keys::{SecretKey, SecretKeyDft}, + rgsw::RGSWCt, + rlwe::{RLWECt, RLWECtDft, RLWEPt}, + test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_product}, +}; +use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; +use sampling::source::Source; + +#[test] +fn by_grlwe_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_in_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rlwe_out_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECtDft::prod_by_grlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_grlwe.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); + ct_rlwe_out_dft.prod_by_grlwe(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); + ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); +} + +#[test] +fn prod_by_grlwe_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe: usize = 45; + let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECtDft::prod_by_grlwe_inplace_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.prod_by_grlwe_inplace(&module, &ct_grlwe, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); +} + +#[test] +fn prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rlwe_dft_in: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_dft_out: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECt::prod_by_rgsw_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_rgsw.size(), + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.dft(&module, &mut ct_rlwe_dft_in); + ct_rlwe_dft_out.prod_by_rgsw(&module, &ct_rlwe_dft_in, &ct_rgsw, scratch.borrow()); + ct_rlwe_dft_out.idft(&module, &mut ct_rlwe_out, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); +} + +#[test] +fn prod_by_rgsw_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECt::prod_by_rgsw_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.prod_by_rgsw_inplace(&module, &ct_rgsw, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); } From 31b14ee585442ae15a79cdda865981f0e37b5c33 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 13 May 2025 00:40:07 +0200 Subject: [PATCH 61/87] rework for GLWE --- base2k/src/module.rs | 213 ++--- core/Cargo.toml | 8 + core/benches/external_product_glwe_fft64.rs | 205 +++++ core/benches/keyswitch_glwe_fft64.rs | 200 +++++ core/src/elem.rs | 233 +----- core/src/encryption.rs | 105 +++ core/src/external_product.rs | 19 + core/src/{rgsw.rs => ggsw.rs} | 219 ++--- core/src/glwe.rs | 845 ++++++++++++++++++++ core/src/keys.rs | 24 +- core/src/keyswitch.rs | 20 + core/src/{grlwe.rs => keyswitch_key.rs} | 220 ++--- core/src/lib.rs | 10 +- core/src/rlwe.rs | 701 ---------------- core/src/test_fft64/grlwe.rs | 137 ++-- core/src/test_fft64/rgsw.rs | 154 ++-- core/src/test_fft64/rlwe.rs | 168 ++-- core/src/test_fft64/rlwe_dft.rs | 143 ++-- core/src/vec_glwe_product.rs | 197 +++++ 19 files changed, 2290 insertions(+), 1531 deletions(-) create mode 100644 core/benches/external_product_glwe_fft64.rs create mode 100644 core/benches/keyswitch_glwe_fft64.rs create mode 100644 core/src/encryption.rs create mode 100644 core/src/external_product.rs rename core/src/{rgsw.rs => ggsw.rs} (51%) create mode 100644 core/src/glwe.rs create mode 100644 core/src/keyswitch.rs rename core/src/{grlwe.rs => keyswitch_key.rs} (51%) delete mode 100644 core/src/rlwe.rs create mode 100644 core/src/vec_glwe_product.rs diff --git a/base2k/src/module.rs b/base2k/src/module.rs index 0e7d124..aab18b4 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -1,106 +1,107 @@ -use crate::GALOISGENERATOR; -use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info}; -use std::marker::PhantomData; - -#[derive(Copy, Clone)] -#[repr(u8)] -pub enum BACKEND { - FFT64, - NTT120, -} - -pub trait Backend { - const KIND: BACKEND; - fn module_type() -> u32; -} - -pub struct FFT64; -pub struct NTT120; - -impl Backend for FFT64 { - const KIND: BACKEND = BACKEND::FFT64; - fn module_type() -> u32 { - 0 - } -} - -impl Backend for NTT120 { - const KIND: BACKEND = BACKEND::NTT120; - fn module_type() -> u32 { - 1 - } -} - -pub struct Module { - pub ptr: *mut MODULE, - n: usize, - _marker: PhantomData, -} - -impl Module { - // Instantiates a new module. - pub fn new(n: usize) -> Self { - unsafe { - let m: *mut module_info_t = new_module_info(n as u64, B::module_type()); - if m.is_null() { - panic!("Failed to create module."); - } - Self { - ptr: m, - n: n, - _marker: PhantomData, - } - } - } - - pub fn n(&self) -> usize { - self.n - } - - pub fn log_n(&self) -> usize { - (usize::BITS - (self.n() - 1).leading_zeros()) as _ - } - - pub fn cyclotomic_order(&self) -> u64 { - (self.n() << 1) as _ - } - - // Returns GALOISGENERATOR^|generator| * sign(generator) - pub fn galois_element(&self, generator: i64) -> i64 { - if generator == 0 { - return 1; - } - ((mod_exp_u64(GALOISGENERATOR, generator.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * generator.signum() - } - - // Returns gen^-1 - pub fn galois_element_inv(&self, generator: i64) -> i64 { - if generator == 0 { - panic!("cannot invert 0") - } - ((mod_exp_u64( - generator.abs() as u64, - (self.cyclotomic_order() - 1) as usize, - ) & (self.cyclotomic_order() - 1)) as i64) - * generator.signum() - } - - pub fn free(self) { - unsafe { delete_module_info(self.ptr) } - drop(self); - } -} - -fn mod_exp_u64(x: u64, e: usize) -> u64 { - let mut y: u64 = 1; - let mut x_pow: u64 = x; - let mut exp = e; - while exp > 0 { - if exp & 1 == 1 { - y = y.wrapping_mul(x_pow); - } - x_pow = x_pow.wrapping_mul(x_pow); - exp >>= 1; - } - y -} +use crate::GALOISGENERATOR; +use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info}; +use std::marker::PhantomData; + +#[derive(Copy, Clone)] +#[repr(u8)] +pub enum BACKEND { + FFT64, + NTT120, +} + +pub trait Backend { + const KIND: BACKEND; + fn module_type() -> u32; +} + +pub struct FFT64; +pub struct NTT120; + +impl Backend for FFT64 { + const KIND: BACKEND = BACKEND::FFT64; + fn module_type() -> u32 { + 0 + } +} + +impl Backend for NTT120 { + const KIND: BACKEND = BACKEND::NTT120; + fn module_type() -> u32 { + 1 + } +} + +pub struct Module { + pub ptr: *mut MODULE, + n: usize, + _marker: PhantomData, +} + +impl Module { + // Instantiates a new module. + pub fn new(n: usize) -> Self { + unsafe { + let m: *mut module_info_t = new_module_info(n as u64, B::module_type()); + if m.is_null() { + panic!("Failed to create module."); + } + Self { + ptr: m, + n: n, + _marker: PhantomData, + } + } + } + + pub fn n(&self) -> usize { + self.n + } + + pub fn log_n(&self) -> usize { + (usize::BITS - (self.n() - 1).leading_zeros()) as _ + } + + pub fn cyclotomic_order(&self) -> u64 { + (self.n() << 1) as _ + } + + // Returns GALOISGENERATOR^|generator| * sign(generator) + pub fn galois_element(&self, generator: i64) -> i64 { + if generator == 0 { + return 1; + } + ((mod_exp_u64(GALOISGENERATOR, generator.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * generator.signum() + } + + // Returns gen^-1 + pub fn galois_element_inv(&self, generator: i64) -> i64 { + if generator == 0 { + panic!("cannot invert 0") + } + ((mod_exp_u64( + generator.abs() as u64, + (self.cyclotomic_order() - 1) as usize, + ) & (self.cyclotomic_order() - 1)) as i64) + * generator.signum() + } +} + +impl Drop for Module { + fn drop(&mut self) { + unsafe { delete_module_info(self.ptr) } + } +} + +fn mod_exp_u64(x: u64, e: usize) -> u64 { + let mut y: u64 = 1; + let mut x_pow: u64 = x; + let mut exp = e; + while exp > 0 { + if exp & 1 == 1 { + y = y.wrapping_mul(x_pow); + } + x_pow = x_pow.wrapping_mul(x_pow); + exp >>= 1; + } + y +} diff --git a/core/Cargo.toml b/core/Cargo.toml index 692c4fb..a54bd5a 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -10,3 +10,11 @@ base2k = {path="../base2k"} sampling = {path="../sampling"} rand_distr = {workspace = true} itertools = {workspace = true} + +[[bench]] +name = "external_product_glwe_fft64" +harness = false + +[[bench]] +name = "keyswitch_glwe_fft64" +harness = false \ No newline at end of file diff --git a/core/benches/external_product_glwe_fft64.rs b/core/benches/external_product_glwe_fft64.rs new file mode 100644 index 0000000..4462fab --- /dev/null +++ b/core/benches/external_product_glwe_fft64.rs @@ -0,0 +1,205 @@ +use base2k::{FFT64, Module, ScalarZnxAlloc, ScratchOwned}; +use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use rlwe::{ + elem::Infos, + encryption::EncryptSkScratchSpace, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, + }, + ggsw::GGSWCiphertext, + glwe::GLWECiphertext, + keys::{SecretKey, SecretKeyFourier}, +}; +use sampling::source::Source; + +fn bench_external_product_glwe_fft64(c: &mut Criterion) { + let mut group = c.benchmark_group("external_product_glwe_fft64"); + + struct Params { + log_n: usize, + basek: usize, + k_rlwe_in: usize, + k_rlwe_out: usize, + k_rgsw: usize, + } + + fn runner(p: Params) -> impl FnMut() { + let module: Module = Module::::new(1 << p.log_n); + + let basek: usize = p.basek; + let k_rlwe_in: usize = p.k_rlwe_in; + let k_rlwe_out: usize = p.k_rlwe_out; + let k_rgsw: usize = p.k_rgsw; + + let rows: usize = (p.k_rlwe_in + p.basek - 1) / p.basek; + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_rgsw, rows); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_in); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_out); + let pt_rgsw: base2k::ScalarZnx> = module.new_scalar_znx(1); + + let mut scratch = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::external_product_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_rgsw.size(), + ), + ); + + let mut source_xs = Source::new([0u8; 32]); + let mut source_xe = Source::new([0u8; 32]); + let mut source_xa = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + move || { + ct_rlwe_out.external_product( + black_box(&module), + black_box(&ct_rlwe_in), + black_box(&ct_rgsw), + black_box(scratch.borrow()), + ); + } + } + + let params_set: Vec = vec![Params { + log_n: 10, + basek: 7, + k_rlwe_in: 27, + k_rlwe_out: 27, + k_rgsw: 27, + }]; + + for params in params_set { + let id = BenchmarkId::new("EXTERNAL_PRODUCT_GLWE_FFT64", ""); + let mut runner = runner(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { + let mut group = c.benchmark_group("external_product_glwe_inplace_fft64"); + + struct Params { + log_n: usize, + basek: usize, + k_rlwe: usize, + k_rgsw: usize, + } + + fn runner(p: Params) -> impl FnMut() { + let module: Module = Module::::new(1 << p.log_n); + + let basek: usize = p.basek; + let k_rlwe: usize = p.k_rlwe; + let k_rgsw: usize = p.k_rgsw; + + let rows: usize = (p.k_rlwe + p.basek - 1) / p.basek; + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_rgsw, rows); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe); + let pt_rgsw: base2k::ScalarZnx> = module.new_scalar_znx(1); + + let mut scratch = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + ); + + let mut source_xs = Source::new([0u8; 32]); + let mut source_xe = Source::new([0u8; 32]); + let mut source_xa = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + move || { + let scratch_borrow = scratch.borrow(); + (0..1374).for_each(|i| { + ct_rlwe.external_product_inplace( + black_box(&module), + black_box(&ct_rgsw), + black_box(scratch_borrow), + ); + }); + } + } + + let params_set: Vec = vec![Params { + log_n: 9, + basek: 18, + k_rlwe: 27, + k_rgsw: 27, + }]; + + for params in params_set { + let id = BenchmarkId::new("EXTERNAL_PRODUCT_GLWE_INPLACE_FFT64", ""); + let mut runner = runner(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_external_product_glwe_fft64, + bench_external_product_glwe_inplace_fft64 +); +criterion_main!(benches); diff --git a/core/benches/keyswitch_glwe_fft64.rs b/core/benches/keyswitch_glwe_fft64.rs new file mode 100644 index 0000000..3a25360 --- /dev/null +++ b/core/benches/keyswitch_glwe_fft64.rs @@ -0,0 +1,200 @@ +use base2k::{FFT64, Module, ScalarZnxAlloc, ScratchOwned}; +use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use rlwe::{ + elem::Infos, + encryption::EncryptSkScratchSpace, + glwe::GLWECiphertext, + keys::{SecretKey, SecretKeyFourier}, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, + keyswitch_key::GLWEKeySwitchKey, +}; +use sampling::source::Source; + +fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { + let mut group = c.benchmark_group("keyswitch_glwe_fft64"); + + struct Params { + log_n: usize, + basek: usize, + k_rlwe_in: usize, + k_rlwe_out: usize, + k_grlwe: usize, + } + + fn runner(p: Params) -> impl FnMut() { + let module: Module = Module::::new(1 << p.log_n); + + let basek: usize = p.basek; + let k_rlwe_in: usize = p.k_rlwe_in; + let k_rlwe_out: usize = p.k_rlwe_out; + let k_grlwe: usize = p.k_grlwe; + + let rows: usize = (p.k_rlwe_in + p.basek - 1) / p.basek; + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, basek, k_grlwe, rows); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_in); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_out); + let pt_grlwe: base2k::ScalarZnx> = module.new_scalar_znx(1); + + let mut scratch = ScratchOwned::new( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::keyswitch_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_grlwe.size(), + ), + ); + + let mut source_xs = Source::new([0u8; 32]); + let mut source_xe = Source::new([0u8; 32]); + let mut source_xa = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + sk_dft.dft(&module, &sk); + + ct_grlwe.encrypt_sk( + &module, + &pt_grlwe, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + move || { + ct_rlwe_out.keyswitch( + black_box(&module), + black_box(&ct_rlwe_in), + black_box(&ct_grlwe), + black_box(scratch.borrow()), + ); + } + } + + let params_set: Vec = vec![Params { + log_n: 16, + basek: 50, + k_rlwe_in: 1250, + k_rlwe_out: 1250, + k_grlwe: 1250 + 66, + }]; + + for params in params_set { + let id = BenchmarkId::new("KEYSWITCH_GLWE_FFT64", ""); + let mut runner = runner(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { + let mut group = c.benchmark_group("keyswitch_glwe_inplace_fft64"); + + struct Params { + log_n: usize, + basek: usize, + k_rlwe: usize, + k_grlwe: usize, + } + + fn runner(p: Params) -> impl FnMut() { + let module: Module = Module::::new(1 << p.log_n); + + let basek: usize = p.basek; + let k_rlwe: usize = p.k_rlwe; + let k_grlwe: usize = p.k_grlwe; + + let rows: usize = (p.k_rlwe + p.basek - 1) / p.basek; + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, basek, k_grlwe, rows); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe); + let pt_grlwe: base2k::ScalarZnx> = module.new_scalar_znx(1); + + let mut scratch = ScratchOwned::new( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), + ); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + sk_dft.dft(&module, &sk); + + ct_grlwe.encrypt_sk( + &module, + &pt_grlwe, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + move || { + ct_rlwe.keyswitch_inplace( + black_box(&module), + black_box(&ct_grlwe), + black_box(scratch.borrow()), + ); + } + } + + let params_set: Vec = vec![Params { + log_n: 9, + basek: 18, + k_rlwe: 27, + k_grlwe: 27, + }]; + + for params in params_set { + let id = BenchmarkId::new("KEYSWITCH_GLWE_INPLACE_FFT64", ""); + let mut runner = runner(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_keyswitch_glwe_fft64, + bench_keyswitch_glwe_inplace_fft64 +); +criterion_main!(benches); diff --git a/core/src/elem.rs b/core/src/elem.rs index b66c86d..bf5ca1e 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,14 +1,6 @@ -use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, - VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, -}; +use base2k::{Backend, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos}; -use crate::{ - grlwe::GRLWECt, - rgsw::RGSWCt, - rlwe::{RLWECt, RLWECtDft}, - utils::derive_size, -}; +use crate::{glwe::GLWECiphertextFourier, utils::derive_size}; pub trait Infos { type Inner: ZnxInfos; @@ -31,244 +23,37 @@ pub trait Infos { } /// Returns the number of polynomials in each row. - fn cols(&self) -> usize { + fn rank(&self) -> usize { self.inner().cols() } /// Returns the number of size per polynomial. fn size(&self) -> usize { let size: usize = self.inner().size(); - debug_assert_eq!(size, derive_size(self.log_base2k(), self.log_k())); + debug_assert_eq!(size, derive_size(self.basek(), self.k())); size } /// Returns the total number of small polynomials. fn poly_count(&self) -> usize { - self.rows() * self.cols() * self.size() + self.rows() * self.rank() * self.size() } /// Returns the base 2 logarithm of the ciphertext base. - fn log_base2k(&self) -> usize; + fn basek(&self) -> usize; /// Returns the bit precision of the ciphertext. - fn log_k(&self) -> usize; + fn k(&self) -> usize; } pub trait GetRow { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) where VecZnxDft: VecZnxDftToMut; } pub trait SetRow { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) where VecZnxDft: VecZnxDftToRef; } - -pub trait ProdInplaceScratchSpace { - fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; - fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; -} - -pub trait ProdInplace -where - MatZnxDft: MatZnxDftToRef, -{ - fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch); - fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch); -} - -pub trait ProdScratchSpace { - fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; - fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; -} - -pub trait Product -where - MatZnxDft: MatZnxDftToRef, -{ - type Lhs; - - fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch); - fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch); -} - -pub(crate) trait MatRLWEProductScratchSpace { - fn prod_with_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize; - - fn prod_with_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { - Self::prod_with_rlwe_scratch_space(module, res_size, res_size, mat_size) - } - - fn prod_with_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { - (Self::prod_with_rlwe_scratch_space(module, res_size, a_size, mat_size) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, a_size) - + module.bytes_of_vec_znx(2, res_size) - } - - fn prod_with_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { - (Self::prod_with_rlwe_inplace_scratch_space(module, res_size, mat_size) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, res_size) - } - - fn prod_with_mat_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { - Self::prod_with_rlwe_dft_scratch_space(module, res_size, a_size, mat_size) - + module.bytes_of_vec_znx_dft(2, a_size) - + module.bytes_of_vec_znx_dft(2, res_size) - } - - fn prod_with_mat_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { - Self::prod_with_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, res_size) - } -} - -pub(crate) trait MatRLWEProduct: Infos { - fn prod_with_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) - where - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef; - - fn prod_with_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) - where - VecZnx: VecZnxToMut + VecZnxToRef, - { - unsafe { - let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. - self.prod_with_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); - } - } - - fn prod_with_rlwe_dft( - &self, - module: &Module, - res: &mut RLWECtDft, - a: &RLWECtDft, - scratch: &mut Scratch, - ) where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + ZnxInfos, - { - let log_base2k: usize = self.log_base2k(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); - - let mut a_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: a_data, - log_base2k: a.log_base2k(), - log_k: a.log_k(), - }; - - a.idft(module, &mut a_idft, scratch_1); - - let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); - - let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: res_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - self.prod_with_rlwe(module, &mut res_idft, &a_idft, scratch_2); - - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); - } - - fn prod_with_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) - where - VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, - { - let log_base2k: usize = self.log_base2k(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); - - let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: res_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - res.idft(module, &mut res_idft, scratch_1); - - self.prod_with_rlwe_inplace(module, &mut res_idft, scratch_1); - - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); - } - - fn prod_with_mat_rlwe(&self, module: &Module, res: &mut RES, a: &LHS, scratch: &mut Scratch) - where - LHS: GetRow + Infos, - RES: SetRow + Infos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); - - let mut tmp_a_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: a.log_base2k(), - log_k: a.log_k(), - }; - - let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, res.size()); - - let mut tmp_res_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_res_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - let min_rows: usize = res.rows().min(a.rows()); - - (0..res.rows()).for_each(|row_i| { - (0..res.cols()).for_each(|col_j| { - a.get_row(module, row_i, col_j, &mut tmp_a_row); - self.prod_with_rlwe_dft(module, &mut tmp_res_row, &tmp_a_row, scratch2); - res.set_row(module, row_i, col_j, &tmp_res_row); - }); - }); - - tmp_res_row.data.zero(); - - (min_rows..res.rows()).for_each(|row_i| { - (0..self.cols()).for_each(|col_j| { - res.set_row(module, row_i, col_j, &tmp_res_row); - }); - }); - } - - fn prod_with_mat_rlwe_inplace(&self, module: &Module, res: &mut RES, scratch: &mut Scratch) - where - RES: GetRow + SetRow + Infos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); - - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - (0..res.rows()).for_each(|row_i| { - (0..res.cols()).for_each(|col_j| { - res.get_row(module, row_i, col_j, &mut tmp_row); - self.prod_with_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, col_j, &tmp_row); - }); - }); - } -} diff --git a/core/src/encryption.rs b/core/src/encryption.rs new file mode 100644 index 0000000..915834c --- /dev/null +++ b/core/src/encryption.rs @@ -0,0 +1,105 @@ +use base2k::{Backend, Module, Scratch}; +use sampling::source::Source; + +pub trait EncryptSkScratchSpace { + fn encrypt_sk_scratch_space(module: &Module, ct_size: usize) -> usize; +} + +pub trait EncryptSk { + type Ciphertext; + type Plaintext; + type SecretKey; + + fn encrypt_sk( + &self, + module: &Module, + ct: &mut Self::Ciphertext, + pt: &Self::Plaintext, + sk: &Self::SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ); +} + +pub trait EncryptZeroSkScratchSpace { + fn encrypt_zero_sk_scratch_space(module: &Module, ct_size: usize) -> usize; +} + +pub trait EncryptZeroSk { + type Ciphertext; + type SecretKey; + + fn encrypt_zero_sk( + &self, + module: &Module, + ct: &mut Self::Ciphertext, + sk: &Self::SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ); +} + +pub trait EncryptPkScratchSpace { + fn encrypt_pk_scratch_space(module: &Module, ct_size: usize) -> usize; +} + +pub trait EncryptPk { + type Ciphertext; + type Plaintext; + type PublicKey; + + fn encrypt_pk( + &self, + module: &Module, + ct: &mut Self::Ciphertext, + pt: &Self::Plaintext, + pk: &Self::PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ); +} + +pub trait EncryptZeroPkScratchSpace { + fn encrypt_zero_pk_scratch_space(module: &Module, ct_size: usize) -> usize; +} + +pub trait EncryptZeroPk { + type Ciphertext; + type PublicKey; + + fn encrypt_zero_pk( + &self, + module: &Module, + ct: &mut Self::Ciphertext, + pk: &Self::PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ); +} + +pub trait Decrypt { + type Plaintext; + type Ciphertext; + type SecretKey; + + fn decrypt( + &self, + module: &Module, + pt: &mut Self::Plaintext, + ct: &Self::Ciphertext, + sk: &Self::SecretKey, + scratch: &mut Scratch, + ); +} diff --git a/core/src/external_product.rs b/core/src/external_product.rs new file mode 100644 index 0000000..e8d0a7e --- /dev/null +++ b/core/src/external_product.rs @@ -0,0 +1,19 @@ +use base2k::{FFT64, Module, Scratch}; + +pub trait ExternalProductScratchSpace { + fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; +} + +pub trait ExternalProduct { + type Lhs; + type Rhs; + fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch); +} +pub trait ExternalProductInplaceScratchSpace { + fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize; +} + +pub trait ExternalProductInplace { + type Rhs; + fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch); +} diff --git a/core/src/rgsw.rs b/core/src/ggsw.rs similarity index 51% rename from core/src/rgsw.rs rename to core/src/ggsw.rs index b866252..79b12a5 100644 --- a/core/src/rgsw.rs +++ b/core/src/ggsw.rs @@ -7,23 +7,26 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{ - GetRow, Infos, MatRLWEProduct, MatRLWEProductScratchSpace, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, - Product, SetRow, + elem::{GetRow, Infos, SetRow}, + encryption::EncryptSkScratchSpace, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, }, - grlwe::GRLWECt, - keys::SecretKeyDft, - rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, + glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext, encrypt_glwe_sk}, + keys::SecretKeyFourier, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, + keyswitch_key::GLWEKeySwitchKey, utils::derive_size, + vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; -pub struct RGSWCt { +pub struct GGSWCiphertext { pub data: MatZnxDft, pub log_base2k: usize, pub log_k: usize, } -impl RGSWCt, B> { +impl GGSWCiphertext, B> { pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { Self { data: module.new_mat_znx_dft(rows, 2, 2, derive_size(log_base2k, log_k)), @@ -33,23 +36,23 @@ impl RGSWCt, B> { } } -impl Infos for RGSWCt { +impl Infos for GGSWCiphertext { type Inner = MatZnxDft; fn inner(&self) -> &Self::Inner { &self.data } - fn log_base2k(&self) -> usize { + fn basek(&self) -> usize { self.log_base2k } - fn log_k(&self) -> usize { + fn k(&self) -> usize { self.log_k } } -impl MatZnxDftToMut for RGSWCt +impl MatZnxDftToMut for GGSWCiphertext where MatZnxDft: MatZnxDftToMut, { @@ -58,7 +61,7 @@ where } } -impl MatZnxDftToRef for RGSWCt +impl MatZnxDftToRef for GGSWCiphertext where MatZnxDft: MatZnxDftToRef, { @@ -67,9 +70,9 @@ where } } -impl RGSWCt, FFT64> { +impl GGSWCiphertext, FFT64> { pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { - RLWECt::encrypt_sk_scratch_space(module, size) + GLWECiphertext::encrypt_sk_scratch_space(module, size) + module.bytes_of_vec_znx(2, size) + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) @@ -78,9 +81,9 @@ impl RGSWCt, FFT64> { pub fn encrypt_rgsw_sk( module: &Module, - ct: &mut RGSWCt, + ct: &mut GGSWCiphertext, pt: &ScalarZnx

, - sk_dft: &SecretKeyDft, + sk_dft: &SecretKeyFourier, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -92,21 +95,21 @@ pub fn encrypt_rgsw_sk( ScalarZnxDft: ScalarZnxDftToRef, { let size: usize = ct.size(); - let log_base2k: usize = ct.log_base2k(); + let log_base2k: usize = ct.basek(); let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size); let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, 2, size); - let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt { + let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { data: tmp_znx_pt, log_base2k: log_base2k, - log_k: ct.log_k(), + log_k: ct.k(), }; - let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt { + let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { data: tmp_znx_ct, log_base2k: log_base2k, - log_k: ct.log_k(), + log_k: ct.k(), }; (0..ct.rows()).for_each(|row_j| { @@ -114,9 +117,9 @@ pub fn encrypt_rgsw_sk( module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0); module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2); - (0..ct.cols()).for_each(|col_i| { + (0..ct.rank()).for_each(|col_i| { // rlwe encrypt of vec_znx_pt into vec_znx_ct - encrypt_rlwe_sk( + encrypt_glwe_sk( module, &mut vec_znx_ct, Some((&vec_znx_pt, col_i)), @@ -141,12 +144,12 @@ pub fn encrypt_rgsw_sk( }); } -impl RGSWCt { +impl GGSWCiphertext { pub fn encrypt_sk( &mut self, module: &Module, pt: &ScalarZnx

, - sk_dft: &SecretKeyDft, + sk_dft: &SecretKeyFourier, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -163,11 +166,11 @@ impl RGSWCt { } } -impl GetRow for RGSWCt +impl GetRow for GGSWCiphertext where MatZnxDft: MatZnxDftToRef, { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) where VecZnxDft: VecZnxDftToMut, { @@ -175,11 +178,11 @@ where } } -impl SetRow for RGSWCt +impl SetRow for GGSWCiphertext where MatZnxDft: MatZnxDftToMut, { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) where VecZnxDft: VecZnxDftToRef, { @@ -187,30 +190,118 @@ where } } -impl MatRLWEProductScratchSpace for RGSWCt, FFT64> { - fn prod_with_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { +impl KeySwitchScratchSpace for GGSWCiphertext, FFT64> { + fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl KeySwitch for GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GGSWCiphertext; + type Rhs = GLWEKeySwitchKey; + + fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe(module, self, lhs, scratch); + } +} + +impl KeySwitchInplaceScratchSpace for GGSWCiphertext, FFT64> { + fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl KeySwitchInplace for GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Rhs = GLWEKeySwitchKey; + + fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe(module, self, rhs, scratch); + } +} + +impl ExternalProductScratchSpace for GGSWCiphertext, FFT64> { + fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl ExternalProduct for GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GGSWCiphertext; + type Rhs = GGSWCiphertext; + + fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe(module, self, lhs, scratch); + } +} + +impl ExternalProductInplaceScratchSpace for GGSWCiphertext, FFT64> { + fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl ExternalProductInplace for GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Rhs = GGSWCiphertext; + + fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe_inplace(module, self, scratch); + } +} + +impl VecGLWEProductScratchSpace for GGSWCiphertext, FFT64> { + fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { module.bytes_of_vec_znx_dft(2, rgsw_size) + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) | module.vec_znx_big_normalize_tmp_bytes()) } } -impl MatRLWEProduct for RGSWCt +impl VecGLWEProduct for GGSWCiphertext where MatZnxDft: MatZnxDftToRef + ZnxInfos, { - fn prod_with_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, + fn prod_with_glwe( + &self, + module: &Module, + res: &mut GLWECiphertext, + a: &GLWECiphertext, + scratch: &mut Scratch, + ) where VecZnx: VecZnxToMut, VecZnx: VecZnxToRef, { - let log_base2k: usize = self.log_base2k(); + let log_base2k: usize = self.basek(); #[cfg(debug_assertions)] { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(a.log_base2k(), log_base2k); + assert_eq!(res.basek(), log_base2k); + assert_eq!(a.basek(), log_base2k); assert_eq!(self.n(), module.n()); assert_eq!(res.n(), module.n()); assert_eq!(a.n(), module.n()); @@ -231,53 +322,3 @@ where module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); } } - -impl ProdInplaceScratchSpace for RGSWCt, FFT64> { - fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) - } - - fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) - } -} - -impl ProdScratchSpace for RGSWCt, FFT64> { - fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) - } - - fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) - } -} - -impl ProdInplace for RGSWCt -where - RGSWCt: GetRow + SetRow + Infos, - MatZnxDft: MatZnxDftToRef, -{ - fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe_inplace(module, self, scratch); - } - - fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe_inplace(module, self, scratch); - } -} - -impl Product for RGSWCt -where - MatZnxDft: MatZnxDftToRef + MatZnxDftToMut, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = RGSWCt; - - fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe(module, self, lhs, scratch); - } - - fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe(module, self, lhs, scratch); - } -} diff --git a/core/src/glwe.rs b/core/src/glwe.rs new file mode 100644 index 0000000..e50582d --- /dev/null +++ b/core/src/glwe.rs @@ -0,0 +1,845 @@ +use base2k::{ + AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, + ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, + VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, +}; +use sampling::source::Source; + +use crate::{ + elem::Infos, + encryption::{EncryptSk, EncryptSkScratchSpace, EncryptZeroSkScratchSpace}, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, + }, + ggsw::GGSWCiphertext, + keys::{PublicKey, SecretDistribution, SecretKeyFourier}, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, + keyswitch_key::GLWEKeySwitchKey, + utils::derive_size, + vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, +}; + +pub struct GLWECiphertext { + pub data: VecZnx, + pub log_base2k: usize, + pub log_k: usize, +} + +impl GLWECiphertext> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx(2, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for GLWECiphertext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.log_base2k + } + + fn k(&self) -> usize { + self.log_k + } +} + +impl VecZnxToMut for GLWECiphertext +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + #[allow(dead_code)] + pub(crate) fn dft(&self, module: &Module, res: &mut GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToMut + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), 2); + assert_eq!(res.rank(), 2); + assert_eq!(self.basek(), res.basek()) + } + + module.vec_znx_dft(res, 0, self, 0); + module.vec_znx_dft(res, 1, self, 1); + } +} + +impl KeySwitchScratchSpace for GLWECiphertext> { + fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } +} + +impl KeySwitch for GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GLWECiphertext; + type Rhs = GLWEKeySwitchKey; + + fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe(module, self, lhs, scratch); + } +} + +impl KeySwitchInplaceScratchSpace for GLWECiphertext> { + fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl KeySwitchInplace for GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Rhs = GLWEKeySwitchKey; + + fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe_inplace(module, self, scratch); + } +} + +impl ExternalProductScratchSpace for GLWECiphertext> { + fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } +} + +impl ExternalProduct for GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GLWECiphertext; + type Rhs = GGSWCiphertext; + + fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe(module, self, lhs, scratch); + } +} + +impl ExternalProductInplaceScratchSpace for GLWECiphertext> { + fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl ExternalProductInplace for GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + type Rhs = GGSWCiphertext; + + fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe_inplace(module, self, scratch); + } +} + +impl GLWECiphertext> { + pub fn encrypt_pk_scratch_space(module: &Module, pk_size: usize) -> usize { + ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) + + module.bytes_of_scalar_znx_dft(1) + + module.vec_znx_big_normalize_tmp_bytes() + } + + pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + } +} + +impl EncryptSkScratchSpace for GLWECiphertext> { + fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + } +} + +impl EncryptSk for GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + type Ciphertext = GLWECiphertext; + type Plaintext = GLWEPlaintext; + type SecretKey = SecretKeyFourier; + + fn encrypt_sk( + &self, + module: &Module, + ct: &mut Self::Ciphertext, + pt: &Self::Plaintext, + sk: &Self::SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) { + encrypt_glwe_sk( + module, + ct, + Some((pt, 0)), + sk, + source_xa, + source_xe, + sigma, + bound, + scratch, + ); + } +} + +pub(crate) fn encrypt_glwe_sk( + module: &Module, + ct: &mut GLWECiphertext, + pt: Option<(&GLWEPlaintext, usize)>, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let log_base2k: usize = ct.basek(); + let log_k: usize = ct.k(); + let size: usize = ct.size(); + + // c1 = a + ct.data.fill_uniform(log_base2k, 1, size, source_xa); + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); + + { + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + module.vec_znx_dft(&mut c0_dft, 0, ct, 1); + + // c0_dft = DFT(a) * DFT(s) + module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); + + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); + } + + // c0_big = m - c0_big + if let Some((pt, col)) = pt { + match col { + 0 => module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0), + 1 => { + module.vec_znx_big_negate_inplace(&mut c0_big, 0); + module.vec_znx_add_inplace(ct, 1, pt, 0); + module.vec_znx_normalize_inplace(log_base2k, ct, 1, scratch_1); + } + _ => panic!("invalid target column: {}", col), + } + } else { + module.vec_znx_big_negate_inplace(&mut c0_big, 0); + } + // c0_big += e + c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); + + // c0 = norm(c0_big = -as + m + e) + module.vec_znx_big_normalize(log_base2k, ct, 0, &c0_big, 0, scratch_1); +} + +pub fn decrypt_glwe( + module: &Module, + pt: &mut GLWEPlaintext

, + ct: &GLWECiphertext, + sk_dft: &SecretKeyFourier, + scratch: &mut Scratch, +) where + VecZnx

: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct + + { + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct + module.vec_znx_dft(&mut c0_dft, 0, ct, 1); + + // c0_dft = DFT(a) * DFT(s) + module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); + + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); + } + + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_big_add_small_inplace(&mut c0_big, 0, ct, 0); + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(ct.basek(), pt, 0, &mut c0_big, 0, scratch_1); + + pt.log_base2k = ct.basek(); + pt.log_k = pt.k().min(ct.k()); +} + +impl GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &GLWEPlaintext, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + encrypt_glwe_sk( + module, + self, + Some((pt, 0)), + sk_dft, + source_xa, + source_xe, + sigma, + bound, + scratch, + ) + } + + pub fn encrypt_zero_sk( + &mut self, + module: &Module, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + ScalarZnxDft: ScalarZnxDftToRef, + { + encrypt_glwe_sk::( + module, self, None, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ) + } + + pub fn encrypt_pk( + &mut self, + module: &Module, + pt: &GLWEPlaintext, + pk: &PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + { + encrypt_glwe_pk( + module, + self, + Some(pt), + pk, + source_xu, + source_xe, + sigma, + bound, + scratch, + ) + } + + pub fn encrypt_zero_pk( + &mut self, + module: &Module, + pk: &PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef, + { + encrypt_glwe_pk::( + module, self, None, pk, source_xu, source_xe, sigma, bound, scratch, + ) + } +} + +impl GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + pub fn decrypt( + &self, + module: &Module, + pt: &mut GLWEPlaintext, + sk_dft: &SecretKeyFourier, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut + VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + decrypt_glwe(module, pt, self, sk_dft, scratch); + } +} + +pub(crate) fn encrypt_glwe_pk( + module: &Module, + ct: &mut GLWECiphertext, + pt: Option<&GLWEPlaintext

>, + pk: &PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx

: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, +{ + #[cfg(debug_assertions)] + { + assert_eq!(ct.basek(), pk.basek()); + assert_eq!(ct.n(), module.n()); + assert_eq!(pk.n(), module.n()); + if let Some(pt) = pt { + assert_eq!(pt.basek(), pk.basek()); + assert_eq!(pt.n(), module.n()); + } + } + + let log_base2k: usize = pk.basek(); + let size_pk: usize = pk.size(); + + // Generates u according to the underlying secret distribution. + let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); + + { + let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); + match pk.dist { + SecretDistribution::NONE => panic!( + "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate" + ), + SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), + SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + SecretDistribution::ZERO => {} + } + + module.svp_prepare(&mut u_dft, 0, &u, 0); + } + + let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) + let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) + + // ct[0] = pk[0] * u + m + e0 + module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0); + module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); + tmp_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, bound); + + if let Some(pt) = pt { + module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0); + } + + module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3); + + // ct[1] = pk[1] * u + e1 + module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1); + module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); + tmp_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, bound); + module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3); +} + +pub struct GLWEPlaintext { + pub data: VecZnx, + pub log_base2k: usize, + pub log_k: usize, +} + +impl Infos for GLWEPlaintext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.log_base2k + } + + fn k(&self) -> usize { + self.log_k + } +} + +impl VecZnxToMut for GLWEPlaintext +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for GLWEPlaintext +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl GLWEPlaintext> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +pub struct GLWECiphertextFourier { + pub data: VecZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl GLWECiphertextFourier, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for GLWECiphertextFourier { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.log_base2k + } + + fn k(&self) -> usize { + self.log_k + } +} + +impl VecZnxDftToMut for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut, +{ + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl VecZnxDftToRef for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToRef, +{ + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl GLWECiphertextFourier +where + GLWECiphertextFourier: VecZnxDftToRef, +{ + #[allow(dead_code)] + pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { + module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) + } + + pub(crate) fn idft(&self, module: &Module, res: &mut GLWECiphertext, scratch: &mut Scratch) + where + GLWECiphertext: VecZnxToMut, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), 2); + assert_eq!(res.rank(), 2); + assert_eq!(self.basek(), res.basek()) + } + + let min_size: usize = self.size().min(res.size()); + + let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size); + + module.vec_znx_idft(&mut res_big, 0, self, 0, scratch1); + module.vec_znx_idft(&mut res_big, 1, self, 1, scratch1); + module.vec_znx_big_normalize(self.basek(), res, 0, &res_big, 0, scratch1); + module.vec_znx_big_normalize(self.basek(), res, 1, &res_big, 1, scratch1); + } +} + +pub(crate) fn encrypt_zero_glwe_dft_sk( + module: &Module, + ct: &mut GLWECiphertextFourier, + sk: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let log_base2k: usize = ct.basek(); + let log_k: usize = ct.k(); + let size: usize = ct.size(); + + #[cfg(debug_assertions)] + { + match sk.dist { + SecretDistribution::NONE => panic!("invalid sk.dist = SecretDistribution::NONE"), + _ => {} + } + assert_eq!(ct.rank(), 2); + } + + // ct[1] = DFT(a) + { + let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); + tmp_znx.fill_uniform(log_base2k, 0, size, source_xa); + module.vec_znx_dft(ct, 1, &tmp_znx, 0); + } + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); + + { + let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + // c0_dft = ct[1] * DFT(s) + module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1); + + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); + } + + // c0_big += e + c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); + + // c0 = norm(c0_big = -as - e), NOTE: e is centered at 0. + let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); + module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); + module.vec_znx_negate_inplace(&mut tmp_znx, 0); + // ct[0] = DFT(-as + e) + module.vec_znx_dft(ct, 0, &tmp_znx, 0); +} + +impl GLWECiphertextFourier, FFT64> { + pub fn encrypt_zero_sk_scratch_space(module: &Module, size: usize) -> usize { + (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) + + module.bytes_of_vec_znx_big(1, size) + + module.bytes_of_vec_znx(1, size) + + module.vec_znx_big_normalize_tmp_bytes() + } + + pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() + | module.bytes_of_vec_znx_dft(1, size) + | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) + + module.bytes_of_vec_znx_big(1, size) + } +} + +pub fn decrypt_rlwe_dft( + module: &Module, + pt: &mut GLWEPlaintext

, + ct: &GLWECiphertextFourier, + sk: &SecretKeyFourier, + scratch: &mut Scratch, +) where + VecZnx

: VecZnxToMut + VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct + + { + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct + // c0_dft = DFT(a) * DFT(s) + module.svp_apply(&mut c0_dft, 0, sk, 0, ct, 1); + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); + } + + { + let (mut c1_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, ct.size()); + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_idft(&mut c1_big, 0, ct, 0, scratch_2); + module.vec_znx_big_add_inplace(&mut c0_big, 0, &c1_big, 0); + } + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(ct.basek(), pt, 0, &mut c0_big, 0, scratch_1); + + pt.log_base2k = ct.basek(); + pt.log_k = pt.k().min(ct.k()); +} + +impl GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, +{ + pub(crate) fn encrypt_zero_sk( + &mut self, + module: &Module, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + ScalarZnxDft: ScalarZnxDftToRef, + { + encrypt_zero_glwe_dft_sk( + module, self, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ) + } + + pub fn decrypt( + &self, + module: &Module, + pt: &mut GLWEPlaintext

, + sk_dft: &SecretKeyFourier, + scratch: &mut Scratch, + ) where + VecZnx

: VecZnxToMut + VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + decrypt_rlwe_dft(module, pt, self, sk_dft, scratch); + } +} + +impl KeySwitchScratchSpace for GLWECiphertextFourier, FFT64> { + fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } +} + +impl KeySwitch for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GLWECiphertextFourier; + type Rhs = GLWEKeySwitchKey; + + fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe_fourier(module, self, lhs, scratch); + } +} + +impl KeySwitchInplaceScratchSpace for GLWECiphertextFourier, FFT64> { + fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl KeySwitchInplace for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Rhs = GLWEKeySwitchKey; + + fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe_fourier_inplace(module, self, scratch); + } +} + +impl ExternalProductScratchSpace for GLWECiphertextFourier, FFT64> { + fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } +} + +impl ExternalProduct for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GLWECiphertextFourier; + type Rhs = GGSWCiphertext; + + fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe_fourier(module, self, lhs, scratch); + } +} + +impl ExternalProductInplaceScratchSpace for GLWECiphertextFourier, FFT64> { + fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl ExternalProductInplace for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + type Rhs = GGSWCiphertext; + + fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe_fourier_inplace(module, self, scratch); + } +} diff --git a/core/src/keys.rs b/core/src/keys.rs index 8285f85..eaa569e 100644 --- a/core/src/keys.rs +++ b/core/src/keys.rs @@ -5,7 +5,7 @@ use base2k::{ }; use sampling::source::Source; -use crate::{elem::Infos, rlwe::RLWECtDft}; +use crate::{elem::Infos, glwe::GLWECiphertextFourier}; #[derive(Clone, Copy, Debug)] pub enum SecretDistribution { @@ -67,12 +67,12 @@ where } } -pub struct SecretKeyDft { +pub struct SecretKeyFourier { pub data: ScalarZnxDft, pub dist: SecretDistribution, } -impl SecretKeyDft, B> { +impl SecretKeyFourier, B> { pub fn new(module: &Module) -> Self { Self { data: module.new_scalar_znx_dft(1), @@ -82,7 +82,7 @@ impl SecretKeyDft, B> { pub fn dft(&mut self, module: &Module, sk: &SecretKey) where - SecretKeyDft, B>: ScalarZnxDftToMut, + SecretKeyFourier, B>: ScalarZnxDftToMut, SecretKey: ScalarZnxToRef, { #[cfg(debug_assertions)] @@ -98,7 +98,7 @@ impl SecretKeyDft, B> { } } -impl ScalarZnxDftToMut for SecretKeyDft +impl ScalarZnxDftToMut for SecretKeyFourier where ScalarZnxDft: ScalarZnxDftToMut, { @@ -107,7 +107,7 @@ where } } -impl ScalarZnxDftToRef for SecretKeyDft +impl ScalarZnxDftToRef for SecretKeyFourier where ScalarZnxDft: ScalarZnxDftToRef, { @@ -117,14 +117,14 @@ where } pub struct PublicKey { - pub data: RLWECtDft, + pub data: GLWECiphertextFourier, pub dist: SecretDistribution, } impl PublicKey, B> { pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { Self { - data: RLWECtDft::new(module, log_base2k, log_k), + data: GLWECiphertextFourier::new(module, log_base2k, log_k), dist: SecretDistribution::NONE, } } @@ -137,11 +137,11 @@ impl Infos for PublicKey { &self.data.data } - fn log_base2k(&self) -> usize { + fn basek(&self) -> usize { self.data.log_base2k } - fn log_k(&self) -> usize { + fn k(&self) -> usize { self.data.log_k } } @@ -168,7 +168,7 @@ impl PublicKey { pub fn generate( &mut self, module: &Module, - sk_dft: &SecretKeyDft, + sk_dft: &SecretKeyFourier, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -186,7 +186,7 @@ impl PublicKey { } // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::new(RLWECtDft::encrypt_zero_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::new(GLWECiphertextFourier::encrypt_zero_sk_scratch_space( module, self.size(), )); diff --git a/core/src/keyswitch.rs b/core/src/keyswitch.rs new file mode 100644 index 0000000..c77ccb4 --- /dev/null +++ b/core/src/keyswitch.rs @@ -0,0 +1,20 @@ +use base2k::{FFT64, Module, Scratch}; + +pub trait KeySwitchScratchSpace { + fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; +} + +pub trait KeySwitch { + type Lhs; + type Rhs; + fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch); +} + +pub trait KeySwitchInplaceScratchSpace { + fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize; +} + +pub trait KeySwitchInplace { + type Rhs; + fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch); +} diff --git a/core/src/grlwe.rs b/core/src/keyswitch_key.rs similarity index 51% rename from core/src/grlwe.rs rename to core/src/keyswitch_key.rs index 80c976d..cb4c248 100644 --- a/core/src/grlwe.rs +++ b/core/src/keyswitch_key.rs @@ -7,23 +7,26 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{ - GetRow, Infos, MatRLWEProduct, MatRLWEProductScratchSpace, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, - Product, SetRow, + elem::{GetRow, Infos, SetRow}, + encryption::EncryptSkScratchSpace, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, }, - keys::SecretKeyDft, - rgsw::RGSWCt, - rlwe::{RLWECt, RLWECtDft, RLWEPt}, + ggsw::GGSWCiphertext, + glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext}, + keys::SecretKeyFourier, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, utils::derive_size, + vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; -pub struct GRLWECt { +pub struct GLWEKeySwitchKey { pub data: MatZnxDft, pub log_base2k: usize, pub log_k: usize, } -impl GRLWECt, B> { +impl GLWEKeySwitchKey, B> { pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { Self { data: module.new_mat_znx_dft(rows, 1, 2, derive_size(log_base2k, log_k)), @@ -33,23 +36,23 @@ impl GRLWECt, B> { } } -impl Infos for GRLWECt { +impl Infos for GLWEKeySwitchKey { type Inner = MatZnxDft; fn inner(&self) -> &Self::Inner { &self.data } - fn log_base2k(&self) -> usize { + fn basek(&self) -> usize { self.log_base2k } - fn log_k(&self) -> usize { + fn k(&self) -> usize { self.log_k } } -impl MatZnxDftToMut for GRLWECt +impl MatZnxDftToMut for GLWEKeySwitchKey where MatZnxDft: MatZnxDftToMut, { @@ -58,7 +61,7 @@ where } } -impl MatZnxDftToRef for GRLWECt +impl MatZnxDftToRef for GLWEKeySwitchKey where MatZnxDft: MatZnxDftToRef, { @@ -67,20 +70,20 @@ where } } -impl GRLWECt, FFT64> { +impl GLWEKeySwitchKey, FFT64> { pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { - RLWECt::encrypt_sk_scratch_space(module, size) + GLWECiphertext::encrypt_sk_scratch_space(module, size) + module.bytes_of_vec_znx(2, size) + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) } } -pub fn encrypt_grlwe_sk( +pub fn encrypt_glwe_key_switch_key_sk( module: &Module, - ct: &mut GRLWECt, + ct: &mut GLWEKeySwitchKey, pt: &ScalarZnx

, - sk_dft: &SecretKeyDft, + sk_dft: &SecretKeyFourier, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -93,22 +96,22 @@ pub fn encrypt_grlwe_sk( { let rows: usize = ct.rows(); let size: usize = ct.size(); - let log_base2k: usize = ct.log_base2k(); + let log_base2k: usize = ct.basek(); let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size); let (mut vec_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size); - let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt { + let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { data: tmp_znx_pt, log_base2k: log_base2k, - log_k: ct.log_k(), + log_k: ct.k(), }; - let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt { + let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { data: tmp_znx_ct, log_base2k: log_base2k, - log_k: ct.log_k(), + log_k: ct.k(), }; (0..rows).for_each(|row_i| { @@ -119,7 +122,7 @@ pub fn encrypt_grlwe_sk( // rlwe encrypt of vec_znx_pt into vec_znx_ct vec_znx_ct.encrypt_sk( module, - Some(&vec_znx_pt), + &vec_znx_pt, sk_dft, source_xa, source_xe, @@ -139,12 +142,12 @@ pub fn encrypt_grlwe_sk( }); } -impl GRLWECt { +impl GLWEKeySwitchKey { pub fn encrypt_sk( &mut self, module: &Module, pt: &ScalarZnx

, - sk_dft: &SecretKeyDft, + sk_dft: &SecretKeyFourier, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -155,17 +158,17 @@ impl GRLWECt { ScalarZnx

: ScalarZnxToRef, ScalarZnxDft: ScalarZnxDftToRef, { - encrypt_grlwe_sk( + encrypt_glwe_key_switch_key_sk( module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, ) } } -impl GetRow for GRLWECt +impl GetRow for GLWEKeySwitchKey where MatZnxDft: MatZnxDftToRef, { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) where VecZnxDft: VecZnxDftToMut, { @@ -177,11 +180,11 @@ where } } -impl SetRow for GRLWECt +impl SetRow for GLWEKeySwitchKey where MatZnxDft: MatZnxDftToMut, { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) where VecZnxDft: VecZnxDftToRef, { @@ -193,8 +196,92 @@ where } } -impl MatRLWEProductScratchSpace for GRLWECt, FFT64> { - fn prod_with_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { +impl KeySwitchScratchSpace for GLWEKeySwitchKey, FFT64> { + fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl KeySwitch for GLWEKeySwitchKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GLWEKeySwitchKey; + type Rhs = GLWEKeySwitchKey; + + fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe(module, self, lhs, scratch); + } +} + +impl KeySwitchInplaceScratchSpace for GLWEKeySwitchKey, FFT64> { + fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl KeySwitchInplace for GLWEKeySwitchKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Rhs = GLWEKeySwitchKey; + + fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe(module, self, rhs, scratch); + } +} + +impl ExternalProductScratchSpace for GLWEKeySwitchKey, FFT64> { + fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl ExternalProduct for GLWEKeySwitchKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GLWEKeySwitchKey; + type Rhs = GGSWCiphertext; + + fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe(module, self, lhs, scratch); + } +} + +impl ExternalProductInplaceScratchSpace for GLWEKeySwitchKey, FFT64> { + fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl ExternalProductInplace for GLWEKeySwitchKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Rhs = GGSWCiphertext; + + fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe_inplace(module, self, scratch); + } +} + +impl VecGLWEProductScratchSpace for GLWEKeySwitchKey, FFT64> { + fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { module.bytes_of_vec_znx_dft(2, grlwe_size) + (module.vec_znx_big_normalize_tmp_bytes() | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) @@ -202,22 +289,27 @@ impl MatRLWEProductScratchSpace for GRLWECt, FFT64> { } } -impl MatRLWEProduct for GRLWECt +impl VecGLWEProduct for GLWEKeySwitchKey where MatZnxDft: MatZnxDftToRef + ZnxInfos, { - fn prod_with_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) - where + fn prod_with_glwe( + &self, + module: &Module, + res: &mut GLWECiphertext, + a: &GLWECiphertext, + scratch: &mut Scratch, + ) where MatZnxDft: MatZnxDftToRef, VecZnx: VecZnxToMut, VecZnx: VecZnxToRef, { - let log_base2k: usize = self.log_base2k(); + let log_base2k: usize = self.basek(); #[cfg(debug_assertions)] { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(a.log_base2k(), log_base2k); + assert_eq!(res.basek(), log_base2k); + assert_eq!(a.basek(), log_base2k); assert_eq!(self.n(), module.n()); assert_eq!(res.n(), module.n()); assert_eq!(a.n(), module.n()); @@ -239,53 +331,3 @@ where module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); } } - -impl ProdInplaceScratchSpace for GRLWECt, FFT64> { - fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) - } - - fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) - } -} - -impl ProdScratchSpace for GRLWECt, FFT64> { - fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) - } - - fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) - } -} - -impl ProdInplace for GRLWECt -where - GRLWECt: GetRow + SetRow + Infos, - MatZnxDft: MatZnxDftToRef, -{ - fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe_inplace(module, self, scratch); - } - - fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe_inplace(module, self, scratch); - } -} - -impl Product for GRLWECt -where - MatZnxDft: MatZnxDftToRef + MatZnxDftToMut, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GRLWECt; - - fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe(module, self, lhs, scratch); - } - - fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe(module, self, lhs, scratch); - } -} diff --git a/core/src/lib.rs b/core/src/lib.rs index bed71cc..97db860 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,8 +1,12 @@ pub mod elem; -pub mod grlwe; +pub mod encryption; +pub mod external_product; +pub mod ggsw; +pub mod glwe; pub mod keys; -pub mod rgsw; -pub mod rlwe; +pub mod keyswitch; +pub mod keyswitch_key; #[cfg(test)] mod test_fft64; mod utils; +pub mod vec_glwe_product; diff --git a/core/src/rlwe.rs b/core/src/rlwe.rs deleted file mode 100644 index 2dab803..0000000 --- a/core/src/rlwe.rs +++ /dev/null @@ -1,701 +0,0 @@ -use base2k::{ - AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, - ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, - VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, -}; -use sampling::source::Source; - -use crate::{ - elem::{Infos, MatRLWEProduct, MatRLWEProductScratchSpace, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, - grlwe::GRLWECt, - keys::{PublicKey, SecretDistribution, SecretKeyDft}, - rgsw::RGSWCt, - utils::derive_size, -}; - -pub struct RLWECt { - pub data: VecZnx, - pub log_base2k: usize, - pub log_k: usize, -} - -impl RLWECt> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx(2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for RLWECt { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl VecZnxToMut for RLWECt -where - VecZnx: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data.to_mut() - } -} - -impl VecZnxToRef for RLWECt -where - VecZnx: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data.to_ref() - } -} - -impl RLWECt -where - VecZnx: VecZnxToRef, -{ - #[allow(dead_code)] - pub(crate) fn dft(&self, module: &Module, res: &mut RLWECtDft) - where - VecZnxDft: VecZnxDftToMut + ZnxInfos, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.cols(), 2); - assert_eq!(res.cols(), 2); - assert_eq!(self.log_base2k(), res.log_base2k()) - } - - module.vec_znx_dft(res, 0, self, 0); - module.vec_znx_dft(res, 1, self, 1); - } -} - -impl ProdInplaceScratchSpace for RLWECt> { - fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_inplace_scratch_space(module, lhs, rhs) - } - - fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_inplace_scratch_space(module, lhs, rhs) - } -} - -impl ProdScratchSpace for RLWECt> { - fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_scratch_space(module, res_size, lhs, rhs) - } - - fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_scratch_space(module, res_size, lhs, rhs) - } -} - -impl ProdInplace for RLWECt -where - VecZnx: VecZnxToMut + VecZnxToRef, - MatZnxDft: MatZnxDftToRef, -{ - fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_rlwe_inplace(module, self, scratch); - } - - fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_rlwe_inplace(module, self, scratch); - } -} - -impl Product for RLWECt -where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = RLWECt; - - fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_rlwe(module, self, lhs, scratch); - } - - fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_rlwe(module, self, lhs, scratch); - } -} - -impl RLWECt> { - pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) - } - - pub fn encrypt_pk_scratch_space(module: &Module, pk_size: usize) -> usize { - ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) - + module.bytes_of_scalar_znx_dft(1) - + module.vec_znx_big_normalize_tmp_bytes() - } - - pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) - } -} - -pub fn encrypt_rlwe_sk( - module: &Module, - ct: &mut RLWECt, - pt: Option<(&RLWEPt

, usize)>, - sk_dft: &SecretKeyDft, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx

: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let log_base2k: usize = ct.log_base2k(); - let log_k: usize = ct.log_k(); - let size: usize = ct.size(); - - // c1 = a - ct.data.fill_uniform(log_base2k, 1, size, source_xa); - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); - - { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); - module.vec_znx_dft(&mut c0_dft, 0, ct, 1); - - // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); - } - - // c0_big = m - c0_big - if let Some((pt, col)) = pt { - match col { - 0 => module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0), - 1 => { - module.vec_znx_big_negate_inplace(&mut c0_big, 0); - module.vec_znx_add_inplace(ct, 1, pt, 0); - module.vec_znx_normalize_inplace(log_base2k, ct, 1, scratch_1); - } - _ => panic!("invalid target column: {}", col), - } - } else { - module.vec_znx_big_negate_inplace(&mut c0_big, 0); - } - // c0_big += e - c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); - - // c0 = norm(c0_big = -as + m + e) - module.vec_znx_big_normalize(log_base2k, ct, 0, &c0_big, 0, scratch_1); -} - -pub fn decrypt_rlwe( - module: &Module, - pt: &mut RLWEPt

, - ct: &RLWECt, - sk_dft: &SecretKeyDft, - scratch: &mut Scratch, -) where - VecZnx

: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct - - { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct - module.vec_znx_dft(&mut c0_dft, 0, ct, 1); - - // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); - } - - // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_big_add_small_inplace(&mut c0_big, 0, ct, 0); - - // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1); - - pt.log_base2k = ct.log_base2k(); - pt.log_k = pt.log_k().min(ct.log_k()); -} - -impl RLWECt { - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: Option<&RLWEPt

>, - sk_dft: &SecretKeyDft, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx

: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - if let Some(pt) = pt { - encrypt_rlwe_sk( - module, - self, - Some((pt, 0)), - sk_dft, - source_xa, - source_xe, - sigma, - bound, - scratch, - ) - } else { - encrypt_rlwe_sk::( - module, self, None, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ) - } - } - - pub fn decrypt( - &self, - module: &Module, - pt: &mut RLWEPt

, - sk_dft: &SecretKeyDft, - scratch: &mut Scratch, - ) where - VecZnx

: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - decrypt_rlwe(module, pt, self, sk_dft, scratch); - } - - pub fn encrypt_pk( - &mut self, - module: &Module, - pt: Option<&RLWEPt

>, - pk: &PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx

: VecZnxToRef, - VecZnxDft: VecZnxDftToRef, - { - encrypt_rlwe_pk( - module, self, pt, pk, source_xu, source_xe, sigma, bound, scratch, - ) - } -} - -pub(crate) fn encrypt_rlwe_pk( - module: &Module, - ct: &mut RLWECt, - pt: Option<&RLWEPt

>, - pk: &PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx

: VecZnxToRef, - VecZnxDft: VecZnxDftToRef, -{ - #[cfg(debug_assertions)] - { - assert_eq!(ct.log_base2k(), pk.log_base2k()); - assert_eq!(ct.n(), module.n()); - assert_eq!(pk.n(), module.n()); - if let Some(pt) = pt { - assert_eq!(pt.log_base2k(), pk.log_base2k()); - assert_eq!(pt.n(), module.n()); - } - } - - let log_base2k: usize = pk.log_base2k(); - let size_pk: usize = pk.size(); - - // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); - - { - let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); - match pk.dist { - SecretDistribution::NONE => panic!( - "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate" - ), - SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), - SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), - SecretDistribution::ZERO => {} - } - - module.svp_prepare(&mut u_dft, 0, &u, 0); - } - - let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) - let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) - - // ct[0] = pk[0] * u + m + e0 - module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0); - module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); - tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); - - if let Some(pt) = pt { - module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0); - } - - module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3); - - // ct[1] = pk[1] * u + e1 - module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1); - module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); - tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); - module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3); -} - -pub struct RLWEPt { - pub data: VecZnx, - pub log_base2k: usize, - pub log_k: usize, -} - -impl Infos for RLWEPt { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl VecZnxToMut for RLWEPt -where - VecZnx: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data.to_mut() - } -} - -impl VecZnxToRef for RLWEPt -where - VecZnx: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data.to_ref() - } -} - -impl RLWEPt> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -pub struct RLWECtDft { - pub data: VecZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} - -impl RLWECtDft, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for RLWECtDft { - type Inner = VecZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl VecZnxDftToMut for RLWECtDft -where - VecZnxDft: VecZnxDftToMut, -{ - fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { - self.data.to_mut() - } -} - -impl VecZnxDftToRef for RLWECtDft -where - VecZnxDft: VecZnxDftToRef, -{ - fn to_ref(&self) -> VecZnxDft<&[u8], B> { - self.data.to_ref() - } -} - -impl RLWECtDft -where - RLWECtDft: VecZnxDftToRef, -{ - #[allow(dead_code)] - pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { - module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) - } - - pub(crate) fn idft(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) - where - RLWECt: VecZnxToMut, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.cols(), 2); - assert_eq!(res.cols(), 2); - assert_eq!(self.log_base2k(), res.log_base2k()) - } - - let min_size: usize = self.size().min(res.size()); - - let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size); - - module.vec_znx_idft(&mut res_big, 0, self, 0, scratch1); - module.vec_znx_idft(&mut res_big, 1, self, 1, scratch1); - module.vec_znx_big_normalize(self.log_base2k(), res, 0, &res_big, 0, scratch1); - module.vec_znx_big_normalize(self.log_base2k(), res, 1, &res_big, 1, scratch1); - } -} - -pub(crate) fn encrypt_zero_rlwe_dft_sk( - module: &Module, - ct: &mut RLWECtDft, - sk: &SecretKeyDft, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let log_base2k: usize = ct.log_base2k(); - let log_k: usize = ct.log_k(); - let size: usize = ct.size(); - - #[cfg(debug_assertions)] - { - match sk.dist { - SecretDistribution::NONE => panic!("invalid sk.dist = SecretDistribution::NONE"), - _ => {} - } - assert_eq!(ct.cols(), 2); - } - - // ct[1] = DFT(a) - { - let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); - tmp_znx.fill_uniform(log_base2k, 0, size, source_xa); - module.vec_znx_dft(ct, 1, &tmp_znx, 0); - } - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); - - { - let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); - // c0_dft = ct[1] * DFT(s) - module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); - } - - // c0_big += e - c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); - - // c0 = norm(c0_big = -as - e), NOTE: e is centered at 0. - let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); - module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); - module.vec_znx_negate_inplace(&mut tmp_znx, 0); - // ct[0] = DFT(-as + e) - module.vec_znx_dft(ct, 0, &tmp_znx, 0); -} - -impl RLWECtDft, FFT64> { - pub fn encrypt_zero_sk_scratch_space(module: &Module, size: usize) -> usize { - (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) - + module.bytes_of_vec_znx_big(1, size) - + module.bytes_of_vec_znx(1, size) - + module.vec_znx_big_normalize_tmp_bytes() - } - - pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() - | module.bytes_of_vec_znx_dft(1, size) - | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) - + module.bytes_of_vec_znx_big(1, size) - } -} - -pub fn decrypt_rlwe_dft( - module: &Module, - pt: &mut RLWEPt

, - ct: &RLWECtDft, - sk: &SecretKeyDft, - scratch: &mut Scratch, -) where - VecZnx

: VecZnxToMut + VecZnxToRef, - VecZnxDft: VecZnxDftToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct - - { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct - // c0_dft = DFT(a) * DFT(s) - module.svp_apply(&mut c0_dft, 0, sk, 0, ct, 1); - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); - } - - { - let (mut c1_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, ct.size()); - // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_idft(&mut c1_big, 0, ct, 0, scratch_2); - module.vec_znx_big_add_inplace(&mut c0_big, 0, &c1_big, 0); - } - - // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1); - - pt.log_base2k = ct.log_base2k(); - pt.log_k = pt.log_k().min(ct.log_k()); -} - -impl RLWECtDft { - pub(crate) fn encrypt_zero_sk( - &mut self, - module: &Module, - sk_dft: &SecretKeyDft, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - encrypt_zero_rlwe_dft_sk( - module, self, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ) - } - - pub fn decrypt( - &self, - module: &Module, - pt: &mut RLWEPt

, - sk_dft: &SecretKeyDft, - scratch: &mut Scratch, - ) where - VecZnx

: VecZnxToMut + VecZnxToRef, - VecZnxDft: VecZnxDftToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - decrypt_rlwe_dft(module, pt, self, sk_dft, scratch); - } -} - -impl ProdInplaceScratchSpace for RLWECtDft, FFT64> { - fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_inplace_scratch_space(module, lhs, rhs) - } - - fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_inplace_scratch_space(module, lhs, rhs) - } -} - -impl ProdScratchSpace for RLWECtDft, FFT64> { - fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_scratch_space(module, res_size, lhs, rhs) - } - - fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_scratch_space(module, res_size, lhs, rhs) - } -} - -impl ProdInplace for RLWECtDft -where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_rlwe_dft_inplace(module, self, scratch); - } - - fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_rlwe_dft_inplace(module, self, scratch); - } -} - -impl Product for RLWECtDft -where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - VecZnxDft: VecZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = RLWECtDft; - - fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_rlwe_dft(module, self, lhs, scratch); - } - - fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_rlwe_dft(module, self, lhs, scratch); - } -} diff --git a/core/src/test_fft64/grlwe.rs b/core/src/test_fft64/grlwe.rs index 81c1023..9d9a077 100644 --- a/core/src/test_fft64/grlwe.rs +++ b/core/src/test_fft64/grlwe.rs @@ -2,11 +2,15 @@ use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZ use sampling::source::Source; use crate::{ - elem::{GetRow, Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, - grlwe::GRLWECt, - keys::{SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECtDft, RLWEPt}, + elem::{GetRow, Infos}, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, + }, + ggsw::GGSWCiphertext, + glwe::{GLWECiphertextFourier, GLWEPlaintext}, + keys::{SecretKey, SecretKeyFourier}, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, + keyswitch_key::GLWEKeySwitchKey, test_fft64::rgsw::noise_rgsw_product, }; @@ -20,8 +24,8 @@ fn encrypt_sk() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut ct: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_ct, rows); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -31,13 +35,14 @@ fn encrypt_sk() { pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), ); let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct.encrypt_sk( @@ -51,7 +56,7 @@ fn encrypt_sk() { scratch.borrow(), ); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct); (0..ct.rows()).for_each(|row_i| { ct.get_row(&module, row_i, 0, &mut ct_rlwe_dft); @@ -60,12 +65,10 @@ fn encrypt_sk() { let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); }); - - module.free(); } #[test] -fn from_prod_by_grlwe() { +fn keyswitch() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -74,18 +77,18 @@ fn from_prod_by_grlwe() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s0s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s0s1: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s1s2: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s0s2: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) - | GRLWECt::prod_by_grlwe_scratch_space( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) + | GLWEKeySwitchKey::keyswitch_scratch_space( &module, ct_grlwe_s0s2.size(), ct_grlwe_s0s1.size(), @@ -96,19 +99,19 @@ fn from_prod_by_grlwe() { let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); let mut sk2: SecretKey> = SecretKey::new(&module); sk2.fill_ternary_prob(0.5, &mut source_xs); - let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk2_dft.dft(&module, &sk2); // GRLWE_{s1}(s0) = s0 -> s1 @@ -136,10 +139,11 @@ fn from_prod_by_grlwe() { ); // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s0s2.prod_by_grlwe(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); + ct_grlwe_s0s2.keyswitch(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); @@ -166,12 +170,10 @@ fn from_prod_by_grlwe() { noise_want ); }); - - module.free(); } #[test] -fn prod_by_grlwe() { +fn keyswitch_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -180,35 +182,35 @@ fn prod_by_grlwe() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s0s1: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s1s2: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) - | GRLWECt::prod_by_grlwe_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) + | GLWEKeySwitchKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), ); let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); let mut sk2: SecretKey> = SecretKey::new(&module); sk2.fill_ternary_prob(0.5, &mut source_xs); - let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk2_dft.dft(&module, &sk2); // GRLWE_{s1}(s0) = s0 -> s1 @@ -236,12 +238,13 @@ fn prod_by_grlwe() { ); // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s0s1.prod_by_grlwe_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); + ct_grlwe_s0s1.keyswitch_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); - let ct_grlwe_s0s2: GRLWECt, FFT64> = ct_grlwe_s0s1; + let ct_grlwe_s0s2: GLWEKeySwitchKey, FFT64> = ct_grlwe_s0s1; - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); @@ -268,12 +271,10 @@ fn prod_by_grlwe() { noise_want ); }); - - module.free(); } #[test] -fn from_prod_by_rgsw() { +fn external_product() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -282,9 +283,9 @@ fn from_prod_by_rgsw() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe_in: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_out: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_in: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_out: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); @@ -294,15 +295,15 @@ fn from_prod_by_rgsw() { let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_in.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_out.size()) - | GRLWECt::prod_by_rgsw_scratch_space( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_out.size()) + | GLWEKeySwitchKey::external_product_scratch_space( &module, ct_grlwe_out.size(), ct_grlwe_in.size(), ct_rgsw.size(), ) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()), + | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()), ); let k: usize = 1; @@ -314,7 +315,7 @@ fn from_prod_by_rgsw() { let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); // GRLWE_{s1}(s0) = s0 -> s1 @@ -341,10 +342,11 @@ fn from_prod_by_rgsw() { ); // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) - ct_grlwe_out.prod_by_rgsw(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); + ct_grlwe_out.external_product(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); @@ -382,12 +384,10 @@ fn from_prod_by_rgsw() { noise_want ); }); - - module.free(); } #[test] -fn prod_by_rgsw() { +fn external_product_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -396,8 +396,8 @@ fn prod_by_rgsw() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); @@ -407,10 +407,10 @@ fn prod_by_rgsw() { let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe.size()) - | GRLWECt::prod_by_rgsw_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()), + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe.size()) + | GLWEKeySwitchKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()), ); let k: usize = 1; @@ -422,7 +422,7 @@ fn prod_by_rgsw() { let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); // GRLWE_{s1}(s0) = s0 -> s1 @@ -449,10 +449,11 @@ fn prod_by_rgsw() { ); // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) - ct_grlwe.prod_by_rgsw_inplace(&module, &ct_rgsw, scratch.borrow()); + ct_grlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); @@ -490,8 +491,6 @@ fn prod_by_rgsw() { noise_want ); }); - - module.free(); } pub(crate) fn noise_grlwe_rlwe_product( diff --git a/core/src/test_fft64/rgsw.rs b/core/src/test_fft64/rgsw.rs index 50cd356..820b671 100644 --- a/core/src/test_fft64/rgsw.rs +++ b/core/src/test_fft64/rgsw.rs @@ -5,16 +5,20 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{GetRow, Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, - grlwe::GRLWECt, - keys::{SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECtDft, RLWEPt}, + elem::{GetRow, Infos}, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, + }, + ggsw::GGSWCiphertext, + glwe::{GLWECiphertextFourier, GLWEPlaintext}, + keys::{SecretKey, SecretKeyFourier}, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, + keyswitch_key::GLWEKeySwitchKey, test_fft64::grlwe::noise_grlwe_rlwe_product, }; #[test] -fn encrypt_rgsw_sk() { +fn encrypt_sk() { let module: Module = Module::::new(2048); let log_base2k: usize = 8; let log_k_ct: usize = 54; @@ -23,9 +27,9 @@ fn encrypt_rgsw_sk() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -35,13 +39,14 @@ fn encrypt_rgsw_sk() { pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), + GGSWCiphertext::encrypt_sk_scratch_space(&module, ct.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), ); let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct.encrypt_sk( @@ -55,11 +60,11 @@ fn encrypt_rgsw_sk() { scratch.borrow(), ); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); - (0..ct.cols()).for_each(|col_j| { + (0..ct.rank()).for_each(|col_j| { (0..ct.rows()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); @@ -82,12 +87,10 @@ fn encrypt_rgsw_sk() { pt_want.data.zero(); }); }); - - module.free(); } #[test] -fn from_prod_by_grlwe() { +fn keyswitch() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -98,9 +101,9 @@ fn from_prod_by_grlwe() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw_in: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_in, rows); - let mut ct_rgsw_out: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_out, rows); + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_in, rows); + let mut ct_rgsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_out, rows); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -111,10 +114,10 @@ fn from_prod_by_grlwe() { pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw_out.size()) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw_in.size()) - | RGSWCt::prod_by_grlwe_scratch_space( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_out.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw_in.size()) + | GGSWCiphertext::keyswitch_scratch_space( &module, ct_rgsw_out.size(), ct_rgsw_in.size(), @@ -125,13 +128,13 @@ fn from_prod_by_grlwe() { let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -156,15 +159,15 @@ fn from_prod_by_grlwe() { scratch.borrow(), ); - ct_rgsw_out.prod_by_grlwe(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow()); + ct_rgsw_out.keyswitch(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow()); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw_out); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_out); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_out.size()); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); - (0..ct_rgsw_out.cols()).for_each(|col_j| { + (0..ct_rgsw_out.rank()).for_each(|col_j| { (0..ct_rgsw_out.rows()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); @@ -203,12 +206,10 @@ fn from_prod_by_grlwe() { pt_want.data.zero(); }); }); - - module.free(); } #[test] -fn from_prod_by_grlwe_inplace() { +fn keyswitch_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -218,8 +219,8 @@ fn from_prod_by_grlwe_inplace() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw, rows); + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw, rows); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -230,22 +231,22 @@ fn from_prod_by_grlwe_inplace() { pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw.size()) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RGSWCt::prod_by_grlwe_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()), + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()), ); let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -270,15 +271,15 @@ fn from_prod_by_grlwe_inplace() { scratch.borrow(), ); - ct_rgsw.prod_by_grlwe_inplace(&module, &ct_grlwe, scratch.borrow()); + ct_rgsw.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw.size()); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); - (0..ct_rgsw.cols()).for_each(|col_j| { + (0..ct_rgsw.rank()).for_each(|col_j| { (0..ct_rgsw.rows()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); @@ -317,12 +318,10 @@ fn from_prod_by_grlwe_inplace() { pt_want.data.zero(); }); }); - - module.free(); } #[test] -fn from_prod_by_rgsw() { +fn external_product() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_rgsw_rhs: usize = 60; @@ -333,9 +332,9 @@ fn from_prod_by_rgsw() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw_rhs: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_rhs, rows); - let mut ct_rgsw_lhs_in: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_lhs_in, rows); - let mut ct_rgsw_lhs_out: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_lhs_out, rows); + let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows); + let mut ct_rgsw_lhs_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_in, rows); + let mut ct_rgsw_lhs_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_out, rows); let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); @@ -351,10 +350,10 @@ fn from_prod_by_rgsw() { pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size()) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw_lhs_in.size()) - | RGSWCt::prod_by_rgsw_scratch_space( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw_lhs_in.size()) + | GGSWCiphertext::external_product_scratch_space( &module, ct_rgsw_lhs_out.size(), ct_rgsw_lhs_in.size(), @@ -365,7 +364,7 @@ fn from_prod_by_rgsw() { let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct_rgsw_rhs.encrypt_sk( @@ -390,17 +389,18 @@ fn from_prod_by_rgsw() { scratch.borrow(), ); - ct_rgsw_lhs_out.prod_by_rgsw(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow()); + ct_rgsw_lhs_out.external_product(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow()); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw_lhs_out); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs_out); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs_out.size()); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); - (0..ct_rgsw_lhs_out.cols()).for_each(|col_j| { + (0..ct_rgsw_lhs_out.rank()).for_each(|col_j| { (0..ct_rgsw_lhs_out.rows()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); @@ -448,12 +448,10 @@ fn from_prod_by_rgsw() { pt_want.data.zero(); }); }); - - module.free(); } #[test] -fn from_prod_by_rgsw_inplace() { +fn external_product_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_rgsw_rhs: usize = 60; @@ -463,8 +461,8 @@ fn from_prod_by_rgsw_inplace() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw_rhs: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_rhs, rows); - let mut ct_rgsw_lhs: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_lhs, rows); + let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows); + let mut ct_rgsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows); let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); @@ -480,16 +478,16 @@ fn from_prod_by_rgsw_inplace() { pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw_lhs.size()) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw_lhs.size()) - | RGSWCt::prod_by_rgsw_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()), + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw_lhs.size()) + | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()), ); let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct_rgsw_rhs.encrypt_sk( @@ -514,17 +512,17 @@ fn from_prod_by_rgsw_inplace() { scratch.borrow(), ); - ct_rgsw_lhs.prod_by_rgsw_inplace(&module, &ct_rgsw_rhs, scratch.borrow()); + ct_rgsw_lhs.external_product_inplace(&module, &ct_rgsw_rhs, scratch.borrow()); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw_lhs); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs.size()); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); - (0..ct_rgsw_lhs.cols()).for_each(|col_j| { + (0..ct_rgsw_lhs.rank()).for_each(|col_j| { (0..ct_rgsw_lhs.rows()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); @@ -572,8 +570,6 @@ fn from_prod_by_rgsw_inplace() { pt_want.data.zero(); }); }); - - module.free(); } pub(crate) fn noise_rgsw_product( diff --git a/core/src/test_fft64/rlwe.rs b/core/src/test_fft64/rlwe.rs index a2fabb9..6958925 100644 --- a/core/src/test_fft64/rlwe.rs +++ b/core/src/test_fft64/rlwe.rs @@ -6,11 +6,16 @@ use itertools::izip; use sampling::source::Source; use crate::{ - elem::{Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, - grlwe::GRLWECt, - keys::{PublicKey, SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECt, RLWECtDft, RLWEPt}, + elem::Infos, + encryption::EncryptSkScratchSpace, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, + }, + ggsw::GGSWCiphertext, + glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext}, + keys::{PublicKey, SecretKey, SecretKeyFourier}, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, + keyswitch_key::GLWEKeySwitchKey, test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_product}, }; @@ -24,21 +29,21 @@ fn encrypt_sk() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_pt); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_pt); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECt::decrypt_scratch_space(&module, ct.size()), + GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct.size()), ); let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); let mut data_want: Vec = vec![0i64; module.n()]; @@ -52,7 +57,7 @@ fn encrypt_sk() { ct.encrypt_sk( &module, - Some(&pt), + &pt, &sk_dft, &mut source_xa, &mut source_xe, @@ -81,8 +86,6 @@ fn encrypt_sk() { b_scaled ) }); - - module.free(); } #[test] @@ -94,7 +97,7 @@ fn encrypt_zero_sk() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); @@ -102,14 +105,14 @@ fn encrypt_zero_sk() { let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); - let mut ct_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct); let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECtDft::decrypt_scratch_space(&module, ct_dft.size()) - | RLWECtDft::encrypt_zero_sk_scratch_space(&module, ct_dft.size()), + GLWECiphertextFourier::decrypt_scratch_space(&module, ct_dft.size()) + | GLWECiphertextFourier::encrypt_zero_sk_scratch_space(&module, ct_dft.size()), ); ct_dft.encrypt_zero_sk( @@ -124,7 +127,6 @@ fn encrypt_zero_sk() { ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2); - module.free(); } #[test] @@ -137,8 +139,8 @@ fn encrypt_pk() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -147,7 +149,7 @@ fn encrypt_pk() { let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); let mut pk: PublicKey, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk); @@ -161,9 +163,9 @@ fn encrypt_pk() { ); let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECt::encrypt_sk_scratch_space(&module, ct.size()) - | RLWECt::decrypt_scratch_space(&module, ct.size()) - | RLWECt::encrypt_pk_scratch_space(&module, pk.size()), + GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct.size()) + | GLWECiphertext::encrypt_pk_scratch_space(&module, pk.size()), ); let mut data_want: Vec = vec![0i64; module.n()]; @@ -178,7 +180,7 @@ fn encrypt_pk() { ct.encrypt_pk( &module, - Some(&pt_want), + &pt_want, &pk, &mut source_xu, &mut source_xe, @@ -187,19 +189,17 @@ fn encrypt_pk() { scratch.borrow(), ); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2); - - module.free(); } #[test] -fn prod_by_grlwe() { +fn keyswitch() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -210,11 +210,11 @@ fn prod_by_grlwe() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -226,10 +226,10 @@ fn prod_by_grlwe() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECt::prod_by_grlwe_scratch_space( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::keyswitch_scratch_space( &module, ct_rlwe_out.size(), ct_rlwe_in.size(), @@ -240,13 +240,13 @@ fn prod_by_grlwe() { let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -262,7 +262,7 @@ fn prod_by_grlwe() { ct_rlwe_in.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk0_dft, &mut source_xa, &mut source_xe, @@ -271,7 +271,7 @@ fn prod_by_grlwe() { scratch.borrow(), ); - ct_rlwe_out.prod_by_grlwe(&module, &ct_rlwe_in, &ct_grlwe, scratch.borrow()); + ct_rlwe_out.keyswitch(&module, &ct_rlwe_in, &ct_grlwe, scratch.borrow()); ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); @@ -296,12 +296,10 @@ fn prod_by_grlwe() { noise_have, noise_want ); - - module.free(); } #[test] -fn prod_by_grlwe_inplace() { +fn keyswich_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -311,10 +309,10 @@ fn prod_by_grlwe_inplace() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -326,22 +324,22 @@ fn prod_by_grlwe_inplace() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECt::prod_by_grlwe_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), ); let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -357,7 +355,7 @@ fn prod_by_grlwe_inplace() { ct_rlwe.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk0_dft, &mut source_xa, &mut source_xe, @@ -366,7 +364,7 @@ fn prod_by_grlwe_inplace() { scratch.borrow(), ); - ct_rlwe.prod_by_grlwe_inplace(&module, &ct_grlwe, scratch.borrow()); + ct_rlwe.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); @@ -391,12 +389,10 @@ fn prod_by_grlwe_inplace() { noise_have, noise_want ); - - module.free(); } #[test] -fn prod_by_rgsw() { +fn external_product() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -407,12 +403,12 @@ fn prod_by_rgsw() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -430,10 +426,10 @@ fn prod_by_rgsw() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECt::prod_by_grlwe_scratch_space( + GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::external_product_scratch_space( &module, ct_rlwe_out.size(), ct_rlwe_in.size(), @@ -444,7 +440,7 @@ fn prod_by_rgsw() { let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( @@ -460,7 +456,7 @@ fn prod_by_rgsw() { ct_rlwe_in.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk_dft, &mut source_xa, &mut source_xe, @@ -469,7 +465,7 @@ fn prod_by_rgsw() { scratch.borrow(), ); - ct_rlwe_out.prod_by_rgsw(&module, &ct_rlwe_in, &ct_rgsw, scratch.borrow()); + ct_rlwe_out.external_product(&module, &ct_rlwe_in, &ct_rgsw, scratch.borrow()); ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); @@ -505,12 +501,10 @@ fn prod_by_rgsw() { noise_have, noise_want ); - - module.free(); } #[test] -fn prod_by_rgsw_inplace() { +fn external_product_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -521,11 +515,11 @@ fn prod_by_rgsw_inplace() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -543,16 +537,16 @@ fn prod_by_rgsw_inplace() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECt::prod_by_rgsw_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), ); let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( @@ -568,7 +562,7 @@ fn prod_by_rgsw_inplace() { ct_rlwe.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk_dft, &mut source_xa, &mut source_xe, @@ -577,7 +571,7 @@ fn prod_by_rgsw_inplace() { scratch.borrow(), ); - ct_rlwe.prod_by_rgsw_inplace(&module, &ct_rgsw, scratch.borrow()); + ct_rlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); @@ -613,6 +607,4 @@ fn prod_by_rgsw_inplace() { noise_have, noise_want ); - - module.free(); } diff --git a/core/src/test_fft64/rlwe_dft.rs b/core/src/test_fft64/rlwe_dft.rs index fe71a09..06359b1 100644 --- a/core/src/test_fft64/rlwe_dft.rs +++ b/core/src/test_fft64/rlwe_dft.rs @@ -1,16 +1,21 @@ use crate::{ - elem::{Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, - grlwe::GRLWECt, - keys::{SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECt, RLWECtDft, RLWEPt}, + elem::Infos, + encryption::EncryptSkScratchSpace, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, + }, + ggsw::GGSWCiphertext, + glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext}, + keys::{SecretKey, SecretKeyFourier}, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, + keyswitch_key::GLWEKeySwitchKey, test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_product}, }; use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; use sampling::source::Source; #[test] -fn by_grlwe_inplace() { +fn keyswitch() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -21,13 +26,15 @@ fn by_grlwe_inplace() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_in_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut ct_rlwe_out_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_in_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rlwe_out_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -39,10 +46,10 @@ fn by_grlwe_inplace() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECtDft::prod_by_grlwe_scratch_space( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertextFourier::keyswitch_scratch_space( &module, ct_rlwe_out.size(), ct_rlwe_in.size(), @@ -53,13 +60,13 @@ fn by_grlwe_inplace() { let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -75,7 +82,7 @@ fn by_grlwe_inplace() { ct_rlwe_in.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk0_dft, &mut source_xa, &mut source_xe, @@ -85,7 +92,7 @@ fn by_grlwe_inplace() { ); ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); - ct_rlwe_out_dft.prod_by_grlwe(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); + ct_rlwe_out_dft.keyswitch(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); @@ -111,12 +118,10 @@ fn by_grlwe_inplace() { noise_have, noise_want ); - - module.free(); } #[test] -fn prod_by_grlwe_inplace() { +fn keyswich_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -126,11 +131,11 @@ fn prod_by_grlwe_inplace() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -142,22 +147,22 @@ fn prod_by_grlwe_inplace() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECtDft::prod_by_grlwe_inplace_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), ); let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -173,7 +178,7 @@ fn prod_by_grlwe_inplace() { ct_rlwe.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk0_dft, &mut source_xa, &mut source_xe, @@ -183,7 +188,7 @@ fn prod_by_grlwe_inplace() { ); ct_rlwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.prod_by_grlwe_inplace(&module, &ct_grlwe, scratch.borrow()); + ct_rlwe_dft.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); @@ -209,12 +214,10 @@ fn prod_by_grlwe_inplace() { noise_have, noise_want ); - - module.free(); } #[test] -fn prod_by_rgsw() { +fn external_product() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -225,14 +228,16 @@ fn prod_by_rgsw() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut ct_rlwe_dft_in: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_dft_out: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rlwe_dft_in: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_dft_out: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -250,10 +255,10 @@ fn prod_by_rgsw() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECt::prod_by_rgsw_scratch_space( + GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::external_product_scratch_space( &module, ct_rlwe_out.size(), ct_rlwe_in.size(), @@ -264,7 +269,7 @@ fn prod_by_rgsw() { let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( @@ -280,7 +285,7 @@ fn prod_by_rgsw() { ct_rlwe_in.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk_dft, &mut source_xa, &mut source_xe, @@ -290,7 +295,7 @@ fn prod_by_rgsw() { ); ct_rlwe_in.dft(&module, &mut ct_rlwe_dft_in); - ct_rlwe_dft_out.prod_by_rgsw(&module, &ct_rlwe_dft_in, &ct_rgsw, scratch.borrow()); + ct_rlwe_dft_out.external_product(&module, &ct_rlwe_dft_in, &ct_rgsw, scratch.borrow()); ct_rlwe_dft_out.idft(&module, &mut ct_rlwe_out, scratch.borrow()); ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); @@ -327,12 +332,10 @@ fn prod_by_rgsw() { noise_have, noise_want ); - - module.free(); } #[test] -fn prod_by_rgsw_inplace() { +fn external_product_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -343,12 +346,12 @@ fn prod_by_rgsw_inplace() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -366,16 +369,16 @@ fn prod_by_rgsw_inplace() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECt::prod_by_rgsw_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), ); let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( @@ -391,7 +394,7 @@ fn prod_by_rgsw_inplace() { ct_rlwe.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk_dft, &mut source_xa, &mut source_xe, @@ -401,7 +404,7 @@ fn prod_by_rgsw_inplace() { ); ct_rlwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.prod_by_rgsw_inplace(&module, &ct_rgsw, scratch.borrow()); + ct_rlwe_dft.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); @@ -438,6 +441,4 @@ fn prod_by_rgsw_inplace() { noise_have, noise_want ); - - module.free(); } diff --git a/core/src/vec_glwe_product.rs b/core/src/vec_glwe_product.rs new file mode 100644 index 0000000..7920de9 --- /dev/null +++ b/core/src/vec_glwe_product.rs @@ -0,0 +1,197 @@ +use base2k::{ + FFT64, Module, Scratch, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, + VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, +}; + +use crate::{ + elem::{GetRow, Infos, SetRow}, + glwe::{GLWECiphertext, GLWECiphertextFourier}, +}; + +pub(crate) trait VecGLWEProductScratchSpace { + fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; + + fn prod_with_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + Self::prod_with_glwe_scratch_space(module, res_size, res_size, rhs) + } + + fn prod_with_glwe_dft_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + (Self::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(2, lhs) + + module.bytes_of_vec_znx(2, res_size) + } + + fn prod_with_glwe_dft_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + (Self::prod_with_glwe_inplace_scratch_space(module, res_size, rhs) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(2, res_size) + } + + fn prod_with_vec_glwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + Self::prod_with_glwe_dft_scratch_space(module, res_size, lhs, rhs) + + module.bytes_of_vec_znx_dft(2, lhs) + + module.bytes_of_vec_znx_dft(2, res_size) + } + + fn prod_with_vec_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + Self::prod_with_glwe_dft_inplace_scratch_space(module, res_size, rhs) + module.bytes_of_vec_znx_dft(2, res_size) + } +} + +pub(crate) trait VecGLWEProduct: Infos { + fn prod_with_glwe( + &self, + module: &Module, + res: &mut GLWECiphertext, + a: &GLWECiphertext, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef; + + fn prod_with_glwe_inplace(&self, module: &Module, res: &mut GLWECiphertext, scratch: &mut Scratch) + where + VecZnx: VecZnxToMut + VecZnxToRef, + { + unsafe { + let res_ptr: *mut GLWECiphertext = res as *mut GLWECiphertext; // This is ok because [Self::mul_rlwe] only updates res at the end. + self.prod_with_glwe(&module, &mut *res_ptr, &*res_ptr, scratch); + } + } + + fn prod_with_glwe_fourier( + &self, + module: &Module, + res: &mut GLWECiphertextFourier, + a: &GLWECiphertextFourier, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { + let log_base2k: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.basek(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); + + let mut a_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: a_data, + log_base2k: a.basek(), + log_k: a.k(), + }; + + a.idft(module, &mut a_idft, scratch_1); + + let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: res_data, + log_base2k: res.basek(), + log_k: res.k(), + }; + + self.prod_with_glwe(module, &mut res_idft, &a_idft, scratch_2); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + fn prod_with_glwe_fourier_inplace( + &self, + module: &Module, + res: &mut GLWECiphertextFourier, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, + { + let log_base2k: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.basek(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: res_data, + log_base2k: res.basek(), + log_k: res.k(), + }; + + res.idft(module, &mut res_idft, scratch_1); + + self.prod_with_glwe_inplace(module, &mut res_idft, scratch_1); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + fn prod_with_vec_glwe(&self, module: &Module, res: &mut RES, a: &LHS, scratch: &mut Scratch) + where + LHS: GetRow + Infos, + RES: SetRow + Infos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); + + let mut tmp_a_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: a.basek(), + log_k: a.k(), + }; + + let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, res.size()); + + let mut tmp_res_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_res_data, + log_base2k: res.basek(), + log_k: res.k(), + }; + + let min_rows: usize = res.rows().min(a.rows()); + + (0..res.rows()).for_each(|row_i| { + (0..res.rank()).for_each(|col_j| { + a.get_row(module, row_i, col_j, &mut tmp_a_row); + self.prod_with_glwe_fourier(module, &mut tmp_res_row, &tmp_a_row, scratch2); + res.set_row(module, row_i, col_j, &tmp_res_row); + }); + }); + + tmp_res_row.data.zero(); + + (min_rows..res.rows()).for_each(|row_i| { + (0..self.rank()).for_each(|col_j| { + res.set_row(module, row_i, col_j, &tmp_res_row); + }); + }); + } + + fn prod_with_vec_glwe_inplace(&self, module: &Module, res: &mut RES, scratch: &mut Scratch) + where + RES: GetRow + SetRow + Infos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); + + let mut tmp_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: res.basek(), + log_k: res.k(), + }; + + (0..res.rows()).for_each(|row_i| { + (0..res.rank()).for_each(|col_j| { + res.get_row(module, row_i, col_j, &mut tmp_row); + self.prod_with_glwe_fourier_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, col_j, &tmp_row); + }); + }); + } +} From dee889dc0cbd59bf0af784efa66c726ac27816f2 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 13 May 2025 17:21:41 +0200 Subject: [PATCH 62/87] working on adding rank to glwe (all test passing) --- base2k/src/mat_znx_dft.rs | 464 +++++----- core/benches/external_product_glwe_fft64.rs | 2 +- core/src/elem.rs | 6 +- core/src/encryption.rs | 105 --- core/src/external_product.rs | 19 - core/src/gglwe_ciphertext.rs | 253 ++++++ core/src/ggsw.rs | 324 ------- core/src/ggsw_ciphertext.rs | 316 +++++++ core/src/glwe.rs | 845 ------------------ core/src/glwe_ciphertext.rs | 460 ++++++++++ core/src/glwe_ciphertext_fourier.rs | 261 ++++++ core/src/glwe_plaintext.rs | 53 ++ core/src/keys.rs | 83 +- core/src/keyswitch.rs | 20 - core/src/keyswitch_key.rs | 344 +++---- core/src/lib.rs | 10 +- core/src/test_fft64/{grlwe.rs => gglwe.rs} | 129 +-- core/src/test_fft64/{rgsw.rs => ggsw.rs} | 104 ++- core/src/test_fft64/{rlwe.rs => glwe.rs} | 115 +-- .../{rlwe_dft.rs => glwe_fourier.rs} | 96 +- core/src/test_fft64/mod.rs | 8 +- core/src/vec_glwe_product.rs | 33 +- 22 files changed, 2020 insertions(+), 2030 deletions(-) delete mode 100644 core/src/encryption.rs delete mode 100644 core/src/external_product.rs create mode 100644 core/src/gglwe_ciphertext.rs delete mode 100644 core/src/ggsw.rs create mode 100644 core/src/ggsw_ciphertext.rs delete mode 100644 core/src/glwe.rs create mode 100644 core/src/glwe_ciphertext.rs create mode 100644 core/src/glwe_ciphertext_fourier.rs create mode 100644 core/src/glwe_plaintext.rs delete mode 100644 core/src/keyswitch.rs rename core/src/test_fft64/{grlwe.rs => gglwe.rs} (79%) rename core/src/test_fft64/{rgsw.rs => ggsw.rs} (87%) rename core/src/test_fft64/{rlwe.rs => glwe.rs} (84%) rename core/src/test_fft64/{rlwe_dft.rs => glwe_fourier.rs} (84%) diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index c34115d..209c696 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,232 +1,232 @@ -use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; -use std::marker::PhantomData; - -/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], -/// stored as a 3D matrix in the DFT domain in a single contiguous array. -/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft]. -/// -/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. -/// See the trait [MatZnxDftOps] for additional information. -pub struct MatZnxDft { - data: D, - n: usize, - size: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - _phantom: PhantomData, -} - -impl ZnxInfos for MatZnxDft { - fn cols(&self) -> usize { - self.cols_in - } - - fn rows(&self) -> usize { - self.rows - } - - fn n(&self) -> usize { - self.n - } - - fn size(&self) -> usize { - self.size - } -} - -impl ZnxSliceSize for MatZnxDft { - fn sl(&self) -> usize { - self.n() * self.cols_out() - } -} - -impl DataView for MatZnxDft { - type D = D; - fn data(&self) -> &Self::D { - &self.data - } -} - -impl DataViewMut for MatZnxDft { - fn data_mut(&mut self) -> &mut Self::D { - &mut self.data - } -} - -impl> ZnxView for MatZnxDft { - type Scalar = f64; -} - -impl MatZnxDft { - pub(crate) fn cols_in(&self) -> usize { - self.cols_in - } - - pub(crate) fn cols_out(&self) -> usize { - self.cols_out - } -} - -impl>, B: Backend> MatZnxDft { - pub(crate) fn bytes_of(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - unsafe { - crate::ffi::vmp::bytes_of_vmp_pmat( - module.ptr, - (rows * cols_in) as u64, - (size * cols_out) as u64, - ) as usize - } - } - - pub(crate) fn new(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { - let data: Vec = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); - Self { - data: data.into(), - n: module.n(), - size, - rows, - cols_in, - cols_out, - _phantom: PhantomData, - } - } - - pub(crate) fn new_from_bytes( - module: &Module, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: impl Into>, - ) -> Self { - let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size)); - Self { - data: data.into(), - n: module.n(), - size, - rows, - cols_in, - cols_out, - _phantom: PhantomData, - } - } -} - -impl> MatZnxDft { - /// Returns a copy of the backend array at index (i, j) of the [MatZnxDft]. - /// - /// # Arguments - /// - /// * `row`: row index (i). - /// * `col`: col index (j). - #[allow(dead_code)] - fn at(&self, row: usize, col: usize) -> Vec { - let n: usize = self.n(); - - let mut res: Vec = alloc_aligned(n); - - if n < 8 { - res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]); - } else { - (0..n >> 3).for_each(|blk| { - res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); - }); - } - - res - } - - #[allow(dead_code)] - fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { - let nrows: usize = self.rows(); - let nsize: usize = self.size(); - if col == (nsize - 1) && (nsize & 1 == 1) { - &self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..] - } else { - &self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] - } - } -} - -pub type MatZnxDftOwned = MatZnxDft, B>; - -pub trait MatZnxDftToRef { - fn to_ref(&self) -> MatZnxDft<&[u8], B>; -} - -pub trait MatZnxDftToMut { - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>; -} - -impl MatZnxDftToMut for MatZnxDft, B> { - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - MatZnxDft { - data: self.data.as_mut_slice(), - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} - -impl MatZnxDftToRef for MatZnxDft, B> { - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - MatZnxDft { - data: self.data.as_slice(), - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} - -impl MatZnxDftToMut for MatZnxDft<&mut [u8], B> { - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - MatZnxDft { - data: self.data, - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} - -impl MatZnxDftToRef for MatZnxDft<&mut [u8], B> { - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - MatZnxDft { - data: self.data, - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} - -impl MatZnxDftToRef for MatZnxDft<&[u8], B> { - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - MatZnxDft { - data: self.data, - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} +use crate::znx_base::ZnxInfos; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; +use std::marker::PhantomData; + +/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], +/// stored as a 3D matrix in the DFT domain in a single contiguous array. +/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft]. +/// +/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. +/// See the trait [MatZnxDftOps] for additional information. +pub struct MatZnxDft { + data: D, + n: usize, + size: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + _phantom: PhantomData, +} + +impl ZnxInfos for MatZnxDft { + fn cols(&self) -> usize { + self.cols_in + } + + fn rows(&self) -> usize { + self.rows + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } +} + +impl ZnxSliceSize for MatZnxDft { + fn sl(&self) -> usize { + self.n() * self.cols_out() + } +} + +impl DataView for MatZnxDft { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for MatZnxDft { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for MatZnxDft { + type Scalar = f64; +} + +impl MatZnxDft { + pub fn cols_in(&self) -> usize { + self.cols_in + } + + pub fn cols_out(&self) -> usize { + self.cols_out + } +} + +impl>, B: Backend> MatZnxDft { + pub(crate) fn bytes_of(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + unsafe { + crate::ffi::vmp::bytes_of_vmp_pmat( + module.ptr, + (rows * cols_in) as u64, + (size * cols_out) as u64, + ) as usize + } + } + + pub(crate) fn new(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + let data: Vec = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n: module.n(), + size, + rows, + cols_in, + cols_out, + _phantom: PhantomData, + } + } + + pub(crate) fn new_from_bytes( + module: &Module, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: impl Into>, + ) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n: module.n(), + size, + rows, + cols_in, + cols_out, + _phantom: PhantomData, + } + } +} + +impl> MatZnxDft { + /// Returns a copy of the backend array at index (i, j) of the [MatZnxDft]. + /// + /// # Arguments + /// + /// * `row`: row index (i). + /// * `col`: col index (j). + #[allow(dead_code)] + fn at(&self, row: usize, col: usize) -> Vec { + let n: usize = self.n(); + + let mut res: Vec = alloc_aligned(n); + + if n < 8 { + res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]); + } else { + (0..n >> 3).for_each(|blk| { + res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); + }); + } + + res + } + + #[allow(dead_code)] + fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { + let nrows: usize = self.rows(); + let nsize: usize = self.size(); + if col == (nsize - 1) && (nsize & 1 == 1) { + &self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..] + } else { + &self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] + } + } +} + +pub type MatZnxDftOwned = MatZnxDft, B>; + +pub trait MatZnxDftToRef { + fn to_ref(&self) -> MatZnxDft<&[u8], B>; +} + +pub trait MatZnxDftToMut { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>; +} + +impl MatZnxDftToMut for MatZnxDft, B> { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + MatZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft, B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data.as_slice(), + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToMut for MatZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft<&mut [u8], B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft<&[u8], B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} diff --git a/core/benches/external_product_glwe_fft64.rs b/core/benches/external_product_glwe_fft64.rs index 4462fab..435a25f 100644 --- a/core/benches/external_product_glwe_fft64.rs +++ b/core/benches/external_product_glwe_fft64.rs @@ -6,7 +6,7 @@ use rlwe::{ external_product::{ ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, }, - ggsw::GGSWCiphertext, + ggsw_ciphertext::GGSWCiphertext, glwe::GLWECiphertext, keys::{SecretKey, SecretKeyFourier}, }; diff --git a/core/src/elem.rs b/core/src/elem.rs index bf5ca1e..4562137 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,6 +1,6 @@ use base2k::{Backend, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos}; -use crate::{glwe::GLWECiphertextFourier, utils::derive_size}; +use crate::{glwe_ciphertext_fourier::GLWECiphertextFourier, utils::derive_size}; pub trait Infos { type Inner: ZnxInfos; @@ -23,7 +23,7 @@ pub trait Infos { } /// Returns the number of polynomials in each row. - fn rank(&self) -> usize { + fn cols(&self) -> usize { self.inner().cols() } @@ -36,7 +36,7 @@ pub trait Infos { /// Returns the total number of small polynomials. fn poly_count(&self) -> usize { - self.rows() * self.rank() * self.size() + self.rows() * self.cols() * self.size() } /// Returns the base 2 logarithm of the ciphertext base. diff --git a/core/src/encryption.rs b/core/src/encryption.rs deleted file mode 100644 index 915834c..0000000 --- a/core/src/encryption.rs +++ /dev/null @@ -1,105 +0,0 @@ -use base2k::{Backend, Module, Scratch}; -use sampling::source::Source; - -pub trait EncryptSkScratchSpace { - fn encrypt_sk_scratch_space(module: &Module, ct_size: usize) -> usize; -} - -pub trait EncryptSk { - type Ciphertext; - type Plaintext; - type SecretKey; - - fn encrypt_sk( - &self, - module: &Module, - ct: &mut Self::Ciphertext, - pt: &Self::Plaintext, - sk: &Self::SecretKey, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ); -} - -pub trait EncryptZeroSkScratchSpace { - fn encrypt_zero_sk_scratch_space(module: &Module, ct_size: usize) -> usize; -} - -pub trait EncryptZeroSk { - type Ciphertext; - type SecretKey; - - fn encrypt_zero_sk( - &self, - module: &Module, - ct: &mut Self::Ciphertext, - sk: &Self::SecretKey, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ); -} - -pub trait EncryptPkScratchSpace { - fn encrypt_pk_scratch_space(module: &Module, ct_size: usize) -> usize; -} - -pub trait EncryptPk { - type Ciphertext; - type Plaintext; - type PublicKey; - - fn encrypt_pk( - &self, - module: &Module, - ct: &mut Self::Ciphertext, - pt: &Self::Plaintext, - pk: &Self::PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ); -} - -pub trait EncryptZeroPkScratchSpace { - fn encrypt_zero_pk_scratch_space(module: &Module, ct_size: usize) -> usize; -} - -pub trait EncryptZeroPk { - type Ciphertext; - type PublicKey; - - fn encrypt_zero_pk( - &self, - module: &Module, - ct: &mut Self::Ciphertext, - pk: &Self::PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ); -} - -pub trait Decrypt { - type Plaintext; - type Ciphertext; - type SecretKey; - - fn decrypt( - &self, - module: &Module, - pt: &mut Self::Plaintext, - ct: &Self::Ciphertext, - sk: &Self::SecretKey, - scratch: &mut Scratch, - ); -} diff --git a/core/src/external_product.rs b/core/src/external_product.rs deleted file mode 100644 index e8d0a7e..0000000 --- a/core/src/external_product.rs +++ /dev/null @@ -1,19 +0,0 @@ -use base2k::{FFT64, Module, Scratch}; - -pub trait ExternalProductScratchSpace { - fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; -} - -pub trait ExternalProduct { - type Lhs; - type Rhs; - fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch); -} -pub trait ExternalProductInplaceScratchSpace { - fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize; -} - -pub trait ExternalProductInplace { - type Rhs; - fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch); -} diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs new file mode 100644 index 0000000..9d7c45a --- /dev/null +++ b/core/src/gglwe_ciphertext.rs @@ -0,0 +1,253 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, + ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, + ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos, SetRow}, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::SecretKeyFourier, + utils::derive_size, + vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, +}; + +pub struct GGLWECiphertext { + pub(crate) data: MatZnxDft, + pub(crate) basek: usize, + pub(crate) k: usize, +} + +impl GGLWECiphertext, B> { + pub fn new(module: &Module, base2k: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { + Self { + data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(base2k, k)), + basek: base2k, + k, + } + } +} + +impl Infos for GGLWECiphertext { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GGLWECiphertext { + pub fn rank(&self) -> usize { + self.data.cols_out() - 1 + } +} + +impl MatZnxDftToMut for GGLWECiphertext +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl MatZnxDftToRef for GGLWECiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl GGLWECiphertext, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + GLWECiphertext::encrypt_sk_scratch_space(module, rank, size) + + module.bytes_of_vec_znx(rank + 1, size) + + module.bytes_of_vec_znx(1, size) + + module.bytes_of_vec_znx_dft(rank + 1, size) + } + + pub fn encrypt_pk_scratch_space(_module: &Module, _rank: usize, _pk_size: usize) -> usize { + unimplemented!() + } +} + +impl GGLWECiphertext +where + MatZnxDft: MatZnxDftToMut + ZnxInfos, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + ScalarZnx: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + assert_eq!(pt.n(), module.n()); + } + + let rows: usize = self.rows(); + let size: usize = self.size(); + let basek: usize = self.basek(); + let k: usize = self.k(); + + let cols: usize = self.rank() + 1; + + let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); + let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, cols, size); + let (tmp_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, cols, size); + + let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { + data: tmp_znx_pt, + basek, + k, + }; + + let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { + data: tmp_znx_ct, + basek, + k, + }; + + let mut vec_znx_ct_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier { + data: tmp_znx_dft_ct, + basek, + k, + }; + + (0..rows).for_each(|row_i| { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0); + module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3); + + // rlwe encrypt of vec_znx_pt into vec_znx_ct + vec_znx_ct.encrypt_sk( + module, + &vec_znx_pt, + sk_dft, + source_xa, + source_xe, + sigma, + bound, + scratch_3, + ); + + vec_znx_pt.data.zero(); // zeroes for next iteration + + // Switch vec_znx_ct into DFT domain + vec_znx_ct.dft(module, &mut vec_znx_ct_dft); + + // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft + module.vmp_prepare_row(self, row_i, 0, &vec_znx_ct_dft); + }); + } +} + +impl GetRow for GGLWECiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToMut, + { + #[cfg(debug_assertions)] + { + assert_eq!(col_j, 0); + } + module.vmp_extract_row(res, self, row_i, col_j); + } +} + +impl SetRow for GGLWECiphertext +where + MatZnxDft: MatZnxDftToMut, +{ + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(col_j, 0); + } + module.vmp_prepare_row(self, row_i, col_j, a); + } +} + +impl VecGLWEProductScratchSpace for GGLWECiphertext, FFT64> { + fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + module.bytes_of_vec_znx_dft(2, grlwe_size) + + (module.vec_znx_big_normalize_tmp_bytes() + | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) + + module.bytes_of_vec_znx_dft(1, a_size))) + } +} + +impl VecGLWEProduct for GGLWECiphertext +where + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + fn prod_with_glwe( + &self, + module: &Module, + res: &mut GLWECiphertext, + a: &GLWECiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef, + { + let log_base2k: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.basek(), log_base2k); + assert_eq!(a.basek(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), module.n()); + } + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise + + { + let (mut a1_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, a.size()); + module.vec_znx_dft(&mut a1_dft, 0, a, 1); + module.vmp_apply(&mut res_dft, &a1_dft, self, scratch2); + } + + let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); + + module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); + module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); + } +} diff --git a/core/src/ggsw.rs b/core/src/ggsw.rs deleted file mode 100644 index 79b12a5..0000000 --- a/core/src/ggsw.rs +++ /dev/null @@ -1,324 +0,0 @@ -use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, - ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, - ZnxZero, -}; -use sampling::source::Source; - -use crate::{ - elem::{GetRow, Infos, SetRow}, - encryption::EncryptSkScratchSpace, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, - glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext, encrypt_glwe_sk}, - keys::SecretKeyFourier, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - keyswitch_key::GLWEKeySwitchKey, - utils::derive_size, - vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, -}; - -pub struct GGSWCiphertext { - pub data: MatZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} - -impl GGSWCiphertext, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { - Self { - data: module.new_mat_znx_dft(rows, 2, 2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for GGSWCiphertext { - type Inner = MatZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.log_base2k - } - - fn k(&self) -> usize { - self.log_k - } -} - -impl MatZnxDftToMut for GGSWCiphertext -where - MatZnxDft: MatZnxDftToMut, -{ - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - self.data.to_mut() - } -} - -impl MatZnxDftToRef for GGSWCiphertext -where - MatZnxDft: MatZnxDftToRef, -{ - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - self.data.to_ref() - } -} - -impl GGSWCiphertext, FFT64> { - pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { - GLWECiphertext::encrypt_sk_scratch_space(module, size) - + module.bytes_of_vec_znx(2, size) - + module.bytes_of_vec_znx(1, size) - + module.bytes_of_vec_znx_dft(2, size) - } -} - -pub fn encrypt_rgsw_sk( - module: &Module, - ct: &mut GGSWCiphertext, - pt: &ScalarZnx

, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - MatZnxDft: MatZnxDftToMut, - ScalarZnx

: ScalarZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let size: usize = ct.size(); - let log_base2k: usize = ct.basek(); - - let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size); - let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, 2, size); - - let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { - data: tmp_znx_pt, - log_base2k: log_base2k, - log_k: ct.k(), - }; - - let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { - data: tmp_znx_ct, - log_base2k: log_base2k, - log_k: ct.k(), - }; - - (0..ct.rows()).for_each(|row_j| { - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0); - module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2); - - (0..ct.rank()).for_each(|col_i| { - // rlwe encrypt of vec_znx_pt into vec_znx_ct - encrypt_glwe_sk( - module, - &mut vec_znx_ct, - Some((&vec_znx_pt, col_i)), - sk_dft, - source_xa, - source_xe, - sigma, - bound, - scrach_2, - ); - - // Switch vec_znx_ct into DFT domain - { - let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, 2, size); - module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0); - module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1); - module.vmp_prepare_row(ct, row_j, col_i, &vec_znx_dft_ct); - } - }); - - vec_znx_pt.data.zero(); // zeroes for next iteration - }); -} - -impl GGSWCiphertext { - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &ScalarZnx

, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToMut, - ScalarZnx

: ScalarZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - encrypt_rgsw_sk( - module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ) - } -} - -impl GetRow for GGSWCiphertext -where - MatZnxDft: MatZnxDftToRef, -{ - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) - where - VecZnxDft: VecZnxDftToMut, - { - module.vmp_extract_row(res, self, row_i, col_j); - } -} - -impl SetRow for GGSWCiphertext -where - MatZnxDft: MatZnxDftToMut, -{ - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) - where - VecZnxDft: VecZnxDftToRef, - { - module.vmp_prepare_row(self, row_i, col_j, a); - } -} - -impl KeySwitchScratchSpace for GGSWCiphertext, FFT64> { - fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, - ) - } -} - -impl KeySwitch for GGSWCiphertext -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GGSWCiphertext; - type Rhs = GLWEKeySwitchKey; - - fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe(module, self, lhs, scratch); - } -} - -impl KeySwitchInplaceScratchSpace for GGSWCiphertext, FFT64> { - fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( - module, res_size, rhs, - ) - } -} - -impl KeySwitchInplace for GGSWCiphertext -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Rhs = GLWEKeySwitchKey; - - fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe(module, self, rhs, scratch); - } -} - -impl ExternalProductScratchSpace for GGSWCiphertext, FFT64> { - fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, - ) - } -} - -impl ExternalProduct for GGSWCiphertext -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GGSWCiphertext; - type Rhs = GGSWCiphertext; - - fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe(module, self, lhs, scratch); - } -} - -impl ExternalProductInplaceScratchSpace for GGSWCiphertext, FFT64> { - fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, - ) - } -} - -impl ExternalProductInplace for GGSWCiphertext -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Rhs = GGSWCiphertext; - - fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe_inplace(module, self, scratch); - } -} - -impl VecGLWEProductScratchSpace for GGSWCiphertext, FFT64> { - fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { - module.bytes_of_vec_znx_dft(2, rgsw_size) - + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) - | module.vec_znx_big_normalize_tmp_bytes()) - } -} - -impl VecGLWEProduct for GGSWCiphertext -where - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - fn prod_with_glwe( - &self, - module: &Module, - res: &mut GLWECiphertext, - a: &GLWECiphertext, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef, - { - let log_base2k: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.basek(), log_base2k); - assert_eq!(a.basek(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - assert_eq!(a.n(), module.n()); - } - - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise - - { - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, a.size()); - module.vec_znx_dft(&mut a_dft, 0, a, 0); - module.vec_znx_dft(&mut a_dft, 1, a, 1); - module.vmp_apply(&mut res_dft, &a_dft, self, scratch2); - } - - let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - - module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); - module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); - } -} diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs new file mode 100644 index 0000000..9d42df8 --- /dev/null +++ b/core/src/ggsw_ciphertext.rs @@ -0,0 +1,316 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, + ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, + ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos, SetRow}, + gglwe_ciphertext::GGLWECiphertext, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::SecretKeyFourier, + keyswitch_key::GLWESwitchingKey, + utils::derive_size, + vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, +}; + +pub struct GGSWCiphertext { + pub data: MatZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl GGSWCiphertext, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize, rank: usize) -> Self { + Self { + data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for GGSWCiphertext { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.log_base2k + } + + fn k(&self) -> usize { + self.log_k + } +} + +impl GGSWCiphertext { + pub fn rank(&self) -> usize { + self.data.cols_out() - 1 + } +} + +impl MatZnxDftToMut for GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl MatZnxDftToRef for GGSWCiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl GGSWCiphertext, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + GLWECiphertext::encrypt_sk_scratch_space(module, rank, size) + + module.bytes_of_vec_znx(rank + 1, size) + + module.bytes_of_vec_znx(1, size) + + module.bytes_of_vec_znx_dft(rank + 1, size) + } + + pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( + module, res_size, lhs, rhs, + ) + } + + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } + + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( + module, res_size, lhs, rhs, + ) + } + + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + ScalarZnx: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + } + + let size: usize = self.size(); + let log_base2k: usize = self.basek(); + let k: usize = self.k(); + let cols: usize = self.rank() + 1; + + let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size); + let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, cols, size); + + let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { + data: tmp_znx_pt, + basek: log_base2k, + k: k, + }; + + let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { + data: tmp_znx_ct, + basek: log_base2k, + k, + }; + + (0..self.rows()).for_each(|row_j| { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0); + module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2); + + (0..cols).for_each(|col_i| { + // rlwe encrypt of vec_znx_pt into vec_znx_ct + + vec_znx_ct.encrypt_sk_private( + module, + Some((&vec_znx_pt, col_i)), + sk_dft, + source_xa, + source_xe, + sigma, + bound, + scrach_2, + ); + + // Switch vec_znx_ct into DFT domain + { + let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, cols, size); + + (0..cols).for_each(|i| { + module.vec_znx_dft(&mut vec_znx_dft_ct, i, &vec_znx_ct, i); + }); + + module.vmp_prepare_row(self, row_j, col_i, &vec_znx_dft_ct); + } + }); + + vec_znx_pt.data.zero(); // zeroes for next iteration + }); + } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + rhs.0.prod_with_vec_glwe(module, self, lhs, scratch); + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.0.prod_with_vec_glwe_inplace(module, self, scratch); + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + rhs.prod_with_vec_glwe(module, self, lhs, scratch); + } + + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.prod_with_vec_glwe_inplace(module, self, scratch); + } +} + +impl GetRow for GGSWCiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + fn get_row( + &self, + module: &Module, + row_i: usize, + col_j: usize, + res: &mut GLWECiphertextFourier, + ) where + VecZnxDft: VecZnxDftToMut, + { + module.vmp_extract_row(res, self, row_i, col_j); + } +} + +impl SetRow for GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut, +{ + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToRef, + { + module.vmp_prepare_row(self, row_i, col_j, a); + } +} + +impl VecGLWEProductScratchSpace for GGSWCiphertext, FFT64> { + fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { + module.bytes_of_vec_znx_dft(2, rgsw_size) + + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) + | module.vec_znx_big_normalize_tmp_bytes()) + } +} + +impl VecGLWEProduct for GGSWCiphertext +where + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + fn prod_with_glwe( + &self, + module: &Module, + res: &mut GLWECiphertext, + a: &GLWECiphertext, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef, + { + let log_base2k: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.basek(), log_base2k); + assert_eq!(a.basek(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), module.n()); + } + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise + + { + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, a.size()); + module.vec_znx_dft(&mut a_dft, 0, a, 0); + module.vec_znx_dft(&mut a_dft, 1, a, 1); + module.vmp_apply(&mut res_dft, &a_dft, self, scratch2); + } + + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); + module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); + } +} diff --git a/core/src/glwe.rs b/core/src/glwe.rs deleted file mode 100644 index e50582d..0000000 --- a/core/src/glwe.rs +++ /dev/null @@ -1,845 +0,0 @@ -use base2k::{ - AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, - ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, - VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, -}; -use sampling::source::Source; - -use crate::{ - elem::Infos, - encryption::{EncryptSk, EncryptSkScratchSpace, EncryptZeroSkScratchSpace}, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, - ggsw::GGSWCiphertext, - keys::{PublicKey, SecretDistribution, SecretKeyFourier}, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - keyswitch_key::GLWEKeySwitchKey, - utils::derive_size, - vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, -}; - -pub struct GLWECiphertext { - pub data: VecZnx, - pub log_base2k: usize, - pub log_k: usize, -} - -impl GLWECiphertext> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx(2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for GLWECiphertext { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.log_base2k - } - - fn k(&self) -> usize { - self.log_k - } -} - -impl VecZnxToMut for GLWECiphertext -where - VecZnx: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data.to_mut() - } -} - -impl VecZnxToRef for GLWECiphertext -where - VecZnx: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data.to_ref() - } -} - -impl GLWECiphertext -where - VecZnx: VecZnxToRef, -{ - #[allow(dead_code)] - pub(crate) fn dft(&self, module: &Module, res: &mut GLWECiphertextFourier) - where - VecZnxDft: VecZnxDftToMut + ZnxInfos, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), 2); - assert_eq!(res.rank(), 2); - assert_eq!(self.basek(), res.basek()) - } - - module.vec_znx_dft(res, 0, self, 0); - module.vec_znx_dft(res, 1, self, 1); - } -} - -impl KeySwitchScratchSpace for GLWECiphertext> { - fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) - } -} - -impl KeySwitch for GLWECiphertext -where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GLWECiphertext; - type Rhs = GLWEKeySwitchKey; - - fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe(module, self, lhs, scratch); - } -} - -impl KeySwitchInplaceScratchSpace for GLWECiphertext> { - fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, - ) - } -} - -impl KeySwitchInplace for GLWECiphertext -where - VecZnx: VecZnxToMut + VecZnxToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Rhs = GLWEKeySwitchKey; - - fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe_inplace(module, self, scratch); - } -} - -impl ExternalProductScratchSpace for GLWECiphertext> { - fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) - } -} - -impl ExternalProduct for GLWECiphertext -where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GLWECiphertext; - type Rhs = GGSWCiphertext; - - fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe(module, self, lhs, scratch); - } -} - -impl ExternalProductInplaceScratchSpace for GLWECiphertext> { - fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, - ) - } -} - -impl ExternalProductInplace for GLWECiphertext -where - VecZnx: VecZnxToMut + VecZnxToRef, - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - type Rhs = GGSWCiphertext; - - fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe_inplace(module, self, scratch); - } -} - -impl GLWECiphertext> { - pub fn encrypt_pk_scratch_space(module: &Module, pk_size: usize) -> usize { - ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) - + module.bytes_of_scalar_znx_dft(1) - + module.vec_znx_big_normalize_tmp_bytes() - } - - pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) - } -} - -impl EncryptSkScratchSpace for GLWECiphertext> { - fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) - } -} - -impl EncryptSk for GLWECiphertext -where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - type Ciphertext = GLWECiphertext; - type Plaintext = GLWEPlaintext; - type SecretKey = SecretKeyFourier; - - fn encrypt_sk( - &self, - module: &Module, - ct: &mut Self::Ciphertext, - pt: &Self::Plaintext, - sk: &Self::SecretKey, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) { - encrypt_glwe_sk( - module, - ct, - Some((pt, 0)), - sk, - source_xa, - source_xe, - sigma, - bound, - scratch, - ); - } -} - -pub(crate) fn encrypt_glwe_sk( - module: &Module, - ct: &mut GLWECiphertext, - pt: Option<(&GLWEPlaintext, usize)>, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let log_base2k: usize = ct.basek(); - let log_k: usize = ct.k(); - let size: usize = ct.size(); - - // c1 = a - ct.data.fill_uniform(log_base2k, 1, size, source_xa); - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); - - { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); - module.vec_znx_dft(&mut c0_dft, 0, ct, 1); - - // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); - } - - // c0_big = m - c0_big - if let Some((pt, col)) = pt { - match col { - 0 => module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0), - 1 => { - module.vec_znx_big_negate_inplace(&mut c0_big, 0); - module.vec_znx_add_inplace(ct, 1, pt, 0); - module.vec_znx_normalize_inplace(log_base2k, ct, 1, scratch_1); - } - _ => panic!("invalid target column: {}", col), - } - } else { - module.vec_znx_big_negate_inplace(&mut c0_big, 0); - } - // c0_big += e - c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); - - // c0 = norm(c0_big = -as + m + e) - module.vec_znx_big_normalize(log_base2k, ct, 0, &c0_big, 0, scratch_1); -} - -pub fn decrypt_glwe( - module: &Module, - pt: &mut GLWEPlaintext

, - ct: &GLWECiphertext, - sk_dft: &SecretKeyFourier, - scratch: &mut Scratch, -) where - VecZnx

: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct - - { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct - module.vec_znx_dft(&mut c0_dft, 0, ct, 1); - - // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); - } - - // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_big_add_small_inplace(&mut c0_big, 0, ct, 0); - - // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize(ct.basek(), pt, 0, &mut c0_big, 0, scratch_1); - - pt.log_base2k = ct.basek(); - pt.log_k = pt.k().min(ct.k()); -} - -impl GLWECiphertext -where - VecZnx: VecZnxToMut + VecZnxToRef, -{ - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &GLWEPlaintext, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - encrypt_glwe_sk( - module, - self, - Some((pt, 0)), - sk_dft, - source_xa, - source_xe, - sigma, - bound, - scratch, - ) - } - - pub fn encrypt_zero_sk( - &mut self, - module: &Module, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - ScalarZnxDft: ScalarZnxDftToRef, - { - encrypt_glwe_sk::( - module, self, None, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ) - } - - pub fn encrypt_pk( - &mut self, - module: &Module, - pt: &GLWEPlaintext, - pk: &PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToRef, - VecZnxDft: VecZnxDftToRef, - { - encrypt_glwe_pk( - module, - self, - Some(pt), - pk, - source_xu, - source_xe, - sigma, - bound, - scratch, - ) - } - - pub fn encrypt_zero_pk( - &mut self, - module: &Module, - pk: &PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - VecZnxDft: VecZnxDftToRef, - { - encrypt_glwe_pk::( - module, self, None, pk, source_xu, source_xe, sigma, bound, scratch, - ) - } -} - -impl GLWECiphertext -where - VecZnx: VecZnxToRef, -{ - pub fn decrypt( - &self, - module: &Module, - pt: &mut GLWEPlaintext, - sk_dft: &SecretKeyFourier, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToMut + VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - decrypt_glwe(module, pt, self, sk_dft, scratch); - } -} - -pub(crate) fn encrypt_glwe_pk( - module: &Module, - ct: &mut GLWECiphertext, - pt: Option<&GLWEPlaintext

>, - pk: &PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx

: VecZnxToRef, - VecZnxDft: VecZnxDftToRef, -{ - #[cfg(debug_assertions)] - { - assert_eq!(ct.basek(), pk.basek()); - assert_eq!(ct.n(), module.n()); - assert_eq!(pk.n(), module.n()); - if let Some(pt) = pt { - assert_eq!(pt.basek(), pk.basek()); - assert_eq!(pt.n(), module.n()); - } - } - - let log_base2k: usize = pk.basek(); - let size_pk: usize = pk.size(); - - // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); - - { - let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); - match pk.dist { - SecretDistribution::NONE => panic!( - "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate" - ), - SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), - SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), - SecretDistribution::ZERO => {} - } - - module.svp_prepare(&mut u_dft, 0, &u, 0); - } - - let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) - let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) - - // ct[0] = pk[0] * u + m + e0 - module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0); - module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); - tmp_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, bound); - - if let Some(pt) = pt { - module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0); - } - - module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3); - - // ct[1] = pk[1] * u + e1 - module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1); - module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); - tmp_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, bound); - module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3); -} - -pub struct GLWEPlaintext { - pub data: VecZnx, - pub log_base2k: usize, - pub log_k: usize, -} - -impl Infos for GLWEPlaintext { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.log_base2k - } - - fn k(&self) -> usize { - self.log_k - } -} - -impl VecZnxToMut for GLWEPlaintext -where - VecZnx: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data.to_mut() - } -} - -impl VecZnxToRef for GLWEPlaintext -where - VecZnx: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data.to_ref() - } -} - -impl GLWEPlaintext> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -pub struct GLWECiphertextFourier { - pub data: VecZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} - -impl GLWECiphertextFourier, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for GLWECiphertextFourier { - type Inner = VecZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.log_base2k - } - - fn k(&self) -> usize { - self.log_k - } -} - -impl VecZnxDftToMut for GLWECiphertextFourier -where - VecZnxDft: VecZnxDftToMut, -{ - fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { - self.data.to_mut() - } -} - -impl VecZnxDftToRef for GLWECiphertextFourier -where - VecZnxDft: VecZnxDftToRef, -{ - fn to_ref(&self) -> VecZnxDft<&[u8], B> { - self.data.to_ref() - } -} - -impl GLWECiphertextFourier -where - GLWECiphertextFourier: VecZnxDftToRef, -{ - #[allow(dead_code)] - pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { - module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) - } - - pub(crate) fn idft(&self, module: &Module, res: &mut GLWECiphertext, scratch: &mut Scratch) - where - GLWECiphertext: VecZnxToMut, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), 2); - assert_eq!(res.rank(), 2); - assert_eq!(self.basek(), res.basek()) - } - - let min_size: usize = self.size().min(res.size()); - - let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size); - - module.vec_znx_idft(&mut res_big, 0, self, 0, scratch1); - module.vec_znx_idft(&mut res_big, 1, self, 1, scratch1); - module.vec_znx_big_normalize(self.basek(), res, 0, &res_big, 0, scratch1); - module.vec_znx_big_normalize(self.basek(), res, 1, &res_big, 1, scratch1); - } -} - -pub(crate) fn encrypt_zero_glwe_dft_sk( - module: &Module, - ct: &mut GLWECiphertextFourier, - sk: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let log_base2k: usize = ct.basek(); - let log_k: usize = ct.k(); - let size: usize = ct.size(); - - #[cfg(debug_assertions)] - { - match sk.dist { - SecretDistribution::NONE => panic!("invalid sk.dist = SecretDistribution::NONE"), - _ => {} - } - assert_eq!(ct.rank(), 2); - } - - // ct[1] = DFT(a) - { - let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); - tmp_znx.fill_uniform(log_base2k, 0, size, source_xa); - module.vec_znx_dft(ct, 1, &tmp_znx, 0); - } - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); - - { - let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); - // c0_dft = ct[1] * DFT(s) - module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); - } - - // c0_big += e - c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); - - // c0 = norm(c0_big = -as - e), NOTE: e is centered at 0. - let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); - module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); - module.vec_znx_negate_inplace(&mut tmp_znx, 0); - // ct[0] = DFT(-as + e) - module.vec_znx_dft(ct, 0, &tmp_znx, 0); -} - -impl GLWECiphertextFourier, FFT64> { - pub fn encrypt_zero_sk_scratch_space(module: &Module, size: usize) -> usize { - (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) - + module.bytes_of_vec_znx_big(1, size) - + module.bytes_of_vec_znx(1, size) - + module.vec_znx_big_normalize_tmp_bytes() - } - - pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() - | module.bytes_of_vec_znx_dft(1, size) - | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) - + module.bytes_of_vec_znx_big(1, size) - } -} - -pub fn decrypt_rlwe_dft( - module: &Module, - pt: &mut GLWEPlaintext

, - ct: &GLWECiphertextFourier, - sk: &SecretKeyFourier, - scratch: &mut Scratch, -) where - VecZnx

: VecZnxToMut + VecZnxToRef, - VecZnxDft: VecZnxDftToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct - - { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct - // c0_dft = DFT(a) * DFT(s) - module.svp_apply(&mut c0_dft, 0, sk, 0, ct, 1); - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); - } - - { - let (mut c1_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, ct.size()); - // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_idft(&mut c1_big, 0, ct, 0, scratch_2); - module.vec_znx_big_add_inplace(&mut c0_big, 0, &c1_big, 0); - } - - // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize(ct.basek(), pt, 0, &mut c0_big, 0, scratch_1); - - pt.log_base2k = ct.basek(); - pt.log_k = pt.k().min(ct.k()); -} - -impl GLWECiphertextFourier -where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, -{ - pub(crate) fn encrypt_zero_sk( - &mut self, - module: &Module, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - ScalarZnxDft: ScalarZnxDftToRef, - { - encrypt_zero_glwe_dft_sk( - module, self, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ) - } - - pub fn decrypt( - &self, - module: &Module, - pt: &mut GLWEPlaintext

, - sk_dft: &SecretKeyFourier, - scratch: &mut Scratch, - ) where - VecZnx

: VecZnxToMut + VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - decrypt_rlwe_dft(module, pt, self, sk_dft, scratch); - } -} - -impl KeySwitchScratchSpace for GLWECiphertextFourier, FFT64> { - fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) - } -} - -impl KeySwitch for GLWECiphertextFourier -where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - VecZnxDft: VecZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GLWECiphertextFourier; - type Rhs = GLWEKeySwitchKey; - - fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe_fourier(module, self, lhs, scratch); - } -} - -impl KeySwitchInplaceScratchSpace for GLWECiphertextFourier, FFT64> { - fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, - ) - } -} - -impl KeySwitchInplace for GLWECiphertextFourier -where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Rhs = GLWEKeySwitchKey; - - fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe_fourier_inplace(module, self, scratch); - } -} - -impl ExternalProductScratchSpace for GLWECiphertextFourier, FFT64> { - fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) - } -} - -impl ExternalProduct for GLWECiphertextFourier -where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - VecZnxDft: VecZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GLWECiphertextFourier; - type Rhs = GGSWCiphertext; - - fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe_fourier(module, self, lhs, scratch); - } -} - -impl ExternalProductInplaceScratchSpace for GLWECiphertextFourier, FFT64> { - fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, - ) - } -} - -impl ExternalProductInplace for GLWECiphertextFourier -where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - type Rhs = GGSWCiphertext; - - fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe_fourier_inplace(module, self, scratch); - } -} diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs new file mode 100644 index 0000000..ed8a39e --- /dev/null +++ b/core/src/glwe_ciphertext.rs @@ -0,0 +1,460 @@ +use base2k::{ + AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, + ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, + ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + elem::Infos, + gglwe_ciphertext::GGLWECiphertext, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, + utils::derive_size, + vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, +}; + +pub struct GLWECiphertext { + pub data: VecZnx, + pub basek: usize, + pub k: usize, +} + +impl GLWECiphertext> { + pub fn new(module: &Module, basek: usize, k: usize, rank: usize) -> Self { + Self { + data: module.new_vec_znx(rank + 1, derive_size(basek, k)), + basek, + k, + } + } +} + +impl Infos for GLWECiphertext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GLWECiphertext { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl VecZnxToMut for GLWECiphertext +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + #[allow(dead_code)] + pub(crate) fn dft(&self, module: &Module, res: &mut GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToMut + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), res.rank()); + assert_eq!(self.basek(), res.basek()) + } + + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_dft(res, i, self, i); + }) + } +} + +impl GLWECiphertext> { + pub fn encrypt_sk_scratch_space(module: &Module, _rank: usize, ct_size: usize) -> usize { + module.vec_znx_big_normalize_tmp_bytes() + + module.bytes_of_vec_znx_dft(1, ct_size) + + module.bytes_of_vec_znx_big(1, ct_size) + } + pub fn encrypt_pk_scratch_space(module: &Module, _rank: usize, pk_size: usize) -> usize { + ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) + + module.bytes_of_scalar_znx_dft(1) + + module.vec_znx_big_normalize_tmp_bytes() + } + + pub fn decrypt_scratch_space(module: &Module, ct_size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, ct_size)) + + module.bytes_of_vec_znx_big(1, ct_size) + } + + pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } + + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } + + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } + + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &GLWEPlaintext, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + self.encrypt_sk_private( + module, + Some((pt, 0)), + sk_dft, + source_xa, + source_xe, + sigma, + bound, + scratch, + ); + } + + pub fn encrypt_zero_sk( + &mut self, + module: &Module, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + ScalarZnxDft: ScalarZnxDftToRef, + { + self.encrypt_sk_private( + module, None, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ); + } + + pub fn encrypt_pk( + &mut self, + module: &Module, + pt: &GLWEPlaintext, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + { + self.encrypt_pk_private( + module, + Some((pt, 0)), + pk, + source_xu, + source_xe, + sigma, + bound, + scratch, + ); + } + + pub fn encrypt_zero_pk( + &mut self, + module: &Module, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef, + { + self.encrypt_pk_private( + module, None, pk, source_xu, source_xe, sigma, bound, scratch, + ); + } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + { + rhs.0.prod_with_glwe(module, self, lhs, scratch); + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.0.prod_with_glwe_inplace(module, self, scratch); + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + { + rhs.prod_with_glwe(module, self, lhs, scratch); + } + + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.prod_with_glwe_inplace(module, self, scratch); + } + + pub(crate) fn encrypt_sk_private( + &mut self, + module: &Module, + pt: Option<(&GLWEPlaintext, usize)>, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(sk_dft.n(), module.n()); + assert_eq!(self.n(), module.n()); + if let Some((pt, col)) = pt { + assert_eq!(pt.n(), module.n()); + assert!(col < self.rank() + 1); + } + } + + let log_base2k: usize = self.basek(); + let log_k: usize = self.k(); + let size: usize = self.size(); + let cols: usize = self.rank() + 1; + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx(module, 1, size); + c0_big.zero(); + + { + // c[i] = uniform + // c[0] -= c[i] * s[i], + (1..cols).for_each(|i| { + let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size); + + // c[i] = uniform + self.data.fill_uniform(log_base2k, i, size, source_xa); + + // c[i] = norm(IDFT(DFT(c[i]) * DFT(s[i]))) + module.vec_znx_dft(&mut ci_dft, 0, self, i); + module.svp_apply_inplace(&mut ci_dft, 0, sk_dft, i - 1); + let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); + + // use c[0] as buffer, which is overwritten later by the normalization step + module.vec_znx_big_normalize(log_base2k, self, 0, &ci_big, 0, scratch_2); + + // c0_tmp = -c[i] * s[i] (use c[0] as buffer) + module.vec_znx_sub_ab_inplace(&mut c0_big, 0, self, 0); + + // c[i] += m if col = i + if let Some((pt, col)) = pt { + if i == col { + module.vec_znx_add_inplace(self, i, pt, 0); + module.vec_znx_normalize_inplace(log_base2k, self, i, scratch_2); + } + } + }); + } + + // c[0] += e + c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); + + // c[0] += m if col = 0 + if let Some((pt, col)) = pt { + if col == 0 { + module.vec_znx_add_inplace(&mut c0_big, 0, pt, 0); + } + } + + // c[0] = norm(c[0]) + module.vec_znx_normalize(log_base2k, self, 0, &c0_big, 0, scratch_1); + } + + pub(crate) fn encrypt_pk_private( + &mut self, + module: &Module, + pt: Option<(&GLWEPlaintext, usize)>, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.basek(), pk.basek()); + assert_eq!(self.n(), module.n()); + assert_eq!(pk.n(), module.n()); + assert_eq!(self.rank(), pk.rank()); + if let Some((pt, _)) = pt { + assert_eq!(pt.basek(), pk.basek()); + assert_eq!(pt.n(), module.n()); + } + } + + let log_base2k: usize = pk.basek(); + let size_pk: usize = pk.size(); + let cols: usize = self.rank() + 1; + + // Generates u according to the underlying secret distribution. + let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); + + { + let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); + match pk.dist { + SecretDistribution::NONE => panic!( + "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ + Self::generate" + ), + SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), + SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + SecretDistribution::ZERO => {} + } + + module.svp_prepare(&mut u_dft, 0, &u, 0); + } + + // ct[i] = pk[i] * u + ei (+ m if col = i) + (0..cols).for_each(|i| { + let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size_pk); + // ci_dft = DFT(u) * DFT(pk[i]) + module.svp_apply(&mut ci_dft, 0, &u_dft, 0, pk, i); + + // ci_big = u * p[i] + let mut ci_big = module.vec_znx_idft_consume(ci_dft); + + // ci_big = u * pk[i] + e + ci_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, bound); + + // ci_big = u * pk[i] + e + m (if col = i) + if let Some((pt, col)) = pt { + if col == i { + module.vec_znx_big_add_small_inplace(&mut ci_big, 0, pt, 0); + } + } + + // ct[i] = norm(ci_big) + module.vec_znx_big_normalize(log_base2k, self, i, &ci_big, 0, scratch_2); + }); + } +} + +impl GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + pub fn decrypt( + &self, + module: &Module, + pt: &mut GLWEPlaintext, + sk_dft: &SecretKeyFourier, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut, + ScalarZnxDft: ScalarZnxDftToRef, + { + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct + + { + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct + module.vec_znx_dft(&mut c0_dft, 0, self, 1); + + // c0_dft = DFT(a) * DFT(s) + module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); + + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); + } + + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_big_add_small_inplace(&mut c0_big, 0, self, 0); + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(self.basek(), pt, 0, &mut c0_big, 0, scratch_1); + + pt.basek = self.basek(); + pt.k = pt.k().min(self.k()); + } +} diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs new file mode 100644 index 0000000..bcc7648 --- /dev/null +++ b/core/src/glwe_ciphertext_fourier.rs @@ -0,0 +1,261 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, + VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, + VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + elem::Infos, + gglwe_ciphertext::GGLWECiphertext, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext::GLWECiphertext, + glwe_plaintext::GLWEPlaintext, + keys::SecretKeyFourier, + keyswitch_key::GLWESwitchingKey, + utils::derive_size, + vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, +}; + +pub struct GLWECiphertextFourier { + pub data: VecZnxDft, + pub basek: usize, + pub k: usize, +} + +impl GLWECiphertextFourier, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, rank: usize) -> Self { + Self { + data: module.new_vec_znx_dft(rank + 1, derive_size(log_base2k, log_k)), + basek: log_base2k, + k: log_k, + } + } +} + +impl Infos for GLWECiphertextFourier { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GLWECiphertextFourier { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl VecZnxDftToMut for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut, +{ + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl VecZnxDftToRef for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToRef, +{ + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl GLWECiphertextFourier, FFT64> { + #[allow(dead_code)] + pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { + module.bytes_of_vec_znx(1, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) + } + + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, ct_size: usize) -> usize { + module.bytes_of_vec_znx(rank + 1, ct_size) + GLWECiphertext::encrypt_sk_scratch_space(module, rank, ct_size) + } + + pub fn decrypt_scratch_space(module: &Module, ct_size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() + | module.bytes_of_vec_znx_dft(1, ct_size) + | (module.bytes_of_vec_znx_big(1, ct_size) + module.vec_znx_idft_tmp_bytes())) + + module.bytes_of_vec_znx_big(1, ct_size) + } + + pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } + + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } + + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } + + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, +{ + pub fn encrypt_zero_sk( + &mut self, + module: &Module, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + ScalarZnxDft: ScalarZnxDftToRef, + { + let (vec_znx_tmp, scratch_1) = scratch.tmp_vec_znx(module, self.rank() + 1, self.size()); + let mut ct_idft = GLWECiphertext { + data: vec_znx_tmp, + basek: self.basek, + k: self.k, + }; + ct_idft.encrypt_zero_sk( + module, sk_dft, source_xa, source_xe, sigma, bound, scratch_1, + ); + + ct_idft.dft(module, self); + } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GLWECiphertextFourier, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + rhs.0.prod_with_glwe_fourier(module, self, lhs, scratch); + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.0.prod_with_glwe_fourier_inplace(module, self, scratch); + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &GLWECiphertextFourier, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + rhs.prod_with_glwe_fourier(module, self, lhs, scratch); + } + + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.prod_with_glwe_fourier_inplace(module, self, scratch); + } +} + +impl GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToRef, +{ + pub fn decrypt( + &self, + module: &Module, + pt: &mut GLWEPlaintext, + sk_dft: &SecretKeyFourier, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut + VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + } + + let cols = self.rank() + 1; + + let (mut pt_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct + pt_big.zero(); + + { + (1..cols).for_each(|i| { + let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct + module.svp_apply(&mut ci_dft, 0, sk_dft, i - 1, self, i); + let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); + module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0); + }); + } + + { + let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size()); + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_idft(&mut c0_big, 0, self, 0, scratch_2); + module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0); + } + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(self.basek(), pt, 0, &mut pt_big, 0, scratch_1); + + pt.basek = self.basek(); + pt.k = pt.k().min(self.k()); + } + + pub(crate) fn idft(&self, module: &Module, res: &mut GLWECiphertext, scratch: &mut Scratch) + where + GLWECiphertext: VecZnxToMut, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), res.rank()); + assert_eq!(self.basek(), res.basek()) + } + + let min_size: usize = self.size().min(res.size()); + + let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size); + + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_idft(&mut res_big, 0, self, i, scratch1); + module.vec_znx_big_normalize(self.basek(), res, i, &res_big, 0, scratch1); + }); + } +} diff --git a/core/src/glwe_plaintext.rs b/core/src/glwe_plaintext.rs new file mode 100644 index 0000000..75088d1 --- /dev/null +++ b/core/src/glwe_plaintext.rs @@ -0,0 +1,53 @@ +use base2k::{Backend, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef}; + +use crate::{elem::Infos, utils::derive_size}; + +pub struct GLWEPlaintext { + pub data: VecZnx, + pub basek: usize, + pub k: usize, +} + +impl Infos for GLWEPlaintext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl VecZnxToMut for GLWEPlaintext +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for GLWEPlaintext +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl GLWEPlaintext> { + pub fn new(module: &Module, base2k: usize, k: usize) -> Self { + Self { + data: module.new_vec_znx(1, derive_size(base2k, k)), + basek: base2k, + k, + } + } +} diff --git a/core/src/keys.rs b/core/src/keys.rs index eaa569e..d57fa73 100644 --- a/core/src/keys.rs +++ b/core/src/keys.rs @@ -5,7 +5,7 @@ use base2k::{ }; use sampling::source::Source; -use crate::{elem::Infos, glwe::GLWECiphertextFourier}; +use crate::{elem::Infos, glwe_ciphertext_fourier::GLWECiphertextFourier}; #[derive(Clone, Copy, Debug)] pub enum SecretDistribution { @@ -21,25 +21,43 @@ pub struct SecretKey { } impl SecretKey> { - pub fn new(module: &Module) -> Self { + pub fn new(module: &Module, rank: usize) -> Self { Self { - data: module.new_scalar_znx(1), + data: module.new_scalar_znx(rank), dist: SecretDistribution::NONE, } } } +impl SecretKey { + pub fn n(&self) -> usize { + self.data.n() + } + + pub fn log_n(&self) -> usize { + self.data.log_n() + } + + pub fn rank(&self) -> usize { + self.data.cols() + } +} + impl SecretKey where S: AsMut<[u8]> + AsRef<[u8]>, { pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { - self.data.fill_ternary_prob(0, prob, source); + (0..self.rank()).for_each(|i| { + self.data.fill_ternary_prob(i, prob, source); + }); self.dist = SecretDistribution::TernaryProb(prob); } pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { - self.data.fill_ternary_hw(0, hw, source); + (0..self.rank()).for_each(|i| { + self.data.fill_ternary_hw(i, hw, source); + }); self.dist = SecretDistribution::TernaryFixed(hw); } @@ -72,10 +90,24 @@ pub struct SecretKeyFourier { pub dist: SecretDistribution, } +impl SecretKeyFourier { + pub fn n(&self) -> usize { + self.data.n() + } + + pub fn log_n(&self) -> usize { + self.data.log_n() + } + + pub fn rank(&self) -> usize { + self.data.cols() + } +} + impl SecretKeyFourier, B> { - pub fn new(module: &Module) -> Self { + pub fn new(module: &Module, rank: usize) -> Self { Self { - data: module.new_scalar_znx_dft(1), + data: module.new_scalar_znx_dft(rank), dist: SecretDistribution::NONE, } } @@ -91,9 +123,15 @@ impl SecretKeyFourier, B> { SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"), _ => {} } + + assert_eq!(self.n(), module.n()); + assert_eq!(sk.n(), module.n()); + assert_eq!(self.rank(), sk.rank()); } - module.svp_prepare(self, 0, sk, 0); + (0..self.rank()).for_each(|i| { + module.svp_prepare(self, i, sk, i); + }); self.dist = sk.dist; } } @@ -116,21 +154,21 @@ where } } -pub struct PublicKey { +pub struct GLWEPublicKey { pub data: GLWECiphertextFourier, pub dist: SecretDistribution, } -impl PublicKey, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { +impl GLWEPublicKey, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, rank: usize) -> Self { Self { - data: GLWECiphertextFourier::new(module, log_base2k, log_k), + data: GLWECiphertextFourier::new(module, log_base2k, log_k, rank), dist: SecretDistribution::NONE, } } } -impl Infos for PublicKey { +impl Infos for GLWEPublicKey { type Inner = VecZnxDft; fn inner(&self) -> &Self::Inner { @@ -138,15 +176,21 @@ impl Infos for PublicKey { } fn basek(&self) -> usize { - self.data.log_base2k + self.data.basek } fn k(&self) -> usize { - self.data.log_k + self.data.k } } -impl VecZnxDftToMut for PublicKey +impl GLWEPublicKey { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl VecZnxDftToMut for GLWEPublicKey where VecZnxDft: VecZnxDftToMut, { @@ -155,7 +199,7 @@ where } } -impl VecZnxDftToRef for PublicKey +impl VecZnxDftToRef for GLWEPublicKey where VecZnxDft: VecZnxDftToRef, { @@ -164,7 +208,7 @@ where } } -impl PublicKey { +impl GLWEPublicKey { pub fn generate( &mut self, module: &Module, @@ -186,8 +230,9 @@ impl PublicKey { } // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::new(GLWECiphertextFourier::encrypt_zero_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::new(GLWECiphertextFourier::encrypt_sk_scratch_space( module, + self.rank(), self.size(), )); self.data.encrypt_zero_sk( diff --git a/core/src/keyswitch.rs b/core/src/keyswitch.rs deleted file mode 100644 index c77ccb4..0000000 --- a/core/src/keyswitch.rs +++ /dev/null @@ -1,20 +0,0 @@ -use base2k::{FFT64, Module, Scratch}; - -pub trait KeySwitchScratchSpace { - fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; -} - -pub trait KeySwitch { - type Lhs; - type Rhs; - fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch); -} - -pub trait KeySwitchInplaceScratchSpace { - fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize; -} - -pub trait KeySwitchInplace { - type Rhs; - fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch); -} diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index cb4c248..33d2a45 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -1,170 +1,63 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, - ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, - ZnxZero, + Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef, + ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, }; use sampling::source::Source; use crate::{ elem::{GetRow, Infos, SetRow}, - encryption::EncryptSkScratchSpace, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, - ggsw::GGSWCiphertext, - glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext}, + gglwe_ciphertext::GGLWECiphertext, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, keys::SecretKeyFourier, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - utils::derive_size, vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; -pub struct GLWEKeySwitchKey { - pub data: MatZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} +pub struct GLWESwitchingKey(pub(crate) GGLWECiphertext); -impl GLWEKeySwitchKey, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { - Self { - data: module.new_mat_znx_dft(rows, 1, 2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } +impl GLWESwitchingKey, FFT64> { + pub fn new(module: &Module, base2k: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { + GLWESwitchingKey(GGLWECiphertext::new( + module, base2k, k, rows, rank_in, rank_out, + )) } } -impl Infos for GLWEKeySwitchKey { +impl Infos for GLWESwitchingKey { type Inner = MatZnxDft; fn inner(&self) -> &Self::Inner { - &self.data + &self.0.inner() } fn basek(&self) -> usize { - self.log_base2k + self.0.basek() } fn k(&self) -> usize { - self.log_k + self.0.k() } } -impl MatZnxDftToMut for GLWEKeySwitchKey +impl MatZnxDftToMut for GLWESwitchingKey where - MatZnxDft: MatZnxDftToMut, + MatZnxDft: MatZnxDftToMut, { fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - self.data.to_mut() + self.0.data.to_mut() } } -impl MatZnxDftToRef for GLWEKeySwitchKey +impl MatZnxDftToRef for GLWESwitchingKey where - MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, { fn to_ref(&self) -> MatZnxDft<&[u8], B> { - self.data.to_ref() + self.0.data.to_ref() } } -impl GLWEKeySwitchKey, FFT64> { - pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { - GLWECiphertext::encrypt_sk_scratch_space(module, size) - + module.bytes_of_vec_znx(2, size) - + module.bytes_of_vec_znx(1, size) - + module.bytes_of_vec_znx_dft(2, size) - } -} - -pub fn encrypt_glwe_key_switch_key_sk( - module: &Module, - ct: &mut GLWEKeySwitchKey, - pt: &ScalarZnx

, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - MatZnxDft: MatZnxDftToMut, - ScalarZnx

: ScalarZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let rows: usize = ct.rows(); - let size: usize = ct.size(); - let log_base2k: usize = ct.basek(); - - let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); - let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size); - let (mut vec_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size); - - let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { - data: tmp_znx_pt, - log_base2k: log_base2k, - log_k: ct.k(), - }; - - let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { - data: tmp_znx_ct, - log_base2k: log_base2k, - log_k: ct.k(), - }; - - (0..rows).for_each(|row_i| { - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0); - module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scratch_3); - - // rlwe encrypt of vec_znx_pt into vec_znx_ct - vec_znx_ct.encrypt_sk( - module, - &vec_znx_pt, - sk_dft, - source_xa, - source_xe, - sigma, - bound, - scratch_3, - ); - - vec_znx_pt.data.zero(); // zeroes for next iteration - - // Switch vec_znx_ct into DFT domain - module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0); - module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1); - - // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft - module.vmp_prepare_row(ct, row_i, 0, &vec_znx_dft_ct); - }); -} - -impl GLWEKeySwitchKey { - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &ScalarZnx

, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToMut, - ScalarZnx

: ScalarZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - encrypt_glwe_key_switch_key_sk( - module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ) - } -} - -impl GetRow for GLWEKeySwitchKey +impl GetRow for GLWESwitchingKey where MatZnxDft: MatZnxDftToRef, { @@ -180,7 +73,7 @@ where } } -impl SetRow for GLWEKeySwitchKey +impl SetRow for GLWESwitchingKey where MatZnxDft: MatZnxDftToMut, { @@ -196,138 +89,117 @@ where } } -impl KeySwitchScratchSpace for GLWEKeySwitchKey, FFT64> { - fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( +impl GLWESwitchingKey, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + GGLWECiphertext::encrypt_sk_scratch_space(module, rank, size) + } + + pub fn encrypt_pk_scratch_space(module: &Module, rank: usize, pk_size: usize) -> usize { + GGLWECiphertext::encrypt_pk_scratch_space(module, rank, pk_size) + } +} + +impl GLWESwitchingKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + ScalarZnx: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + self.0.encrypt_sk( + module, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ); + } +} + +impl GLWESwitchingKey, FFT64> { + pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( module, res_size, lhs, rhs, ) } -} -impl KeySwitch for GLWEKeySwitchKey -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GLWEKeySwitchKey; - type Rhs = GLWEKeySwitchKey; - - fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe(module, self, lhs, scratch); - } -} - -impl KeySwitchInplaceScratchSpace for GLWEKeySwitchKey, FFT64> { - fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( module, res_size, rhs, ) } -} -impl KeySwitchInplace for GLWEKeySwitchKey -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Rhs = GLWEKeySwitchKey; - - fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe(module, self, rhs, scratch); - } -} - -impl ExternalProductScratchSpace for GLWEKeySwitchKey, FFT64> { - fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( module, res_size, lhs, rhs, ) } -} -impl ExternalProduct for GLWEKeySwitchKey -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GLWEKeySwitchKey; - type Rhs = GGSWCiphertext; - - fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe(module, self, lhs, scratch); - } -} - -impl ExternalProductInplaceScratchSpace for GLWEKeySwitchKey, FFT64> { - fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( module, res_size, rhs, ) } } -impl ExternalProductInplace for GLWEKeySwitchKey +impl GLWESwitchingKey where MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, { - type Rhs = GGSWCiphertext; - - fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe_inplace(module, self, scratch); - } -} - -impl VecGLWEProductScratchSpace for GLWEKeySwitchKey, FFT64> { - fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - module.bytes_of_vec_znx_dft(2, grlwe_size) - + (module.vec_znx_big_normalize_tmp_bytes() - | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) - + module.bytes_of_vec_znx_dft(1, a_size))) - } -} - -impl VecGLWEProduct for GLWEKeySwitchKey -where - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - fn prod_with_glwe( - &self, + pub fn keyswitch( + &mut self, module: &Module, - res: &mut GLWECiphertext, - a: &GLWECiphertext, + lhs: &GLWESwitchingKey, + rhs: &GLWESwitchingKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + rhs.0 + .prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch); + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.0 + .prod_with_vec_glwe_inplace(module, &mut self.0, scratch); + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &GLWESwitchingKey, + rhs: &GGSWCiphertext, scratch: &mut Scratch, ) where - MatZnxDft: MatZnxDftToRef, - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, { - let log_base2k: usize = self.basek(); + rhs.prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch); + } - #[cfg(debug_assertions)] - { - assert_eq!(res.basek(), log_base2k); - assert_eq!(a.basek(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - assert_eq!(a.n(), module.n()); - } - - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise - - { - let (mut a1_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, a.size()); - module.vec_znx_dft(&mut a1_dft, 0, a, 1); - module.vmp_apply(&mut res_dft, &a1_dft, self, scratch2); - } - - let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - - module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); - - module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); - module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.prod_with_vec_glwe_inplace(module, &mut self.0, scratch); } } diff --git a/core/src/lib.rs b/core/src/lib.rs index 97db860..cdd83d1 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,10 +1,10 @@ pub mod elem; -pub mod encryption; -pub mod external_product; -pub mod ggsw; -pub mod glwe; +pub mod gglwe_ciphertext; +pub mod ggsw_ciphertext; +pub mod glwe_ciphertext; +pub mod glwe_ciphertext_fourier; +pub mod glwe_plaintext; pub mod keys; -pub mod keyswitch; pub mod keyswitch_key; #[cfg(test)] mod test_fft64; diff --git a/core/src/test_fft64/grlwe.rs b/core/src/test_fft64/gglwe.rs similarity index 79% rename from core/src/test_fft64/grlwe.rs rename to core/src/test_fft64/gglwe.rs index 9d9a077..7a7de6d 100644 --- a/core/src/test_fft64/grlwe.rs +++ b/core/src/test_fft64/gglwe.rs @@ -3,15 +3,12 @@ use sampling::source::Source; use crate::{ elem::{GetRow, Infos}, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, - ggsw::GGSWCiphertext, - glwe::{GLWECiphertextFourier, GLWEPlaintext}, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, keys::{SecretKey, SecretKeyFourier}, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - keyswitch_key::GLWEKeySwitchKey, - test_fft64::rgsw::noise_rgsw_product, + keyswitch_key::GLWESwitchingKey, + test_fft64::ggsw::noise_rgsw_product, }; #[test] @@ -20,11 +17,13 @@ fn encrypt_sk() { let log_base2k: usize = 8; let log_k_ct: usize = 54; let rows: usize = 4; + let rank: usize = 1; + let rank_out: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_ct, rows); + let mut ct: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_ct, rows, rank, rank_out); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); @@ -35,14 +34,15 @@ fn encrypt_sk() { pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), ); - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk: SecretKey> = SecretKey::new(&module, rank); + // sk.fill_ternary_prob(0.5, &mut source_xs); + sk.fill_zero(); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct.encrypt_sk( @@ -56,7 +56,7 @@ fn encrypt_sk() { scratch.borrow(), ); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); (0..ct.rows()).for_each(|row_i| { ct.get_row(&module, row_i, 0, &mut ct_rlwe_dft); @@ -74,21 +74,26 @@ fn keyswitch() { let log_k_grlwe: usize = 60; let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + let rank: usize = 1; + let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe_s0s1: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s1s2: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s0s2: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) - | GLWEKeySwitchKey::keyswitch_scratch_space( + | GLWESwitchingKey::keyswitch_scratch_space( &module, ct_grlwe_s0s2.size(), ct_grlwe_s0s1.size(), @@ -96,22 +101,22 @@ fn keyswitch() { ), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); - let mut sk2: SecretKey> = SecretKey::new(&module); + let mut sk2: SecretKey> = SecretKey::new(&module, rank); sk2.fill_ternary_prob(0.5, &mut source_xs); - let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk2_dft.dft(&module, &sk2); // GRLWE_{s1}(s0) = s0 -> s1 @@ -142,7 +147,7 @@ fn keyswitch() { ct_grlwe_s0s2.keyswitch(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { @@ -179,38 +184,43 @@ fn keyswitch_inplace() { let log_k_grlwe: usize = 60; let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + let rank: usize = 1; + let rank_out: usize = 1; + let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe_s0s1: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s1s2: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); + let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) - | GLWEKeySwitchKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), + | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); - let mut sk2: SecretKey> = SecretKey::new(&module); + let mut sk2: SecretKey> = SecretKey::new(&module, rank); sk2.fill_ternary_prob(0.5, &mut source_xs); - let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk2_dft.dft(&module, &sk2); // GRLWE_{s1}(s0) = s0 -> s1 @@ -240,10 +250,10 @@ fn keyswitch_inplace() { // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) ct_grlwe_s0s1.keyswitch_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); - let ct_grlwe_s0s2: GLWEKeySwitchKey, FFT64> = ct_grlwe_s0s1; + let ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = ct_grlwe_s0s1; let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { @@ -280,12 +290,17 @@ fn external_product() { let log_k_grlwe: usize = 60; let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + let rank: usize = 1; + let rank_out: usize = 1; + let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe_in: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_out: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_in: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); + let mut ct_grlwe_out: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); @@ -295,15 +310,15 @@ fn external_product() { let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe_in.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_in.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_out.size()) - | GLWEKeySwitchKey::external_product_scratch_space( + | GLWESwitchingKey::external_product_scratch_space( &module, ct_grlwe_out.size(), ct_grlwe_in.size(), ct_rgsw.size(), ) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()), + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), ); let k: usize = 1; @@ -312,10 +327,10 @@ fn external_product() { pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); // GRLWE_{s1}(s0) = s0 -> s1 @@ -345,7 +360,7 @@ fn external_product() { ct_grlwe_out.external_product(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); @@ -393,11 +408,15 @@ fn external_product_inplace() { let log_k_grlwe: usize = 60; let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + let rank = 1; + let rank_out = 1; + let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); @@ -407,10 +426,10 @@ fn external_product_inplace() { let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe.size()) - | GLWEKeySwitchKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()), + | GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), ); let k: usize = 1; @@ -419,10 +438,10 @@ fn external_product_inplace() { pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); // GRLWE_{s1}(s0) = s0 -> s1 @@ -452,7 +471,7 @@ fn external_product_inplace() { ct_grlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); diff --git a/core/src/test_fft64/rgsw.rs b/core/src/test_fft64/ggsw.rs similarity index 87% rename from core/src/test_fft64/rgsw.rs rename to core/src/test_fft64/ggsw.rs index 820b671..ce16ea5 100644 --- a/core/src/test_fft64/rgsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -6,15 +6,12 @@ use sampling::source::Source; use crate::{ elem::{GetRow, Infos}, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, - ggsw::GGSWCiphertext, - glwe::{GLWECiphertextFourier, GLWEPlaintext}, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, keys::{SecretKey, SecretKeyFourier}, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - keyswitch_key::GLWEKeySwitchKey, - test_fft64::grlwe::noise_grlwe_rlwe_product, + keyswitch_key::GLWESwitchingKey, + test_fft64::gglwe::noise_grlwe_rlwe_product, }; #[test] @@ -23,11 +20,12 @@ fn encrypt_sk() { let log_base2k: usize = 8; let log_k_ct: usize = 54; let rows: usize = 4; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows); + let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows, rank); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); @@ -39,14 +37,14 @@ fn encrypt_sk() { pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, ct.size()) + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct.encrypt_sk( @@ -60,7 +58,7 @@ fn encrypt_sk() { scratch.borrow(), ); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); @@ -98,12 +96,15 @@ fn keyswitch() { let log_k_rgsw_out: usize = 45; let rows: usize = (log_k_rgsw_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; + let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_in, rows); - let mut ct_rgsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_out, rows); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_rgsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_in, rows, rank); + let mut ct_rgsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_out, rows, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -114,9 +115,9 @@ fn keyswitch() { pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_out.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw_in.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_in.size()) | GGSWCiphertext::keyswitch_scratch_space( &module, ct_rgsw_out.size(), @@ -125,16 +126,16 @@ fn keyswitch() { ), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -161,7 +162,8 @@ fn keyswitch() { ct_rgsw_out.keyswitch(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow()); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_out); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_out, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_out.size()); @@ -215,12 +217,14 @@ fn keyswitch_inplace() { let log_k_grlwe: usize = 60; let log_k_rgsw: usize = 45; let rows: usize = (log_k_rgsw + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw, rows); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw, rows, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -231,22 +235,22 @@ fn keyswitch_inplace() { pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -273,7 +277,8 @@ fn keyswitch_inplace() { ct_rgsw.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw.size()); @@ -328,13 +333,16 @@ fn external_product() { let log_k_rgsw_lhs_in: usize = 45; let log_k_rgsw_lhs_out: usize = 45; let rows: usize = (log_k_rgsw_lhs_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows); - let mut ct_rgsw_lhs_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_in, rows); - let mut ct_rgsw_lhs_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_out, rows); + let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); + let mut ct_rgsw_lhs_in: GGSWCiphertext, FFT64> = + GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_in, rows, rank); + let mut ct_rgsw_lhs_out: GGSWCiphertext, FFT64> = + GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_out, rows, rank); let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); @@ -350,9 +358,9 @@ fn external_product() { pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw_lhs_in.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs_in.size()) | GGSWCiphertext::external_product_scratch_space( &module, ct_rgsw_lhs_out.size(), @@ -361,10 +369,10 @@ fn external_product() { ), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct_rgsw_rhs.encrypt_sk( @@ -392,7 +400,7 @@ fn external_product() { ct_rgsw_lhs_out.external_product(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow()); let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs_out); + GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs_out, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs_out.size()); @@ -457,12 +465,13 @@ fn external_product_inplace() { let log_k_rgsw_rhs: usize = 60; let log_k_rgsw_lhs: usize = 45; let rows: usize = (log_k_rgsw_lhs + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows); - let mut ct_rgsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows); + let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); + let mut ct_rgsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows, rank); let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); @@ -478,16 +487,16 @@ fn external_product_inplace() { pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw_lhs.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs.size()) | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct_rgsw_rhs.encrypt_sk( @@ -514,7 +523,8 @@ fn external_product_inplace() { ct_rgsw_lhs.external_product_inplace(&module, &ct_rgsw_rhs, scratch.borrow()); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs.size()); diff --git a/core/src/test_fft64/rlwe.rs b/core/src/test_fft64/glwe.rs similarity index 84% rename from core/src/test_fft64/rlwe.rs rename to core/src/test_fft64/glwe.rs index 6958925..48b6cb6 100644 --- a/core/src/test_fft64/rlwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -7,16 +7,13 @@ use sampling::source::Source; use crate::{ elem::Infos, - encryption::EncryptSkScratchSpace, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, - ggsw::GGSWCiphertext, - glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext}, - keys::{PublicKey, SecretKey, SecretKeyFourier}, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - keyswitch_key::GLWEKeySwitchKey, - test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_product}, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{GLWEPublicKey, SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, + test_fft64::{gglwe::noise_grlwe_rlwe_product, ggsw::noise_rgsw_product}, }; #[test] @@ -25,11 +22,12 @@ fn encrypt_sk() { let log_base2k: usize = 8; let log_k_ct: usize = 54; let log_k_pt: usize = 30; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_pt); let mut source_xs: Source = Source::new([0u8; 32]); @@ -37,13 +35,14 @@ fn encrypt_sk() { let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct.size()), + GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct.size()), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); let mut data_want: Vec = vec![0i64; module.n()]; @@ -93,6 +92,7 @@ fn encrypt_zero_sk() { let module: Module = Module::::new(1024); let log_base2k: usize = 8; let log_k_ct: usize = 55; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; @@ -103,16 +103,16 @@ fn encrypt_zero_sk() { let mut source_xe: Source = Source::new([1u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); - let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct); + let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); let mut scratch: ScratchOwned = ScratchOwned::new( GLWECiphertextFourier::decrypt_scratch_space(&module, ct_dft.size()) - | GLWECiphertextFourier::encrypt_zero_sk_scratch_space(&module, ct_dft.size()), + | GLWECiphertextFourier::encrypt_sk_scratch_space(&module, rank, ct_dft.size()), ); ct_dft.encrypt_zero_sk( @@ -135,11 +135,12 @@ fn encrypt_pk() { let log_base2k: usize = 8; let log_k_ct: usize = 54; let log_k_pk: usize = 64; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut source_xs: Source = Source::new([0u8; 32]); @@ -147,12 +148,12 @@ fn encrypt_pk() { let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xu: Source = Source::new([0u8; 32]); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); - let mut pk: PublicKey, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk); + let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::new(&module, log_base2k, log_k_pk, rank); pk.generate( &module, &sk_dft, @@ -163,9 +164,9 @@ fn encrypt_pk() { ); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) + GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct.size()) - | GLWECiphertext::encrypt_pk_scratch_space(&module, pk.size()), + | GLWECiphertext::encrypt_pk_scratch_space(&module, rank, pk.size()), ); let mut data_want: Vec = vec![0i64; module.n()]; @@ -206,13 +207,15 @@ fn keyswitch() { let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); @@ -226,9 +229,9 @@ fn keyswitch() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) | GLWECiphertext::keyswitch_scratch_space( &module, ct_rlwe_out.size(), @@ -237,16 +240,16 @@ fn keyswitch() { ), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -305,12 +308,14 @@ fn keyswich_inplace() { let log_k_grlwe: usize = 60; let log_k_rlwe: usize = 45; let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); @@ -324,22 +329,22 @@ fn keyswich_inplace() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -399,13 +404,14 @@ fn external_product() { let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); @@ -426,9 +432,9 @@ fn external_product() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) | GLWECiphertext::external_product_scratch_space( &module, ct_rlwe_out.size(), @@ -437,10 +443,10 @@ fn external_product() { ), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( @@ -511,12 +517,13 @@ fn external_product_inplace() { let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); @@ -537,16 +544,16 @@ fn external_product_inplace() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( diff --git a/core/src/test_fft64/rlwe_dft.rs b/core/src/test_fft64/glwe_fourier.rs similarity index 84% rename from core/src/test_fft64/rlwe_dft.rs rename to core/src/test_fft64/glwe_fourier.rs index 06359b1..661a1e5 100644 --- a/core/src/test_fft64/rlwe_dft.rs +++ b/core/src/test_fft64/glwe_fourier.rs @@ -1,15 +1,12 @@ use crate::{ elem::Infos, - encryption::EncryptSkScratchSpace, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, - ggsw::GGSWCiphertext, - glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext}, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, keys::{SecretKey, SecretKeyFourier}, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - keyswitch_key::GLWEKeySwitchKey, - test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_product}, + keyswitch_key::GLWESwitchingKey, + test_fft64::{gglwe::noise_grlwe_rlwe_product, ggsw::noise_rgsw_product}, }; use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; use sampling::source::Source; @@ -23,16 +20,19 @@ fn keyswitch() { let log_k_rlwe_out: usize = 60; let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; + let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); let mut ct_rlwe_in_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); let mut ct_rlwe_out_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out); + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); @@ -46,9 +46,9 @@ fn keyswitch() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) | GLWECiphertextFourier::keyswitch_scratch_space( &module, ct_rlwe_out.size(), @@ -57,16 +57,16 @@ fn keyswitch() { ), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -127,13 +127,16 @@ fn keyswich_inplace() { let log_k_grlwe: usize = 60; let log_k_rlwe: usize = 45; let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe, rank); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); @@ -147,22 +150,22 @@ fn keyswich_inplace() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -224,17 +227,18 @@ fn external_product() { let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); let mut ct_rlwe_dft_in: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in); + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); let mut ct_rlwe_dft_out: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out); + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); @@ -255,9 +259,9 @@ fn external_product() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) | GLWECiphertext::external_product_scratch_space( &module, ct_rlwe_out.size(), @@ -266,10 +270,10 @@ fn external_product() { ), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( @@ -342,13 +346,15 @@ fn external_product_inplace() { let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); @@ -369,16 +375,16 @@ fn external_product_inplace() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( diff --git a/core/src/test_fft64/mod.rs b/core/src/test_fft64/mod.rs index 59e2895..ffaf1dc 100644 --- a/core/src/test_fft64/mod.rs +++ b/core/src/test_fft64/mod.rs @@ -1,4 +1,4 @@ -mod grlwe; -mod rgsw; -mod rlwe; -mod rlwe_dft; +mod gglwe; +mod ggsw; +mod glwe; +mod glwe_fourier; diff --git a/core/src/vec_glwe_product.rs b/core/src/vec_glwe_product.rs index 7920de9..63c4769 100644 --- a/core/src/vec_glwe_product.rs +++ b/core/src/vec_glwe_product.rs @@ -5,7 +5,8 @@ use base2k::{ use crate::{ elem::{GetRow, Infos, SetRow}, - glwe::{GLWECiphertext, GLWECiphertextFourier}, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, }; pub(crate) trait VecGLWEProductScratchSpace { @@ -81,8 +82,8 @@ pub(crate) trait VecGLWEProduct: Infos { let mut a_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: a_data, - log_base2k: a.basek(), - log_k: a.k(), + basek: a.basek(), + k: a.k(), }; a.idft(module, &mut a_idft, scratch_1); @@ -91,8 +92,8 @@ pub(crate) trait VecGLWEProduct: Infos { let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: res_data, - log_base2k: res.basek(), - log_k: res.k(), + basek: res.basek(), + k: res.k(), }; self.prod_with_glwe(module, &mut res_idft, &a_idft, scratch_2); @@ -122,8 +123,8 @@ pub(crate) trait VecGLWEProduct: Infos { let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: res_data, - log_base2k: res.basek(), - log_k: res.k(), + basek: res.basek(), + k: res.k(), }; res.idft(module, &mut res_idft, scratch_1); @@ -143,22 +144,22 @@ pub(crate) trait VecGLWEProduct: Infos { let mut tmp_a_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_row_data, - log_base2k: a.basek(), - log_k: a.k(), + basek: a.basek(), + k: a.k(), }; let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, res.size()); let mut tmp_res_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_res_data, - log_base2k: res.basek(), - log_k: res.k(), + basek: res.basek(), + k: res.k(), }; let min_rows: usize = res.rows().min(a.rows()); (0..res.rows()).for_each(|row_i| { - (0..res.rank()).for_each(|col_j| { + (0..res.cols()).for_each(|col_j| { a.get_row(module, row_i, col_j, &mut tmp_a_row); self.prod_with_glwe_fourier(module, &mut tmp_res_row, &tmp_a_row, scratch2); res.set_row(module, row_i, col_j, &tmp_res_row); @@ -168,7 +169,7 @@ pub(crate) trait VecGLWEProduct: Infos { tmp_res_row.data.zero(); (min_rows..res.rows()).for_each(|row_i| { - (0..self.rank()).for_each(|col_j| { + (0..self.cols()).for_each(|col_j| { res.set_row(module, row_i, col_j, &tmp_res_row); }); }); @@ -182,12 +183,12 @@ pub(crate) trait VecGLWEProduct: Infos { let mut tmp_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_row_data, - log_base2k: res.basek(), - log_k: res.k(), + basek: res.basek(), + k: res.k(), }; (0..res.rows()).for_each(|row_i| { - (0..res.rank()).for_each(|col_j| { + (0..res.cols()).for_each(|col_j| { res.get_row(module, row_i, col_j, &mut tmp_row); self.prod_with_glwe_fourier_inplace(module, &mut tmp_row, scratch1); res.set_row(module, row_i, col_j, &tmp_row); From 66188a12a64faad608ec9a4e271bbce61aec10b9 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 14 May 2025 09:10:05 +0200 Subject: [PATCH 63/87] added multiple rank glwe enc sk & fixed decryption for glwe --- core/src/glwe_ciphertext.rs | 27 +++++-- core/src/test_fft64/glwe.rs | 141 ++++++++++++++++++------------------ 2 files changed, 92 insertions(+), 76 deletions(-) diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index ed8a39e..56ca1a9 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -435,17 +435,30 @@ where VecZnx: VecZnxToMut, ScalarZnxDft: ScalarZnxDftToRef, { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + } + + let cols: usize = self.rank() + 1; + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct + c0_big.zero(); { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct - module.vec_znx_dft(&mut c0_dft, 0, self, 1); + (1..cols).for_each(|i| { + // ci_dft = DFT(a[i]) * DFT(s[i]) + let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct + module.vec_znx_dft(&mut ci_dft, 0, self, i); + module.svp_apply_inplace(&mut ci_dft, 0, sk_dft, i - 1); + let ci_big = module.vec_znx_idft_consume(ci_dft); - // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); + // c0_big += a[i] * s[i] + module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0); + }); } // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 48b6cb6..5ffe2dd 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -1,6 +1,6 @@ use base2k::{ Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, - ZnxViewMut, ZnxZero, + ZnxView, ZnxViewMut, ZnxZero, }; use itertools::izip; use sampling::source::Source; @@ -17,18 +17,26 @@ use crate::{ }; #[test] -fn encrypt_sk() { - let module: Module = Module::::new(32); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let log_k_pt: usize = 30; - let rank: usize = 1; +fn encrypt_sk_rank_1() { + encrypt_sk(11, 8, 54, 30, 3.2, 1); +} - let sigma: f64 = 3.2; +#[test] +fn encrypt_sk_rank_2() { + encrypt_sk(5, 8, 54, 30, 3.2, 2); +} + +#[test] +fn encrypt_sk_rank_3() { + encrypt_sk(11, 8, 54, 30, 3.2, 3); +} + +fn encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); let bound: f64 = sigma * 6.0; - let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_pt); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_pt); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -51,8 +59,7 @@ fn encrypt_sk() { .iter_mut() .for_each(|x| *x = source_xa.next_i64() & 0xFF); - pt.data - .encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10); + pt.data.encode_vec_i64(0, basek, k_pt, &data_want, 10); ct.encrypt_sk( &module, @@ -72,10 +79,10 @@ fn encrypt_sk() { let mut data_have: Vec = vec![0i64; module.n()]; pt.data - .decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have); + .decode_vec_i64(0, basek, pt.size() * basek, &mut data_have); // TODO: properly assert the decryption noise through std(dec(ct) - pt) - let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64; + let scale: f64 = (1 << (pt.size() * basek - k_pt)) as f64; izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { let b_scaled = (*b as f64) / scale; assert!( @@ -90,14 +97,14 @@ fn encrypt_sk() { #[test] fn encrypt_zero_sk() { let module: Module = Module::::new(1024); - let log_base2k: usize = 8; - let log_k_ct: usize = 55; + let basek: usize = 8; + let k_ct: usize = 55; let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); @@ -108,7 +115,7 @@ fn encrypt_zero_sk() { let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); - let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); + let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct, rank); let mut scratch: ScratchOwned = ScratchOwned::new( GLWECiphertextFourier::decrypt_scratch_space(&module, ct_dft.size()) @@ -126,22 +133,22 @@ fn encrypt_zero_sk() { ); ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2); + assert!((sigma - pt.data.std(0, basek) * (k_ct as f64).exp2()) <= 0.2); } #[test] fn encrypt_pk() { let module: Module = Module::::new(32); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; + let basek: usize = 8; + let k_ct: usize = 54; let log_k_pk: usize = 64; let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -153,7 +160,7 @@ fn encrypt_pk() { let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); - let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::new(&module, log_base2k, log_k_pk, rank); + let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::new(&module, basek, log_k_pk, rank); pk.generate( &module, &sk_dft, @@ -175,9 +182,7 @@ fn encrypt_pk() { .iter_mut() .for_each(|x| *x = source_xa.next_i64() & 0); - pt_want - .data - .encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10); + pt_want.data.encode_vec_i64(0, basek, k_ct, &data_want, 10); ct.encrypt_pk( &module, @@ -190,34 +195,33 @@ fn encrypt_pk() { scratch.borrow(), ); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); - assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2); + assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, basek) * (k_ct as f64).exp2()).abs() < 0.2); } #[test] fn keyswitch() { let module: Module = Module::::new(2048); - let log_base2k: usize = 12; + let basek: usize = 12; let log_k_grlwe: usize = 60; let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rows: usize = (log_k_rlwe_in + basek - 1) / basek; let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_out, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -226,7 +230,7 @@ fn keyswitch() { // Random input plaintext pt_want .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) @@ -280,10 +284,10 @@ fn keyswitch() { module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_want: f64 = noise_grlwe_rlwe_product( module.n() as f64, - log_base2k, + basek, 0.5, 0.5, 0f64, @@ -304,20 +308,19 @@ fn keyswitch() { #[test] fn keyswich_inplace() { let module: Module = Module::::new(2048); - let log_base2k: usize = 12; + let basek: usize = 12; let log_k_grlwe: usize = 60; let log_k_rlwe: usize = 45; - let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + let rows: usize = (log_k_rlwe + basek - 1) / basek; let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -326,7 +329,7 @@ fn keyswich_inplace() { // Random input plaintext pt_want .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) @@ -375,10 +378,10 @@ fn keyswich_inplace() { module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_want: f64 = noise_grlwe_rlwe_product( module.n() as f64, - log_base2k, + basek, 0.5, 0.5, 0f64, @@ -399,22 +402,22 @@ fn keyswich_inplace() { #[test] fn external_product() { let module: Module = Module::::new(2048); - let log_base2k: usize = 12; + let basek: usize = 12; let log_k_grlwe: usize = 60; let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rows: usize = (log_k_rlwe_in + basek - 1) / basek; let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_out, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -423,7 +426,7 @@ fn external_product() { // Random input plaintext pt_want .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); pt_want.to_mut().at_mut(0, 0)[1] = 1; @@ -479,7 +482,7 @@ fn external_product() { module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_have: f64 = pt_have.data.std(0, basek).log2(); let var_gct_err_lhs: f64 = sigma * sigma; let var_gct_err_rhs: f64 = 0f64; @@ -490,7 +493,7 @@ fn external_product() { let noise_want: f64 = noise_rgsw_product( module.n() as f64, - log_base2k, + basek, 0.5, var_msg, var_a0_err, @@ -512,21 +515,21 @@ fn external_product() { #[test] fn external_product_inplace() { let module: Module = Module::::new(2048); - let log_base2k: usize = 12; + let basek: usize = 12; let log_k_grlwe: usize = 60; let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rows: usize = (log_k_rlwe_in + basek - 1) / basek; let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -535,7 +538,7 @@ fn external_product_inplace() { // Random input plaintext pt_want .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); pt_want.to_mut().at_mut(0, 0)[1] = 1; @@ -586,7 +589,7 @@ fn external_product_inplace() { module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_have: f64 = pt_have.data.std(0, basek).log2(); let var_gct_err_lhs: f64 = sigma * sigma; let var_gct_err_rhs: f64 = 0f64; @@ -597,7 +600,7 @@ fn external_product_inplace() { let noise_want: f64 = noise_rgsw_product( module.n() as f64, - log_base2k, + basek, 0.5, var_msg, var_a0_err, From d489bef105ffed48fd88d17fac52b940549b7469 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 14 May 2025 09:18:46 +0200 Subject: [PATCH 64/87] hard coded noise bound to 6 sigma --- core/src/gglwe_ciphertext.rs | 2 -- core/src/ggsw_ciphertext.rs | 2 -- core/src/glwe_ciphertext.rs | 21 +++++---------------- core/src/glwe_ciphertext_fourier.rs | 5 +---- core/src/keys.rs | 2 -- core/src/keyswitch_key.rs | 6 ++---- core/src/lib.rs | 2 ++ core/src/test_fft64/gglwe.rs | 18 ++---------------- core/src/test_fft64/ggsw.rs | 14 -------------- core/src/test_fft64/glwe.rs | 29 ++--------------------------- core/src/test_fft64/glwe_fourier.rs | 12 ------------ 11 files changed, 14 insertions(+), 99 deletions(-) diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index 9d7c45a..9d2c79a 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -97,7 +97,6 @@ where source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - bound: f64, scratch: &mut Scratch, ) where ScalarZnx: ScalarZnxToRef, @@ -153,7 +152,6 @@ where source_xa, source_xe, sigma, - bound, scratch_3, ); diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 9d42df8..d277d78 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -119,7 +119,6 @@ where source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - bound: f64, scratch: &mut Scratch, ) where ScalarZnx: ScalarZnxToRef, @@ -168,7 +167,6 @@ where source_xa, source_xe, sigma, - bound, scrach_2, ); diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 56ca1a9..5f8d086 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -7,6 +7,7 @@ use base2k::{ use sampling::source::Source; use crate::{ + SIX_SIGMA, elem::Infos, gglwe_ciphertext::GGLWECiphertext, ggsw_ciphertext::GGSWCiphertext, @@ -145,7 +146,6 @@ where source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - bound: f64, scratch: &mut Scratch, ) where VecZnx: VecZnxToRef, @@ -158,7 +158,6 @@ where source_xa, source_xe, sigma, - bound, scratch, ); } @@ -170,14 +169,11 @@ where source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - bound: f64, scratch: &mut Scratch, ) where ScalarZnxDft: ScalarZnxDftToRef, { - self.encrypt_sk_private( - module, None, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ); + self.encrypt_sk_private(module, None, sk_dft, source_xa, source_xe, sigma, scratch); } pub fn encrypt_pk( @@ -188,7 +184,6 @@ where source_xu: &mut Source, source_xe: &mut Source, sigma: f64, - bound: f64, scratch: &mut Scratch, ) where VecZnx: VecZnxToRef, @@ -201,7 +196,6 @@ where source_xu, source_xe, sigma, - bound, scratch, ); } @@ -213,14 +207,11 @@ where source_xu: &mut Source, source_xe: &mut Source, sigma: f64, - bound: f64, scratch: &mut Scratch, ) where VecZnxDft: VecZnxDftToRef, { - self.encrypt_pk_private( - module, None, pk, source_xu, source_xe, sigma, bound, scratch, - ); + self.encrypt_pk_private(module, None, pk, source_xu, source_xe, sigma, scratch); } pub fn keyswitch( @@ -279,7 +270,6 @@ where source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - bound: f64, scratch: &mut Scratch, ) where VecZnx: VecZnxToRef, @@ -335,7 +325,7 @@ where } // c[0] += e - c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); + c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, sigma * SIX_SIGMA); // c[0] += m if col = 0 if let Some((pt, col)) = pt { @@ -356,7 +346,6 @@ where source_xu: &mut Source, source_xe: &mut Source, sigma: f64, - bound: f64, scratch: &mut Scratch, ) where VecZnx: VecZnxToRef, @@ -406,7 +395,7 @@ where let mut ci_big = module.vec_znx_idft_consume(ci_dft); // ci_big = u * pk[i] + e - ci_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, bound); + ci_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, sigma * SIX_SIGMA); // ci_big = u * pk[i] + e + m (if col = i) if let Some((pt, col)) = pt { diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index bcc7648..e31d0dc 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -122,7 +122,6 @@ where source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - bound: f64, scratch: &mut Scratch, ) where ScalarZnxDft: ScalarZnxDftToRef, @@ -133,9 +132,7 @@ where basek: self.basek, k: self.k, }; - ct_idft.encrypt_zero_sk( - module, sk_dft, source_xa, source_xe, sigma, bound, scratch_1, - ); + ct_idft.encrypt_zero_sk(module, sk_dft, source_xa, source_xe, sigma, scratch_1); ct_idft.dft(module, self); } diff --git a/core/src/keys.rs b/core/src/keys.rs index d57fa73..8a4d5e1 100644 --- a/core/src/keys.rs +++ b/core/src/keys.rs @@ -216,7 +216,6 @@ impl GLWEPublicKey { source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - bound: f64, ) where VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, ScalarZnxDft: ScalarZnxDftToRef + ZnxInfos, @@ -241,7 +240,6 @@ impl GLWEPublicKey { source_xa, source_xe, sigma, - bound, scratch.borrow(), ); self.dist = sk_dft.dist; diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index 33d2a45..f9500a7 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -111,15 +111,13 @@ where source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - bound: f64, scratch: &mut Scratch, ) where ScalarZnx: ScalarZnxToRef, ScalarZnxDft: ScalarZnxDftToRef, { - self.0.encrypt_sk( - module, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ); + self.0 + .encrypt_sk(module, pt, sk_dft, source_xa, source_xe, sigma, scratch); } } diff --git a/core/src/lib.rs b/core/src/lib.rs index cdd83d1..14392df 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -10,3 +10,5 @@ pub mod keyswitch_key; mod test_fft64; mod utils; pub mod vec_glwe_product; + +pub(crate) const SIX_SIGMA: f64 = 6.0; diff --git a/core/src/test_fft64/gglwe.rs b/core/src/test_fft64/gglwe.rs index 7a7de6d..e4a566d 100644 --- a/core/src/test_fft64/gglwe.rs +++ b/core/src/test_fft64/gglwe.rs @@ -21,7 +21,6 @@ fn encrypt_sk() { let rank_out: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_ct, rows, rank, rank_out); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); @@ -52,7 +51,6 @@ fn encrypt_sk() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -77,7 +75,6 @@ fn keyswitch() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); @@ -127,7 +124,6 @@ fn keyswitch() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -139,7 +135,6 @@ fn keyswitch() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -188,7 +183,6 @@ fn keyswitch_inplace() { let rank_out: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); @@ -231,7 +225,6 @@ fn keyswitch_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -243,7 +236,6 @@ fn keyswitch_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -294,7 +286,6 @@ fn external_product() { let rank_out: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_grlwe_in: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); @@ -341,7 +332,6 @@ fn external_product() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -352,7 +342,6 @@ fn external_product() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -408,11 +397,10 @@ fn external_product_inplace() { let log_k_grlwe: usize = 60; let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - let rank = 1; - let rank_out = 1; + let rank: usize = 1; + let rank_out: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); @@ -452,7 +440,6 @@ fn external_product_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -463,7 +450,6 @@ fn external_product_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index ce16ea5..f1903c1 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -23,7 +23,6 @@ fn encrypt_sk() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows, rank); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); @@ -54,7 +53,6 @@ fn encrypt_sk() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -99,7 +97,6 @@ fn keyswitch() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); @@ -145,7 +142,6 @@ fn keyswitch() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -156,7 +152,6 @@ fn keyswitch() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -220,7 +215,6 @@ fn keyswitch_inplace() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); @@ -260,7 +254,6 @@ fn keyswitch_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -271,7 +264,6 @@ fn keyswitch_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -336,7 +328,6 @@ fn external_product() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); let mut ct_rgsw_lhs_in: GGSWCiphertext, FFT64> = @@ -382,7 +373,6 @@ fn external_product() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -393,7 +383,6 @@ fn external_product() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -468,7 +457,6 @@ fn external_product_inplace() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); let mut ct_rgsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows, rank); @@ -506,7 +494,6 @@ fn external_product_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -517,7 +504,6 @@ fn external_product_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 5ffe2dd..8e97eff 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -1,6 +1,6 @@ use base2k::{ Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, - ZnxView, ZnxViewMut, ZnxZero, + ZnxViewMut, ZnxZero, }; use itertools::izip; use sampling::source::Source; @@ -33,7 +33,6 @@ fn encrypt_sk_rank_3() { fn encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { let module: Module = Module::::new(1 << log_n); - let bound: f64 = sigma * 6.0; let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_pt); @@ -68,7 +67,6 @@ fn encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -102,7 +100,6 @@ fn encrypt_zero_sk() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); @@ -128,7 +125,6 @@ fn encrypt_zero_sk() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); @@ -145,7 +141,6 @@ fn encrypt_pk() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); @@ -161,14 +156,7 @@ fn encrypt_pk() { sk_dft.dft(&module, &sk); let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::new(&module, basek, log_k_pk, rank); - pk.generate( - &module, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - ); + pk.generate(&module, &sk_dft, &mut source_xa, &mut source_xe, sigma); let mut scratch: ScratchOwned = ScratchOwned::new( GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) @@ -191,7 +179,6 @@ fn encrypt_pk() { &mut source_xu, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -215,7 +202,6 @@ fn keyswitch() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank); @@ -263,7 +249,6 @@ fn keyswitch() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -274,7 +259,6 @@ fn keyswitch() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -315,7 +299,6 @@ fn keyswich_inplace() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe, rank); @@ -357,7 +340,6 @@ fn keyswich_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -368,7 +350,6 @@ fn keyswich_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -410,7 +391,6 @@ fn external_product() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank); @@ -459,7 +439,6 @@ fn external_product() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -470,7 +449,6 @@ fn external_product() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -523,7 +501,6 @@ fn external_product_inplace() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank); @@ -566,7 +543,6 @@ fn external_product_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -577,7 +553,6 @@ fn external_product_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); diff --git a/core/src/test_fft64/glwe_fourier.rs b/core/src/test_fft64/glwe_fourier.rs index 661a1e5..16f9eca 100644 --- a/core/src/test_fft64/glwe_fourier.rs +++ b/core/src/test_fft64/glwe_fourier.rs @@ -23,7 +23,6 @@ fn keyswitch() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); @@ -76,7 +75,6 @@ fn keyswitch() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -87,7 +85,6 @@ fn keyswitch() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -130,7 +127,6 @@ fn keyswich_inplace() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); @@ -175,7 +171,6 @@ fn keyswich_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -186,7 +181,6 @@ fn keyswich_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -230,7 +224,6 @@ fn external_product() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); @@ -283,7 +276,6 @@ fn external_product() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -294,7 +286,6 @@ fn external_product() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -349,7 +340,6 @@ fn external_product_inplace() { let rank: usize = 1; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); @@ -394,7 +384,6 @@ fn external_product_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -405,7 +394,6 @@ fn external_product_inplace() { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); From cb1928802a35b08e0b09cb894c50128ddb04fddb Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 14 May 2025 14:57:04 +0200 Subject: [PATCH 65/87] Added noise based test for glwe pk enc --- core/src/test_fft64/glwe.rs | 56 ++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 8e97eff..dca899b 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -17,21 +17,30 @@ use crate::{ }; #[test] -fn encrypt_sk_rank_1() { - encrypt_sk(11, 8, 54, 30, 3.2, 1); +fn encrypt_sk() { + (1..4).for_each(|rank| { + println!("test encrypt_sk rank: {}", rank); + test_encrypt_sk(11, 8, 54, 30, 3.2, rank); + }); } #[test] -fn encrypt_sk_rank_2() { - encrypt_sk(5, 8, 54, 30, 3.2, 2); +fn encrypt_zero_sk() { + (1..4).for_each(|rank| { + println!("test encrypt_zero_sk rank: {}", rank); + test_encrypt_zero_sk(11, 8, 64, 3.2, rank); + }); } #[test] -fn encrypt_sk_rank_3() { - encrypt_sk(11, 8, 54, 30, 3.2, 3); +fn encrypt_pk() { + (1..4).for_each(|rank| { + println!("test encrypt_pk rank: {}", rank); + test_encrypt_pk(11, 8, 64, 64, 3.2, rank) + }); } -fn encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { +fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { let module: Module = Module::::new(1 << log_n); let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); @@ -92,14 +101,8 @@ fn encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, }); } -#[test] -fn encrypt_zero_sk() { - let module: Module = Module::::new(1024); - let basek: usize = 8; - let k_ct: usize = 55; - let rank: usize = 1; - - let sigma: f64 = 3.2; +fn test_encrypt_zero_sk(log_n: usize, basek: usize, k_ct: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); @@ -132,15 +135,8 @@ fn encrypt_zero_sk() { assert!((sigma - pt.data.std(0, basek) * (k_ct as f64).exp2()) <= 0.2); } -#[test] -fn encrypt_pk() { - let module: Module = Module::::new(32); - let basek: usize = 8; - let k_ct: usize = 54; - let log_k_pk: usize = 64; - let rank: usize = 1; - - let sigma: f64 = 3.2; +fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); @@ -155,7 +151,7 @@ fn encrypt_pk() { let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); - let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::new(&module, basek, log_k_pk, rank); + let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::new(&module, basek, k_pk, rank); pk.generate(&module, &sk_dft, &mut source_xa, &mut source_xe, sigma); let mut scratch: ScratchOwned = ScratchOwned::new( @@ -188,7 +184,15 @@ fn encrypt_pk() { module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); - assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, basek) * (k_ct as f64).exp2()).abs() < 0.2); + let noise_have: f64 = pt_want.data.std(0, basek).log2(); + let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64); + + assert!( + (noise_have - noise_want).abs() < 0.2, + "{} {}", + noise_have, + noise_want + ); } #[test] From f517a730a3fba5c2d14f2fcd5088f1156065814c Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 14 May 2025 16:34:52 +0200 Subject: [PATCH 66/87] updated key-switch for rank switching & updated glwe key-switching test --- core/src/gglwe_ciphertext.rs | 118 +-- core/src/ggsw_ciphertext.rs | 62 +- core/src/glwe_ciphertext.rs | 33 +- core/src/glwe_ciphertext_fourier.rs | 33 +- core/src/keyswitch_key.rs | 46 +- core/src/test_fft64/gglwe.rs | 1003 +++++++++++------------ core/src/test_fft64/ggsw.rs | 1139 +++++++++++++-------------- core/src/test_fft64/glwe.rs | 100 ++- core/src/test_fft64/glwe_fourier.rs | 876 ++++++++++---------- core/src/vec_glwe_product.rs | 72 +- 10 files changed, 1806 insertions(+), 1676 deletions(-) diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index 9d2c79a..2a86c63 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -52,6 +52,14 @@ impl GGLWECiphertext { pub fn rank(&self) -> usize { self.data.cols_out() - 1 } + + pub fn rank_in(&self) -> usize { + self.data.cols_in() + } + + pub fn rank_out(&self) -> usize { + self.data.cols_out() - 1 + } } impl MatZnxDftToMut for GGLWECiphertext @@ -104,7 +112,8 @@ where { #[cfg(debug_assertions)] { - assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.rank_in(), pt.cols()); + assert_eq!(self.rank_out(), sk_dft.rank()); assert_eq!(self.n(), module.n()); assert_eq!(sk_dft.n(), module.n()); assert_eq!(pt.n(), module.n()); @@ -115,11 +124,12 @@ where let basek: usize = self.basek(); let k: usize = self.k(); - let cols: usize = self.rank() + 1; + let cols_in: usize = self.rank_in(); + let cols_out: usize = self.rank_out() + 1; let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); - let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, cols, size); - let (tmp_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, cols, size); + let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, cols_out, size); + let (tmp_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, cols_out, size); let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { data: tmp_znx_pt, @@ -139,29 +149,42 @@ where k, }; - (0..rows).for_each(|row_i| { - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0); - module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3); + // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns + // + // Example for ksk rank 2 to rank 3: + // + // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2) + // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2) + // + // Example ksk rank 2 to rank 1 + // + // (-(a*s) + s0, a) + // (-(b*s) + s1, b) + (0..cols_in).for_each(|col_i| { + (0..rows).for_each(|row_i| { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, col_i); // Selects the i-th + module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3); - // rlwe encrypt of vec_znx_pt into vec_znx_ct - vec_znx_ct.encrypt_sk( - module, - &vec_znx_pt, - sk_dft, - source_xa, - source_xe, - sigma, - scratch_3, - ); + // rlwe encrypt of vec_znx_pt into vec_znx_ct + vec_znx_ct.encrypt_sk( + module, + &vec_znx_pt, + sk_dft, + source_xa, + source_xe, + sigma, + scratch_3, + ); - vec_znx_pt.data.zero(); // zeroes for next iteration + vec_znx_pt.data.zero(); // zeroes for next iteration - // Switch vec_znx_ct into DFT domain - vec_znx_ct.dft(module, &mut vec_znx_ct_dft); + // Switch vec_znx_ct into DFT domain + vec_znx_ct.dft(module, &mut vec_znx_ct_dft); - // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft - module.vmp_prepare_row(self, row_i, 0, &vec_znx_ct_dft); + // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft + module.vmp_prepare_row(self, row_i, col_i, &vec_znx_ct_dft); + }); }); } } @@ -174,10 +197,6 @@ where where VecZnxDft: VecZnxDftToMut, { - #[cfg(debug_assertions)] - { - assert_eq!(col_j, 0); - } module.vmp_extract_row(res, self, row_i, col_j); } } @@ -190,20 +209,23 @@ where where VecZnxDft: VecZnxDftToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(col_j, 0); - } module.vmp_prepare_row(self, row_i, col_j, a); } } impl VecGLWEProductScratchSpace for GGLWECiphertext, FFT64> { - fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - module.bytes_of_vec_znx_dft(2, grlwe_size) + fn prod_with_glwe_scratch_space( + module: &Module, + res_size: usize, + a_size: usize, + grlwe_size: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + module.bytes_of_vec_znx_dft(rank_out + 1, grlwe_size) + (module.vec_znx_big_normalize_tmp_bytes() - | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) - + module.bytes_of_vec_znx_dft(1, a_size))) + | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, rank_in, rank_out + 1, grlwe_size) + + module.bytes_of_vec_znx_dft(rank_in, a_size))) } } @@ -222,30 +244,38 @@ where VecZnx: VecZnxToMut, VecZnx: VecZnxToRef, { - let log_base2k: usize = self.basek(); + let basek: usize = self.basek(); #[cfg(debug_assertions)] { - assert_eq!(res.basek(), log_base2k); - assert_eq!(a.basek(), log_base2k); + assert_eq!(a.rank(), self.rank_in()); + assert_eq!(res.rank(), self.rank_out()); + assert_eq!(res.basek(), basek); + assert_eq!(a.basek(), basek); assert_eq!(self.n(), module.n()); assert_eq!(res.n(), module.n()); assert_eq!(a.n(), module.n()); } - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise + let cols_in: usize = self.rank_in(); + let cols_out: usize = self.rank_out() + 1; + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, self.size()); // Todo optimise { - let (mut a1_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, a.size()); - module.vec_znx_dft(&mut a1_dft, 0, a, 1); - module.vmp_apply(&mut res_dft, &a1_dft, self, scratch2); + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, a.size()); + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft(&mut ai_dft, col_i, a, col_i + 1); + }); + module.vmp_apply(&mut res_dft, &ai_dft, self, scratch2); } let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); - module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); - module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); + (0..cols_out).for_each(|i| { + module.vec_znx_big_normalize(basek, res, i, &res_big, i, scratch1); + }); } } diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index d277d78..625e09b 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -82,27 +82,34 @@ impl GGSWCiphertext, FFT64> { + module.bytes_of_vec_znx_dft(rank + 1, size) } - pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + pub fn keyswitch_scratch_space( + module: &Module, + res_size: usize, + lhs: usize, + rhs: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, + module, res_size, lhs, rhs, rank_in, rank_out, ) } - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( - module, res_size, rhs, + module, res_size, rhs, rank, ) } - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, + module, res_size, lhs, rhs, rank, rank, ) } - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, + module, res_size, rhs, rank, ) } } @@ -265,9 +272,24 @@ where } impl VecGLWEProductScratchSpace for GGSWCiphertext, FFT64> { - fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { - module.bytes_of_vec_znx_dft(2, rgsw_size) - + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) + fn prod_with_glwe_scratch_space( + module: &Module, + res_size: usize, + a_size: usize, + rgsw_size: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + module.bytes_of_vec_znx_dft(rank_out + 1, rgsw_size) + + ((module.bytes_of_vec_znx_dft(rank_in + 1, a_size) + + module.vmp_apply_tmp_bytes( + res_size, + a_size, + a_size, + rank_in + 1, + rank_out + 1, + rgsw_size, + )) | module.vec_znx_big_normalize_tmp_bytes()) } } @@ -290,6 +312,8 @@ where #[cfg(debug_assertions)] { + assert_eq!(self.rank(), a.rank()); + assert_eq!(self.rank(), res.rank()); assert_eq!(res.basek(), log_base2k); assert_eq!(a.basek(), log_base2k); assert_eq!(self.n(), module.n()); @@ -297,18 +321,22 @@ where assert_eq!(a.n(), module.n()); } - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise + let cols: usize = self.rank() + 1; + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, self.size()); // Todo optimise { - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, a.size()); - module.vec_znx_dft(&mut a_dft, 0, a, 0); - module.vec_znx_dft(&mut a_dft, 1, a, 1); + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, a.size()); + (0..cols).for_each(|col_i| { + module.vec_znx_dft(&mut a_dft, col_i, a, col_i); + }); module.vmp_apply(&mut res_dft, &a_dft, self, scratch2); } let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); - module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); + (0..cols).for_each(|i| { + module.vec_znx_big_normalize(log_base2k, res, i, &res_big, i, scratch1); + }); } } diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 5f8d086..063ceb2 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -113,23 +113,34 @@ impl GLWECiphertext> { + module.bytes_of_vec_znx_big(1, ct_size) } - pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) - } - - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, + pub fn keyswitch_scratch_space( + module: &Module, + res_size: usize, + lhs: usize, + rhs: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space( + module, res_size, lhs, rhs, rank_in, rank_out, ) } - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, rank, + ) } - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space( + module, res_size, lhs, rhs, rank, rank, + ) + } + + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, + module, res_size, rhs, rank, ) } } diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index e31d0dc..a16aba8 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -90,23 +90,34 @@ impl GLWECiphertextFourier, FFT64> { + module.bytes_of_vec_znx_big(1, ct_size) } - pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) - } - - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, + pub fn keyswitch_scratch_space( + module: &Module, + res_size: usize, + lhs: usize, + rhs: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space( + module, res_size, lhs, rhs, rank_in, rank_out, ) } - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, rank, + ) } - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space( + module, res_size, lhs, rhs, rank, rank, + ) + } + + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, + module, res_size, rhs, rank, ) } } diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index f9500a7..37774eb 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -9,7 +9,7 @@ use crate::{ gglwe_ciphertext::GGLWECiphertext, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, - keys::SecretKeyFourier, + keys::{SecretKey, SecretKeyFourier}, vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; @@ -103,46 +103,60 @@ impl GLWESwitchingKey where MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, { - pub fn encrypt_sk( + pub fn encrypt_sk( &mut self, module: &Module, - pt: &ScalarZnx, - sk_dft: &SecretKeyFourier, + sk_in: &SecretKey, + sk_out_dft: &SecretKeyFourier, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, scratch: &mut Scratch, ) where - ScalarZnx: ScalarZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, + ScalarZnx: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, { - self.0 - .encrypt_sk(module, pt, sk_dft, source_xa, source_xe, sigma, scratch); + self.0.encrypt_sk( + module, + &sk_in.data, + sk_out_dft, + source_xa, + source_xe, + sigma, + scratch, + ); } } impl GLWESwitchingKey, FFT64> { - pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + pub fn keyswitch_scratch_space( + module: &Module, + res_size: usize, + lhs: usize, + rhs: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, + module, res_size, lhs, rhs, rank_in, rank_out, ) } - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( - module, res_size, rhs, + module, res_size, rhs, rank, ) } - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, + module, res_size, lhs, rhs, rank, rank, ) } - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, + module, res_size, rhs, rank, ) } } diff --git a/core/src/test_fft64/gglwe.rs b/core/src/test_fft64/gglwe.rs index e4a566d..8327325 100644 --- a/core/src/test_fft64/gglwe.rs +++ b/core/src/test_fft64/gglwe.rs @@ -1,504 +1,503 @@ -use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; -use sampling::source::Source; - -use crate::{ - elem::{GetRow, Infos}, - ggsw_ciphertext::GGSWCiphertext, - glwe_ciphertext_fourier::GLWECiphertextFourier, - glwe_plaintext::GLWEPlaintext, - keys::{SecretKey, SecretKeyFourier}, - keyswitch_key::GLWESwitchingKey, - test_fft64::ggsw::noise_rgsw_product, -}; - -#[test] -fn encrypt_sk() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let rows: usize = 4; - let rank: usize = 1; - let rank_out: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_ct, rows, rank, rank_out); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - // sk.fill_ternary_prob(0.5, &mut source_xs); - sk.fill_zero(); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); - - (0..ct.rows()).for_each(|row_i| { - ct.get_row(&module, row_i, 0, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); - let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); - }); -} - -#[test] -fn keyswitch() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) - | GLWESwitchingKey::keyswitch_scratch_space( - &module, - ct_grlwe_s0s2.size(), - ct_grlwe_s0s1.size(), - ct_grlwe_s1s2.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module, rank); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module, rank); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk1_dft.dft(&module, &sk1); - - let mut sk2: SecretKey> = SecretKey::new(&module, rank); - sk2.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk2_dft.dft(&module, &sk2); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_s0s1.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // GRLWE_{s2}(s1) -> s1 -> s2 - ct_grlwe_s1s2.encrypt_sk( - &module, - &sk1.data, - &sk2_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s0s2.keyswitch(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); - - let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); - - (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { - ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); -} - -#[test] -fn keyswitch_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let rank: usize = 1; - let rank_out: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); - let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) - | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module, rank); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module, rank); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk1_dft.dft(&module, &sk1); - - let mut sk2: SecretKey> = SecretKey::new(&module, rank); - sk2.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk2_dft.dft(&module, &sk2); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_s0s1.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // GRLWE_{s2}(s1) -> s1 -> s2 - ct_grlwe_s1s2.encrypt_sk( - &module, - &sk1.data, - &sk2_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s0s1.keyswitch_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); - - let ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = ct_grlwe_s0s1; - - let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); - - (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { - ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); -} - -#[test] -fn external_product() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let rank: usize = 1; - let rank_out: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe_in: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); - let mut ct_grlwe_out: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); - - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_in.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_out.size()) - | GLWESwitchingKey::external_product_scratch_space( - &module, - ct_grlwe_out.size(), - ct_grlwe_in.size(), - ct_rgsw.size(), - ) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), - ); - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_in.encrypt_sk( - &module, - &pt_grlwe, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) - ct_grlwe_out.external_product(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); - - let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); - - (0..ct_grlwe_out.rows()).for_each(|row_i| { - ct_grlwe_out.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); -} - -#[test] -fn external_product_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let rank: usize = 1; - let rank_out: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); - - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe.size()) - | GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), - ); - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe.encrypt_sk( - &module, - &pt_grlwe, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) - ct_grlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); - - let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); - - (0..ct_grlwe.rows()).for_each(|row_i| { - ct_grlwe.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); -} - -pub(crate) fn noise_grlwe_rlwe_product( +// use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; +// use sampling::source::Source; +// +// use crate::{ +// elem::{GetRow, Infos}, +// ggsw_ciphertext::GGSWCiphertext, +// glwe_ciphertext_fourier::GLWECiphertextFourier, +// glwe_plaintext::GLWEPlaintext, +// keys::{SecretKey, SecretKeyFourier}, +// keyswitch_key::GLWESwitchingKey, +// test_fft64::ggsw::noise_rgsw_product, +// }; +// +// #[test] +// fn encrypt_sk() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 8; +// let log_k_ct: usize = 54; +// let rows: usize = 4; +// let rank: usize = 1; +// let rank_out: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_ct, rows, rank, rank_out); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); +// let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), +// ); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// sk.fill_zero(); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// ct.encrypt_sk( +// &module, +// &pt_scalar, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); +// +// (0..ct.rows()).for_each(|row_i| { +// ct.get_row(&module, row_i, 0, &mut ct_rlwe_dft); +// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); +// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); +// let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); +// assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); +// }); +// } +// +// #[test] +// fn keyswitch() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; +// +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); +// let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); +// let mut ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) +// | GLWESwitchingKey::keyswitch_scratch_space( +// &module, +// ct_grlwe_s0s2.size(), +// ct_grlwe_s0s1.size(), +// ct_grlwe_s1s2.size(), +// ), +// ); +// +// let mut sk0: SecretKey> = SecretKey::new(&module, rank); +// sk0.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk0_dft.dft(&module, &sk0); +// +// let mut sk1: SecretKey> = SecretKey::new(&module, rank); +// sk1.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk1_dft.dft(&module, &sk1); +// +// let mut sk2: SecretKey> = SecretKey::new(&module, rank); +// sk2.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk2_dft.dft(&module, &sk2); +// +// GRLWE_{s1}(s0) = s0 -> s1 +// ct_grlwe_s0s1.encrypt_sk( +// &module, +// &sk0.data, +// &sk1_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// GRLWE_{s2}(s1) -> s1 -> s2 +// ct_grlwe_s1s2.encrypt_sk( +// &module, +// &sk1.data, +// &sk2_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) +// ct_grlwe_s0s2.keyswitch(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); +// +// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); +// +// (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { +// ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); +// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); +// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// let noise_want: f64 = noise_grlwe_rlwe_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// 0.5, +// 0f64, +// sigma * sigma, +// 0f64, +// log_k_grlwe, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// }); +// } +// +// #[test] +// fn keyswitch_inplace() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; +// +// let rank: usize = 1; +// let rank_out: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); +// let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) +// | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), +// ); +// +// let mut sk0: SecretKey> = SecretKey::new(&module, rank); +// sk0.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk0_dft.dft(&module, &sk0); +// +// let mut sk1: SecretKey> = SecretKey::new(&module, rank); +// sk1.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk1_dft.dft(&module, &sk1); +// +// let mut sk2: SecretKey> = SecretKey::new(&module, rank); +// sk2.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk2_dft.dft(&module, &sk2); +// +// GRLWE_{s1}(s0) = s0 -> s1 +// ct_grlwe_s0s1.encrypt_sk( +// &module, +// &sk0.data, +// &sk1_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// GRLWE_{s2}(s1) -> s1 -> s2 +// ct_grlwe_s1s2.encrypt_sk( +// &module, +// &sk1.data, +// &sk2_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) +// ct_grlwe_s0s1.keyswitch_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); +// +// let ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = ct_grlwe_s0s1; +// +// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); +// +// (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { +// ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); +// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); +// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// let noise_want: f64 = noise_grlwe_rlwe_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// 0.5, +// 0f64, +// sigma * sigma, +// 0f64, +// log_k_grlwe, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// }); +// } +// +// #[test] +// fn external_product() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; +// +// let rank: usize = 1; +// let rank_out: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe_in: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); +// let mut ct_grlwe_out: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); +// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); +// +// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); +// let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_in.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_out.size()) +// | GLWESwitchingKey::external_product_scratch_space( +// &module, +// ct_grlwe_out.size(), +// ct_grlwe_in.size(), +// ct_rgsw.size(), +// ) +// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), +// ); +// +// let k: usize = 1; +// +// pt_rgsw.raw_mut()[k] = 1; // X^{k} +// +// pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// GRLWE_{s1}(s0) = s0 -> s1 +// ct_grlwe_in.encrypt_sk( +// &module, +// &pt_grlwe, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw.encrypt_sk( +// &module, +// &pt_rgsw, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) +// ct_grlwe_out.external_product(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); +// +// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); +// +// module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); +// +// (0..ct_grlwe_out.rows()).for_each(|row_i| { +// ct_grlwe_out.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); +// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); +// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// +// let var_gct_err_lhs: f64 = sigma * sigma; +// let var_gct_err_rhs: f64 = 0f64; +// +// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} +// let var_a0_err: f64 = sigma * sigma; +// let var_a1_err: f64 = 1f64 / 12f64; +// +// let noise_want: f64 = noise_rgsw_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// var_msg, +// var_a0_err, +// var_a1_err, +// var_gct_err_lhs, +// var_gct_err_rhs, +// log_k_grlwe, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// }); +// } +// +// #[test] +// fn external_product_inplace() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; +// +// let rank: usize = 1; +// let rank_out: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); +// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); +// +// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); +// let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe.size()) +// | GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) +// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), +// ); +// +// let k: usize = 1; +// +// pt_rgsw.raw_mut()[k] = 1; // X^{k} +// +// pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// GRLWE_{s1}(s0) = s0 -> s1 +// ct_grlwe.encrypt_sk( +// &module, +// &pt_grlwe, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw.encrypt_sk( +// &module, +// &pt_rgsw, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) +// ct_grlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); +// +// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); +// +// module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); +// +// (0..ct_grlwe.rows()).for_each(|row_i| { +// ct_grlwe.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); +// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); +// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// +// let var_gct_err_lhs: f64 = sigma * sigma; +// let var_gct_err_rhs: f64 = 0f64; +// +// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} +// let var_a0_err: f64 = sigma * sigma; +// let var_a1_err: f64 = 1f64 / 12f64; +// +// let noise_want: f64 = noise_rgsw_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// var_msg, +// var_a0_err, +// var_a1_err, +// var_gct_err_lhs, +// var_gct_err_rhs, +// log_k_grlwe, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// }); +// } +pub(crate) fn noise_gglwe_product( n: f64, log_base2k: usize, var_xs: f64, @@ -506,6 +505,7 @@ pub(crate) fn noise_grlwe_rlwe_product( var_a_err: f64, var_gct_err_lhs: f64, var_gct_err_rhs: f64, + rank_in: f64, a_logq: usize, b_logq: usize, ) -> f64 { @@ -522,6 +522,7 @@ pub(crate) fn noise_grlwe_rlwe_product( // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); noise += var_msg * var_a_err * a_scale * a_scale * n; + noise *= rank_in; noise = noise.sqrt(); noise /= b_scale; noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index f1903c1..c514ef9 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -1,573 +1,572 @@ -use base2k::{ - FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, -}; -use sampling::source::Source; - -use crate::{ - elem::{GetRow, Infos}, - ggsw_ciphertext::GGSWCiphertext, - glwe_ciphertext_fourier::GLWECiphertextFourier, - glwe_plaintext::GLWEPlaintext, - keys::{SecretKey, SecretKeyFourier}, - keyswitch_key::GLWESwitchingKey, - test_fft64::gglwe::noise_grlwe_rlwe_product, -}; - -#[test] -fn encrypt_sk() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let rows: usize = 4; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); - - (0..ct.rank()).for_each(|col_j| { - (0..ct.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); - - if col_j == 1 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); - } - - ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); - - ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); - - pt_want.data.zero(); - }); - }); -} - -#[test] -fn keyswitch() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rgsw_in: usize = 45; - let log_k_rgsw_out: usize = 45; - let rows: usize = (log_k_rgsw_in + log_base2k - 1) / log_base2k; - - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_rgsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_in, rows, rank); - let mut ct_rgsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_out, rows, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_out.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_in.size()) - | GGSWCiphertext::keyswitch_scratch_space( - &module, - ct_rgsw_out.size(), - ct_rgsw_in.size(), - ct_grlwe.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module, rank); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module, rank); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw_in.encrypt_sk( - &module, - &pt_rgsw, - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw_out.keyswitch(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow()); - - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_out, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_out.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_out.size()); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); - - (0..ct_rgsw_out.rank()).for_each(|col_j| { - (0..ct_rgsw_out.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); - - if col_j == 1 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); - } - - ct_rgsw_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.2, - "have: {} want: {}", - noise_have, - noise_want - ); - - pt_want.data.zero(); - }); - }); -} - -#[test] -fn keyswitch_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rgsw: usize = 45; - let rows: usize = (log_k_rgsw + log_base2k - 1) / log_base2k; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw, rows, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) - | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module, rank); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module, rank); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); - - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw.size()); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); - - (0..ct_rgsw.rank()).for_each(|col_j| { - (0..ct_rgsw.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); - - if col_j == 1 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); - } - - ct_rgsw.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.2, - "have: {} want: {}", - noise_have, - noise_want - ); - - pt_want.data.zero(); - }); - }); -} - -#[test] -fn external_product() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_rgsw_rhs: usize = 60; - let log_k_rgsw_lhs_in: usize = 45; - let log_k_rgsw_lhs_out: usize = 45; - let rows: usize = (log_k_rgsw_lhs_in + log_base2k - 1) / log_base2k; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); - let mut ct_rgsw_lhs_in: GGSWCiphertext, FFT64> = - GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_in, rows, rank); - let mut ct_rgsw_lhs_out: GGSWCiphertext, FFT64> = - GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_out, rows, rank); - let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); - - let k: usize = 1; - - pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs_in.size()) - | GGSWCiphertext::external_product_scratch_space( - &module, - ct_rgsw_lhs_out.size(), - ct_rgsw_lhs_in.size(), - ct_rgsw_rhs.size(), - ), - ); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - ct_rgsw_rhs.encrypt_sk( - &module, - &pt_rgsw_rhs, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw_lhs_in.encrypt_sk( - &module, - &pt_rgsw_lhs, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw_lhs_out.external_product(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow()); - - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs_out, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs_out.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs_out.size()); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); - - (0..ct_rgsw_lhs_out.rank()).for_each(|col_j| { - (0..ct_rgsw_lhs_out.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); - - if col_j == 1 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); - } - - ct_rgsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rgsw_lhs_in, - log_k_rgsw_rhs, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "have: {} want: {}", - noise_have, - noise_want - ); - - pt_want.data.zero(); - }); - }); -} - -#[test] -fn external_product_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_rgsw_rhs: usize = 60; - let log_k_rgsw_lhs: usize = 45; - let rows: usize = (log_k_rgsw_lhs + log_base2k - 1) / log_base2k; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); - let mut ct_rgsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows, rank); - let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); - - let k: usize = 1; - - pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs.size()) - | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - ct_rgsw_rhs.encrypt_sk( - &module, - &pt_rgsw_rhs, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw_lhs.encrypt_sk( - &module, - &pt_rgsw_lhs, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw_lhs.external_product_inplace(&module, &ct_rgsw_rhs, scratch.borrow()); - - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs.size()); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); - - (0..ct_rgsw_lhs.rank()).for_each(|col_j| { - (0..ct_rgsw_lhs.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); - - if col_j == 1 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); - } - - ct_rgsw_lhs.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rgsw_lhs, - log_k_rgsw_rhs, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "have: {} want: {}", - noise_have, - noise_want - ); - - pt_want.data.zero(); - }); - }); -} - +// use base2k::{ +// FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, +// VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, +// }; +// use sampling::source::Source; +// +// use crate::{ +// elem::{GetRow, Infos}, +// ggsw_ciphertext::GGSWCiphertext, +// glwe_ciphertext_fourier::GLWECiphertextFourier, +// glwe_plaintext::GLWEPlaintext, +// keys::{SecretKey, SecretKeyFourier}, +// keyswitch_key::GLWESwitchingKey, +// test_fft64::gglwe::noise_grlwe_rlwe_product, +// }; +// +// #[test] +// fn encrypt_sk() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 8; +// let log_k_ct: usize = 54; +// let rows: usize = 4; +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows, rank); +// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); +// let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), +// ); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// ct.encrypt_sk( +// &module, +// &pt_scalar, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); +// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); +// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); +// +// (0..ct.rank()).for_each(|col_j| { +// (0..ct.rows()).for_each(|row_i| { +// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); +// +// if col_j == 1 { +// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); +// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); +// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); +// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); +// } +// +// ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); +// +// ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); +// +// let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); +// assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); +// +// pt_want.data.zero(); +// }); +// }); +// } +// +// #[test] +// fn keyswitch() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let log_k_rgsw_in: usize = 45; +// let log_k_rgsw_out: usize = 45; +// let rows: usize = (log_k_rgsw_in + log_base2k - 1) / log_base2k; +// +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); +// let mut ct_rgsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_in, rows, rank); +// let mut ct_rgsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_out, rows, rank); +// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_out.size()) +// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_in.size()) +// | GGSWCiphertext::keyswitch_scratch_space( +// &module, +// ct_rgsw_out.size(), +// ct_rgsw_in.size(), +// ct_grlwe.size(), +// ), +// ); +// +// let mut sk0: SecretKey> = SecretKey::new(&module, rank); +// sk0.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk0_dft.dft(&module, &sk0); +// +// let mut sk1: SecretKey> = SecretKey::new(&module, rank); +// sk1.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk1_dft.dft(&module, &sk1); +// +// ct_grlwe.encrypt_sk( +// &module, +// &sk0.data, +// &sk1_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw_in.encrypt_sk( +// &module, +// &pt_rgsw, +// &sk0_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw_out.keyswitch(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow()); +// +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_out, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); +// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_out.size()); +// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_out.size()); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); +// +// (0..ct_rgsw_out.rank()).for_each(|col_j| { +// (0..ct_rgsw_out.rows()).for_each(|row_i| { +// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); +// +// if col_j == 1 { +// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); +// module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); +// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); +// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); +// } +// +// ct_rgsw_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); +// ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// let noise_want: f64 = noise_grlwe_rlwe_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// 0.5, +// 0f64, +// sigma * sigma, +// 0f64, +// log_k_grlwe, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.2, +// "have: {} want: {}", +// noise_have, +// noise_want +// ); +// +// pt_want.data.zero(); +// }); +// }); +// } +// +// #[test] +// fn keyswitch_inplace() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let log_k_rgsw: usize = 45; +// let rows: usize = (log_k_rgsw + log_base2k - 1) / log_base2k; +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); +// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw, rows, rank); +// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw.size()) +// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) +// | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()), +// ); +// +// let mut sk0: SecretKey> = SecretKey::new(&module, rank); +// sk0.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk0_dft.dft(&module, &sk0); +// +// let mut sk1: SecretKey> = SecretKey::new(&module, rank); +// sk1.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk1_dft.dft(&module, &sk1); +// +// ct_grlwe.encrypt_sk( +// &module, +// &sk0.data, +// &sk1_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw.encrypt_sk( +// &module, +// &pt_rgsw, +// &sk0_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); +// +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); +// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw.size()); +// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw.size()); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); +// +// (0..ct_rgsw.rank()).for_each(|col_j| { +// (0..ct_rgsw.rows()).for_each(|row_i| { +// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); +// +// if col_j == 1 { +// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); +// module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); +// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); +// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); +// } +// +// ct_rgsw.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); +// ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// let noise_want: f64 = noise_grlwe_rlwe_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// 0.5, +// 0f64, +// sigma * sigma, +// 0f64, +// log_k_grlwe, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.2, +// "have: {} want: {}", +// noise_have, +// noise_want +// ); +// +// pt_want.data.zero(); +// }); +// }); +// } +// +// #[test] +// fn external_product() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_rgsw_rhs: usize = 60; +// let log_k_rgsw_lhs_in: usize = 45; +// let log_k_rgsw_lhs_out: usize = 45; +// let rows: usize = (log_k_rgsw_lhs_in + log_base2k - 1) / log_base2k; +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); +// let mut ct_rgsw_lhs_in: GGSWCiphertext, FFT64> = +// GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_in, rows, rank); +// let mut ct_rgsw_lhs_out: GGSWCiphertext, FFT64> = +// GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_out, rows, rank); +// let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); +// let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); +// +// let k: usize = 1; +// +// pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size()) +// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs_in.size()) +// | GGSWCiphertext::external_product_scratch_space( +// &module, +// ct_rgsw_lhs_out.size(), +// ct_rgsw_lhs_in.size(), +// ct_rgsw_rhs.size(), +// ), +// ); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// ct_rgsw_rhs.encrypt_sk( +// &module, +// &pt_rgsw_rhs, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw_lhs_in.encrypt_sk( +// &module, +// &pt_rgsw_lhs, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw_lhs_out.external_product(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow()); +// +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs_out, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); +// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs_out.size()); +// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs_out.size()); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); +// +// module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); +// +// (0..ct_rgsw_lhs_out.rank()).for_each(|col_j| { +// (0..ct_rgsw_lhs_out.rows()).for_each(|row_i| { +// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); +// +// if col_j == 1 { +// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); +// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); +// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); +// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); +// } +// +// ct_rgsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); +// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// +// let var_gct_err_lhs: f64 = sigma * sigma; +// let var_gct_err_rhs: f64 = 0f64; +// +// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} +// let var_a0_err: f64 = sigma * sigma; +// let var_a1_err: f64 = 1f64 / 12f64; +// +// let noise_want: f64 = noise_rgsw_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// var_msg, +// var_a0_err, +// var_a1_err, +// var_gct_err_lhs, +// var_gct_err_rhs, +// log_k_rgsw_lhs_in, +// log_k_rgsw_rhs, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "have: {} want: {}", +// noise_have, +// noise_want +// ); +// +// pt_want.data.zero(); +// }); +// }); +// } +// +// #[test] +// fn external_product_inplace() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_rgsw_rhs: usize = 60; +// let log_k_rgsw_lhs: usize = 45; +// let rows: usize = (log_k_rgsw_lhs + log_base2k - 1) / log_base2k; +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); +// let mut ct_rgsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows, rank); +// let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); +// let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); +// +// let k: usize = 1; +// +// pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs.size()) +// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs.size()) +// | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()), +// ); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// ct_rgsw_rhs.encrypt_sk( +// &module, +// &pt_rgsw_rhs, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw_lhs.encrypt_sk( +// &module, +// &pt_rgsw_lhs, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw_lhs.external_product_inplace(&module, &ct_rgsw_rhs, scratch.borrow()); +// +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); +// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs.size()); +// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs.size()); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); +// +// module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); +// +// (0..ct_rgsw_lhs.rank()).for_each(|col_j| { +// (0..ct_rgsw_lhs.rows()).for_each(|row_i| { +// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); +// +// if col_j == 1 { +// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); +// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); +// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); +// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); +// } +// +// ct_rgsw_lhs.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); +// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// +// let var_gct_err_lhs: f64 = sigma * sigma; +// let var_gct_err_rhs: f64 = 0f64; +// +// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} +// let var_a0_err: f64 = sigma * sigma; +// let var_a1_err: f64 = 1f64 / 12f64; +// +// let noise_want: f64 = noise_rgsw_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// var_msg, +// var_a0_err, +// var_a1_err, +// var_gct_err_lhs, +// var_gct_err_rhs, +// log_k_rgsw_lhs, +// log_k_rgsw_rhs, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "have: {} want: {}", +// noise_have, +// noise_want +// ); +// +// pt_want.data.zero(); +// }); +// }); +// } pub(crate) fn noise_rgsw_product( n: f64, log_base2k: usize, diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index dca899b..5f2c876 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -13,7 +13,7 @@ use crate::{ glwe_plaintext::GLWEPlaintext, keys::{GLWEPublicKey, SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, - test_fft64::{gglwe::noise_grlwe_rlwe_product, ggsw::noise_rgsw_product}, + test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_rgsw_product}, }; #[test] @@ -197,21 +197,32 @@ fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: #[test] fn keyswitch() { - let module: Module = Module::::new(2048); - let basek: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + basek - 1) / basek; - let rank: usize = 1; + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + println!("test keyswitch rank_in: {} rank_out: {}", rank_in, rank_out); + test_keyswitch(12, 12, 60, 45, 60, rank_in, rank_out, 3.2); + }); + }); +} - let sigma: f64 = 3.2; +fn test_keyswitch( + log_n: usize, + basek: usize, + k_keyswitch: usize, + k_ct_in: usize, + k_ct_out: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct_in + basek - 1) / basek; - let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_out, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_out); + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_keyswitch, rows, rank_in, rank_out); + let mut ct_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank_in); + let mut ct_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -223,57 +234,59 @@ fn keyswitch() { .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) - | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_in, ksk.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_in.size()) | GLWECiphertext::keyswitch_scratch_space( &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_grlwe.size(), + ct_out.size(), + ct_in.size(), + ksk.size(), + rank_in, + rank_out, ), ); - let mut sk0: SecretKey> = SecretKey::new(&module, rank); - sk0.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_in: SecretKey> = SecretKey::new(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk0_dft.dft(&module, &sk0); + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in); + sk_in_dft.dft(&module, &sk_in); - let mut sk1: SecretKey> = SecretKey::new(&module, rank); - sk1.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_out: SecretKey> = SecretKey::new(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk1_dft.dft(&module, &sk1); + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out); + sk_out_dft.dft(&module, &sk_out); - ct_grlwe.encrypt_sk( + ksk.encrypt_sk( &module, - &sk0.data, - &sk1_dft, + &sk_in, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct_rlwe_in.encrypt_sk( + ct_in.encrypt_sk( &module, &pt_want, - &sk0_dft, + &sk_in_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct_rlwe_out.keyswitch(&module, &ct_rlwe_in, &ct_grlwe, scratch.borrow()); + ct_out.keyswitch(&module, &ct_in, &ksk, scratch.borrow()); - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + ct_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( + let noise_want: f64 = noise_gglwe_product( module.n() as f64, basek, 0.5, @@ -281,8 +294,9 @@ fn keyswitch() { 0f64, sigma * sigma, 0f64, - log_k_rlwe_in, - log_k_grlwe, + rank_in as f64, + k_ct_in, + k_keyswitch, ); assert!( @@ -322,7 +336,7 @@ fn keyswich_inplace() { GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) - | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size(), rank), ); let mut sk0: SecretKey> = SecretKey::new(&module, rank); @@ -339,7 +353,7 @@ fn keyswich_inplace() { ct_grlwe.encrypt_sk( &module, - &sk0.data, + &sk0, &sk1_dft, &mut source_xa, &mut source_xe, @@ -364,7 +378,7 @@ fn keyswich_inplace() { module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( + let noise_want: f64 = noise_gglwe_product( module.n() as f64, basek, 0.5, @@ -372,6 +386,7 @@ fn keyswich_inplace() { 0f64, sigma * sigma, 0f64, + rank as f64, log_k_rlwe, log_k_grlwe, ); @@ -427,6 +442,7 @@ fn external_product() { ct_rlwe_out.size(), ct_rlwe_in.size(), ct_rgsw.size(), + rank, ), ); @@ -531,7 +547,7 @@ fn external_product_inplace() { GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) - | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size(), rank), ); let mut sk: SecretKey> = SecretKey::new(&module, rank); diff --git a/core/src/test_fft64/glwe_fourier.rs b/core/src/test_fft64/glwe_fourier.rs index 16f9eca..f25bac9 100644 --- a/core/src/test_fft64/glwe_fourier.rs +++ b/core/src/test_fft64/glwe_fourier.rs @@ -1,438 +1,438 @@ -use crate::{ - elem::Infos, - ggsw_ciphertext::GGSWCiphertext, - glwe_ciphertext::GLWECiphertext, - glwe_ciphertext_fourier::GLWECiphertextFourier, - glwe_plaintext::GLWEPlaintext, - keys::{SecretKey, SecretKeyFourier}, - keyswitch_key::GLWESwitchingKey, - test_fft64::{gglwe::noise_grlwe_rlwe_product, ggsw::noise_rgsw_product}, -}; -use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; -use sampling::source::Source; - -#[test] -fn keyswitch() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut ct_rlwe_in_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); - let mut ct_rlwe_out_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) - | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) - | GLWECiphertextFourier::keyswitch_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_grlwe.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module, rank); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module, rank); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - &pt_want, - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); - ct_rlwe_out_dft.keyswitch(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); - ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); -} - -#[test] -fn keyswich_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe: usize = 45; - let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe, rank); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) - | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) - | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module, rank); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module, rank); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - &pt_want, - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); -} - -#[test] -fn external_product() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); - let mut ct_rlwe_dft_in: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut ct_rlwe_dft_out: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - pt_want.to_mut().at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) - | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) - | GLWECiphertext::external_product_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_rgsw.size(), - ), - ); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - &pt_want, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe_in.dft(&module, &mut ct_rlwe_dft_in); - ct_rlwe_dft_out.external_product(&module, &ct_rlwe_dft_in, &ct_rgsw, scratch.borrow()); - ct_rlwe_dft_out.idft(&module, &mut ct_rlwe_out, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); -} - -#[test] -fn external_product_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - pt_want.to_mut().at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) - | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) - | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - &pt_want, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); -} +// use crate::{ +// elem::Infos, +// ggsw_ciphertext::GGSWCiphertext, +// glwe_ciphertext::GLWECiphertext, +// glwe_ciphertext_fourier::GLWECiphertextFourier, +// glwe_plaintext::GLWEPlaintext, +// keys::{SecretKey, SecretKeyFourier}, +// keyswitch_key::GLWESwitchingKey, +// test_fft64::{gglwe::noise_grlwe_rlwe_product, ggsw::noise_rgsw_product}, +// }; +// use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; +// use sampling::source::Source; +// +// #[test] +// fn keyswitch() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let log_k_rlwe_in: usize = 45; +// let log_k_rlwe_out: usize = 60; +// let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; +// +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); +// let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); +// let mut ct_rlwe_in_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); +// let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); +// let mut ct_rlwe_out_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out, rank); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); +// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_want +// .data +// .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) +// | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) +// | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) +// | GLWECiphertextFourier::keyswitch_scratch_space( +// &module, +// ct_rlwe_out.size(), +// ct_rlwe_in.size(), +// ct_grlwe.size(), +// ), +// ); +// +// let mut sk0: SecretKey> = SecretKey::new(&module, rank); +// sk0.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk0_dft.dft(&module, &sk0); +// +// let mut sk1: SecretKey> = SecretKey::new(&module, rank); +// sk1.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk1_dft.dft(&module, &sk1); +// +// ct_grlwe.encrypt_sk( +// &module, +// &sk0.data, +// &sk1_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe_in.encrypt_sk( +// &module, +// &pt_want, +// &sk0_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); +// ct_rlwe_out_dft.keyswitch(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); +// ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); +// +// ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); +// +// let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); +// let noise_want: f64 = noise_grlwe_rlwe_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// 0.5, +// 0f64, +// sigma * sigma, +// 0f64, +// log_k_rlwe_in, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// } +// +// #[test] +// fn keyswich_inplace() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let log_k_rlwe: usize = 45; +// let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); +// let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe, rank); +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe, rank); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); +// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_want +// .data +// .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) +// | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) +// | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) +// | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), +// ); +// +// let mut sk0: SecretKey> = SecretKey::new(&module, rank); +// sk0.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk0_dft.dft(&module, &sk0); +// +// let mut sk1: SecretKey> = SecretKey::new(&module, rank); +// sk1.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk1_dft.dft(&module, &sk1); +// +// ct_grlwe.encrypt_sk( +// &module, +// &sk0.data, +// &sk1_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe.encrypt_sk( +// &module, +// &pt_want, +// &sk0_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe.dft(&module, &mut ct_rlwe_dft); +// ct_rlwe_dft.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); +// ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); +// +// ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); +// +// let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); +// let noise_want: f64 = noise_grlwe_rlwe_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// 0.5, +// 0f64, +// sigma * sigma, +// 0f64, +// log_k_rlwe, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// } +// +// #[test] +// fn external_product() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let log_k_rlwe_in: usize = 45; +// let log_k_rlwe_out: usize = 60; +// let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); +// let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); +// let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); +// let mut ct_rlwe_dft_in: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); +// let mut ct_rlwe_dft_out: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out, rank); +// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); +// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_want +// .data +// .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); +// +// pt_want.to_mut().at_mut(0, 0)[1] = 1; +// +// let k: usize = 1; +// +// pt_rgsw.raw_mut()[k] = 1; // X^{k} +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) +// | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) +// | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) +// | GLWECiphertext::external_product_scratch_space( +// &module, +// ct_rlwe_out.size(), +// ct_rlwe_in.size(), +// ct_rgsw.size(), +// ), +// ); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// ct_rgsw.encrypt_sk( +// &module, +// &pt_rgsw, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe_in.encrypt_sk( +// &module, +// &pt_want, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe_in.dft(&module, &mut ct_rlwe_dft_in); +// ct_rlwe_dft_out.external_product(&module, &ct_rlwe_dft_in, &ct_rgsw, scratch.borrow()); +// ct_rlwe_dft_out.idft(&module, &mut ct_rlwe_out, scratch.borrow()); +// +// ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); +// +// module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); +// +// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); +// +// let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); +// +// let var_gct_err_lhs: f64 = sigma * sigma; +// let var_gct_err_rhs: f64 = 0f64; +// +// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} +// let var_a0_err: f64 = sigma * sigma; +// let var_a1_err: f64 = 1f64 / 12f64; +// +// let noise_want: f64 = noise_rgsw_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// var_msg, +// var_a0_err, +// var_a1_err, +// var_gct_err_lhs, +// var_gct_err_rhs, +// log_k_rlwe_in, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// } +// +// #[test] +// fn external_product_inplace() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let log_k_rlwe_in: usize = 45; +// let log_k_rlwe_out: usize = 60; +// let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); +// let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); +// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); +// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_want +// .data +// .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); +// +// pt_want.to_mut().at_mut(0, 0)[1] = 1; +// +// let k: usize = 1; +// +// pt_rgsw.raw_mut()[k] = 1; // X^{k} +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) +// | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) +// | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) +// | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), +// ); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// ct_rgsw.encrypt_sk( +// &module, +// &pt_rgsw, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe.encrypt_sk( +// &module, +// &pt_want, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe.dft(&module, &mut ct_rlwe_dft); +// ct_rlwe_dft.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); +// ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); +// +// ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); +// +// module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); +// +// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); +// +// let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); +// +// let var_gct_err_lhs: f64 = sigma * sigma; +// let var_gct_err_rhs: f64 = 0f64; +// +// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} +// let var_a0_err: f64 = sigma * sigma; +// let var_a1_err: f64 = 1f64 / 12f64; +// +// let noise_want: f64 = noise_rgsw_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// var_msg, +// var_a0_err, +// var_a1_err, +// var_gct_err_lhs, +// var_gct_err_rhs, +// log_k_rlwe_in, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// } diff --git a/core/src/vec_glwe_product.rs b/core/src/vec_glwe_product.rs index 63c4769..d3e6636 100644 --- a/core/src/vec_glwe_product.rs +++ b/core/src/vec_glwe_product.rs @@ -10,31 +10,53 @@ use crate::{ }; pub(crate) trait VecGLWEProductScratchSpace { - fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; + fn prod_with_glwe_scratch_space( + module: &Module, + res_size: usize, + lhs: usize, + rhs: usize, + rank_in: usize, + rank_out: usize, + ) -> usize; - fn prod_with_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - Self::prod_with_glwe_scratch_space(module, res_size, res_size, rhs) + fn prod_with_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { + Self::prod_with_glwe_scratch_space(module, res_size, res_size, rhs, rank, rank) } - fn prod_with_glwe_dft_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - (Self::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, lhs) - + module.bytes_of_vec_znx(2, res_size) + fn prod_with_glwe_dft_scratch_space( + module: &Module, + res_size: usize, + lhs: usize, + rhs: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + (Self::prod_with_glwe_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(rank_in, lhs) + + module.bytes_of_vec_znx(rank_out, res_size) } - fn prod_with_glwe_dft_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - (Self::prod_with_glwe_inplace_scratch_space(module, res_size, rhs) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, res_size) + fn prod_with_glwe_dft_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { + (Self::prod_with_glwe_inplace_scratch_space(module, res_size, rhs, rank) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(rank + 1, res_size) } - fn prod_with_vec_glwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - Self::prod_with_glwe_dft_scratch_space(module, res_size, lhs, rhs) - + module.bytes_of_vec_znx_dft(2, lhs) - + module.bytes_of_vec_znx_dft(2, res_size) + fn prod_with_vec_glwe_scratch_space( + module: &Module, + res_size: usize, + lhs: usize, + rhs: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + Self::prod_with_glwe_dft_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out) + + module.bytes_of_vec_znx_dft(rank_in + 1, lhs) + + module.bytes_of_vec_znx_dft(rank_out + 1, res_size) } - fn prod_with_vec_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - Self::prod_with_glwe_dft_inplace_scratch_space(module, res_size, rhs) + module.bytes_of_vec_znx_dft(2, res_size) + fn prod_with_vec_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { + Self::prod_with_glwe_dft_inplace_scratch_space(module, res_size, rhs, rank) + + module.bytes_of_vec_znx_dft(rank + 1, res_size) } } @@ -78,7 +100,7 @@ pub(crate) trait VecGLWEProduct: Infos { assert_eq!(res.n(), module.n()); } - let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); + let (a_data, scratch_1) = scratch.tmp_vec_znx(module, a.rank() + 1, a.size()); let mut a_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: a_data, @@ -88,7 +110,7 @@ pub(crate) trait VecGLWEProduct: Infos { a.idft(module, &mut a_idft, scratch_1); - let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); + let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, res.rank() + 1, res.size()); let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: res_data, @@ -98,8 +120,7 @@ pub(crate) trait VecGLWEProduct: Infos { self.prod_with_glwe(module, &mut res_idft, &a_idft, scratch_2); - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); + res_idft.dft(module, res); } fn prod_with_glwe_fourier_inplace( @@ -119,7 +140,7 @@ pub(crate) trait VecGLWEProduct: Infos { assert_eq!(res.n(), module.n()); } - let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); + let (res_data, scratch_1) = scratch.tmp_vec_znx(module, res.rank() + 1, res.size()); let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: res_data, @@ -131,8 +152,7 @@ pub(crate) trait VecGLWEProduct: Infos { self.prod_with_glwe_inplace(module, &mut res_idft, scratch_1); - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); + res_idft.dft(module, res); } fn prod_with_vec_glwe(&self, module: &Module, res: &mut RES, a: &LHS, scratch: &mut Scratch) @@ -140,7 +160,7 @@ pub(crate) trait VecGLWEProduct: Infos { LHS: GetRow + Infos, RES: SetRow + Infos, { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, a.cols(), a.size()); let mut tmp_a_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_row_data, @@ -148,7 +168,7 @@ pub(crate) trait VecGLWEProduct: Infos { k: a.k(), }; - let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, res.size()); + let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, res.cols(), res.size()); let mut tmp_res_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_res_data, @@ -179,7 +199,7 @@ pub(crate) trait VecGLWEProduct: Infos { where RES: GetRow + SetRow + Infos, { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, res.cols(), res.size()); let mut tmp_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_row_data, From 4c55a7df444779419b22285d8dcbdd8a0f3f2694 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 14 May 2025 16:57:57 +0200 Subject: [PATCH 67/87] updated ggsw product noise prediction & added test for ggsw x glwe of rank > 1 --- core/src/test_fft64/ggsw.rs | 5 +- core/src/test_fft64/glwe.rs | 114 +++++++++++++++++++----------------- 2 files changed, 62 insertions(+), 57 deletions(-) diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index c514ef9..9420831 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -576,6 +576,7 @@ pub(crate) fn noise_rgsw_product( var_a1_err: f64, var_gct_err_lhs: f64, var_gct_err_rhs: f64, + rank: f64, a_logq: usize, b_logq: usize, ) -> f64 { @@ -590,9 +591,9 @@ pub(crate) fn noise_rgsw_product( // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs - let mut noise: f64 = 2.0 * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); + let mut noise: f64 = (rank + 1.0) * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); noise += var_msg * var_a0_err * a_scale * a_scale * n; - noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs; + noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs * rank; noise = noise.sqrt(); noise /= b_scale; noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 5f2c876..008e761 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -40,6 +40,24 @@ fn encrypt_pk() { }); } +#[test] +fn keyswitch() { + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + println!("test keyswitch rank_in: {} rank_out: {}", rank_in, rank_out); + test_keyswitch(12, 12, 60, 45, 60, rank_in, rank_out, 3.2); + }); + }); +} + +#[test] +fn keyswitch_inplace() { + (1..4).for_each(|rank| { + println!("test keyswitch_inplace rank: {}", rank); + test_keyswitch_inplace(12, 12, 60, 45, rank, 3.2); + }); +} + fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { let module: Module = Module::::new(1 << log_n); @@ -195,16 +213,6 @@ fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: ); } -#[test] -fn keyswitch() { - (1..4).for_each(|rank_in| { - (1..4).for_each(|rank_out| { - println!("test keyswitch rank_in: {} rank_out: {}", rank_in, rank_out); - test_keyswitch(12, 12, 60, 45, 60, rank_in, rank_out, 3.2); - }); - }); -} - fn test_keyswitch( log_n: usize, basek: usize, @@ -307,21 +315,14 @@ fn test_keyswitch( ); } -#[test] -fn keyswich_inplace() { - let module: Module = Module::::new(2048); - let basek: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe: usize = 45; - let rows: usize = (log_k_rlwe + basek - 1) / basek; - let rank: usize = 1; +fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct + basek - 1) / basek; - let sigma: f64 = 3.2; - - let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -387,8 +388,8 @@ fn keyswich_inplace() { sigma * sigma, 0f64, rank as f64, - log_k_rlwe, - log_k_grlwe, + k_ct, + k_ksk, ); assert!( @@ -401,22 +402,23 @@ fn keyswich_inplace() { #[test] fn external_product() { - let module: Module = Module::::new(2048); - let basek: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + basek - 1) / basek; - let rank: usize = 1; + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product(12, 12, 60, 45, 60, rank, 3.2); + }); +} - let sigma: f64 = 3.2; +fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usize, k_ct_out: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_out, rank); + let rows: usize = (k_ct_in + basek - 1) / basek; + + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -498,8 +500,9 @@ fn external_product() { var_a1_err, var_gct_err_lhs, var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, + rank as f64, + k_ct_in, + k_ggsw, ); assert!( @@ -512,21 +515,21 @@ fn external_product() { #[test] fn external_product_inplace() { - let module: Module = Module::::new(2048); - let basek: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + basek - 1) / basek; - let rank: usize = 1; + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product_inplace(12, 15, 60, 60, rank, 3.2); + }); +} - let sigma: f64 = 3.2; +fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct + basek - 1) / basek; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -602,8 +605,9 @@ fn external_product_inplace() { var_a1_err, var_gct_err_lhs, var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, + rank as f64, + k_ct, + k_ggsw, ); assert!( From 67594e2e3f382a13a9b6e0844550740f3b3f834f Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 14 May 2025 18:24:45 +0200 Subject: [PATCH 68/87] fixed benchmarks --- core/benches/external_product_glwe_fft64.rs | 81 ++++++----- core/benches/keyswitch_glwe_fft64.rs | 143 +++++++++++--------- core/src/gglwe_ciphertext.rs | 2 +- core/src/ggsw_ciphertext.rs | 2 +- core/src/glwe_ciphertext.rs | 4 +- core/src/glwe_ciphertext_fourier.rs | 2 +- core/src/test_fft64/glwe.rs | 15 +- 7 files changed, 128 insertions(+), 121 deletions(-) diff --git a/core/benches/external_product_glwe_fft64.rs b/core/benches/external_product_glwe_fft64.rs index 435a25f..1739211 100644 --- a/core/benches/external_product_glwe_fft64.rs +++ b/core/benches/external_product_glwe_fft64.rs @@ -2,12 +2,8 @@ use base2k::{FFT64, Module, ScalarZnxAlloc, ScratchOwned}; use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; use rlwe::{ elem::Infos, - encryption::EncryptSkScratchSpace, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, ggsw_ciphertext::GGSWCiphertext, - glwe::GLWECiphertext, + glwe_ciphertext::GLWECiphertext, keys::{SecretKey, SecretKeyFourier}, }; use sampling::source::Source; @@ -18,36 +14,38 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { struct Params { log_n: usize, basek: usize, - k_rlwe_in: usize, - k_rlwe_out: usize, - k_rgsw: usize, + k_ct_in: usize, + k_ct_out: usize, + k_ggsw: usize, + rank: usize, } fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); let basek: usize = p.basek; - let k_rlwe_in: usize = p.k_rlwe_in; - let k_rlwe_out: usize = p.k_rlwe_out; - let k_rgsw: usize = p.k_rgsw; + let k_ct_in: usize = p.k_ct_in; + let k_ct_out: usize = p.k_ct_out; + let k_ggsw: usize = p.k_ggsw; + let rank: usize = p.rank; - let rows: usize = (p.k_rlwe_in + p.basek - 1) / p.basek; + let rows: usize = (p.k_ct_in + p.basek - 1) / p.basek; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_rgsw, rows); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_in); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank); let pt_rgsw: base2k::ScalarZnx> = module.new_scalar_znx(1); let mut scratch = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) | GLWECiphertext::external_product_scratch_space( &module, ct_rlwe_out.size(), ct_rlwe_in.size(), ct_rgsw.size(), + rank, ), ); @@ -55,9 +53,9 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( @@ -67,7 +65,6 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -77,7 +74,6 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -94,9 +90,10 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let params_set: Vec = vec![Params { log_n: 10, basek: 7, - k_rlwe_in: 27, - k_rlwe_out: 27, - k_rgsw: 27, + k_ct_in: 27, + k_ct_out: 27, + k_ggsw: 27, + rank: 1, }]; for params in params_set { @@ -114,38 +111,39 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { struct Params { log_n: usize, basek: usize, - k_rlwe: usize, - k_rgsw: usize, + k_ct: usize, + k_ggsw: usize, + rank: usize, } fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); let basek: usize = p.basek; - let k_rlwe: usize = p.k_rlwe; - let k_rgsw: usize = p.k_rgsw; + let k_glwe: usize = p.k_ct; + let k_ggsw: usize = p.k_ggsw; + let rank: usize = p.rank; - let rows: usize = (p.k_rlwe + p.basek - 1) / p.basek; + let rows: usize = (p.k_ct + p.basek - 1) / p.basek; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_rgsw, rows); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_glwe, rank); let pt_rgsw: base2k::ScalarZnx> = module.new_scalar_znx(1); let mut scratch = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size(), rank), ); let mut source_xs = Source::new([0u8; 32]); let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( @@ -155,7 +153,6 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); @@ -165,13 +162,12 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); move || { let scratch_borrow = scratch.borrow(); - (0..1374).for_each(|i| { + (0..687).for_each(|_| { ct_rlwe.external_product_inplace( black_box(&module), black_box(&ct_rgsw), @@ -182,10 +178,11 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { } let params_set: Vec = vec![Params { - log_n: 9, + log_n: 12, basek: 18, - k_rlwe: 27, - k_rgsw: 27, + k_ct: 54, + k_ggsw: 54, + rank: 1, }]; for params in params_set { diff --git a/core/benches/keyswitch_glwe_fft64.rs b/core/benches/keyswitch_glwe_fft64.rs index 3a25360..1c1b7f8 100644 --- a/core/benches/keyswitch_glwe_fft64.rs +++ b/core/benches/keyswitch_glwe_fft64.rs @@ -1,12 +1,10 @@ -use base2k::{FFT64, Module, ScalarZnxAlloc, ScratchOwned}; +use base2k::{FFT64, Module, ScratchOwned}; use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; use rlwe::{ elem::Infos, - encryption::EncryptSkScratchSpace, - glwe::GLWECiphertext, + glwe_ciphertext::GLWECiphertext, keys::{SecretKey, SecretKeyFourier}, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - keyswitch_key::GLWEKeySwitchKey, + keyswitch_key::GLWESwitchingKey, }; use sampling::source::Source; @@ -16,36 +14,40 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { struct Params { log_n: usize, basek: usize, - k_rlwe_in: usize, - k_rlwe_out: usize, - k_grlwe: usize, + k_ct_in: usize, + k_ct_out: usize, + k_ksk: usize, + rank_in: usize, + rank_out: usize, } fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); let basek: usize = p.basek; - let k_rlwe_in: usize = p.k_rlwe_in; - let k_rlwe_out: usize = p.k_rlwe_out; - let k_grlwe: usize = p.k_grlwe; + let k_rlwe_in: usize = p.k_ct_in; + let k_rlwe_out: usize = p.k_ct_out; + let k_grlwe: usize = p.k_ksk; + let rank_in: usize = p.rank_in; + let rank_out: usize = p.rank_out; - let rows: usize = (p.k_rlwe_in + p.basek - 1) / p.basek; + let rows: usize = (p.k_ct_in + p.basek - 1) / p.basek; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, basek, k_grlwe, rows); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_in); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_out); - let pt_grlwe: base2k::ScalarZnx> = module.new_scalar_znx(1); + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_grlwe, rows, rank_in, rank_out); + let mut ct_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_in, rank_in); + let mut ct_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_out, rank_out); let mut scratch = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ksk.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size()) | GLWECiphertext::keyswitch_scratch_space( &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_grlwe.size(), + ct_out.size(), + ct_in.size(), + ksk.size(), + rank_in, + rank_out, ), ); @@ -53,37 +55,40 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); - sk_dft.dft(&module, &sk); + let mut sk_in: SecretKey> = SecretKey::new(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in); + sk_in_dft.dft(&module, &sk_in); - ct_grlwe.encrypt_sk( + let mut sk_out: SecretKey> = SecretKey::new(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( &module, - &pt_grlwe, - &sk_dft, + &sk_in, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); - ct_rlwe_in.encrypt_zero_sk( + ct_in.encrypt_zero_sk( &module, - &sk_dft, + &sk_in_dft, &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); move || { - ct_rlwe_out.keyswitch( + ct_out.keyswitch( black_box(&module), - black_box(&ct_rlwe_in), - black_box(&ct_grlwe), + black_box(&ct_in), + black_box(&ksk), black_box(scratch.borrow()), ); } @@ -92,9 +97,11 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let params_set: Vec = vec![Params { log_n: 16, basek: 50, - k_rlwe_in: 1250, - k_rlwe_out: 1250, - k_grlwe: 1250 + 66, + k_ct_in: 1250, + k_ct_out: 1250, + k_ksk: 1250 + 66, + rank_in: 1, + rank_out: 1, }]; for params in params_set { @@ -112,65 +119,68 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { struct Params { log_n: usize, basek: usize, - k_rlwe: usize, - k_grlwe: usize, + k_ct: usize, + k_ksk: usize, + rank: usize, } fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); let basek: usize = p.basek; - let k_rlwe: usize = p.k_rlwe; - let k_grlwe: usize = p.k_grlwe; + let k_ct: usize = p.k_ct; + let k_ksk: usize = p.k_ksk; + let rank: usize = p.rank; - let rows: usize = (p.k_rlwe + p.basek - 1) / p.basek; + let rows: usize = (p.k_ct + p.basek - 1) / p.basek; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, basek, k_grlwe, rows); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe); - let pt_grlwe: base2k::ScalarZnx> = module.new_scalar_znx(1); + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); let mut scratch = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct.size(), ksk.size(), rank), ); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); - sk_dft.dft(&module, &sk); + let mut sk_in: SecretKey> = SecretKey::new(&module, rank); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_in_dft.dft(&module, &sk_in); - ct_grlwe.encrypt_sk( + let mut sk_out: SecretKey> = SecretKey::new(&module, rank); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( &module, - &pt_grlwe, - &sk_dft, + &sk_in, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); - ct_rlwe.encrypt_zero_sk( + ct.encrypt_zero_sk( &module, - &sk_dft, + &sk_in_dft, &mut source_xa, &mut source_xe, sigma, - bound, scratch.borrow(), ); move || { - ct_rlwe.keyswitch_inplace( + ct.keyswitch_inplace( black_box(&module), - black_box(&ct_grlwe), + black_box(&ksk), black_box(scratch.borrow()), ); } @@ -179,8 +189,9 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let params_set: Vec = vec![Params { log_n: 9, basek: 18, - k_rlwe: 27, - k_grlwe: 27, + k_ct: 27, + k_ksk: 27, + rank: 1, }]; for params in params_set { diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index 2a86c63..ae4329c 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -82,7 +82,7 @@ where impl GGLWECiphertext, FFT64> { pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { - GLWECiphertext::encrypt_sk_scratch_space(module, rank, size) + GLWECiphertext::encrypt_sk_scratch_space(module, size) + module.bytes_of_vec_znx(rank + 1, size) + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(rank + 1, size) diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 625e09b..fa3b365 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -76,7 +76,7 @@ where impl GGSWCiphertext, FFT64> { pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { - GLWECiphertext::encrypt_sk_scratch_space(module, rank, size) + GLWECiphertext::encrypt_sk_scratch_space(module, size) + module.bytes_of_vec_znx(rank + 1, size) + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(rank + 1, size) diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 063ceb2..82e44da 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -97,12 +97,12 @@ where } impl GLWECiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, _rank: usize, ct_size: usize) -> usize { + pub fn encrypt_sk_scratch_space(module: &Module, ct_size: usize) -> usize { module.vec_znx_big_normalize_tmp_bytes() + module.bytes_of_vec_znx_dft(1, ct_size) + module.bytes_of_vec_znx_big(1, ct_size) } - pub fn encrypt_pk_scratch_space(module: &Module, _rank: usize, pk_size: usize) -> usize { + pub fn encrypt_pk_scratch_space(module: &Module, pk_size: usize) -> usize { ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) + module.bytes_of_scalar_znx_dft(1) + module.vec_znx_big_normalize_tmp_bytes() diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index a16aba8..b302d5e 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -80,7 +80,7 @@ impl GLWECiphertextFourier, FFT64> { } pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, ct_size: usize) -> usize { - module.bytes_of_vec_znx(rank + 1, ct_size) + GLWECiphertext::encrypt_sk_scratch_space(module, rank, ct_size) + module.bytes_of_vec_znx(rank + 1, ct_size) + GLWECiphertext::encrypt_sk_scratch_space(module, ct_size) } pub fn decrypt_scratch_space(module: &Module, ct_size: usize) -> usize { diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 008e761..4ceac95 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -69,8 +69,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) - | GLWECiphertext::decrypt_scratch_space(&module, ct.size()), + GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct.size()), ); let mut sk: SecretKey> = SecretKey::new(&module, rank); @@ -173,9 +172,9 @@ fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: pk.generate(&module, &sk_dft, &mut source_xa, &mut source_xe, sigma); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) + GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct.size()) - | GLWECiphertext::encrypt_pk_scratch_space(&module, rank, pk.size()), + | GLWECiphertext::encrypt_pk_scratch_space(&module, pk.size()), ); let mut data_want: Vec = vec![0i64; module.n()]; @@ -244,7 +243,7 @@ fn test_keyswitch( let mut scratch: ScratchOwned = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_in, ksk.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_in.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size()) | GLWECiphertext::keyswitch_scratch_space( &module, ct_out.size(), @@ -336,7 +335,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, let mut scratch: ScratchOwned = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size(), rank), ); @@ -438,7 +437,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) | GLWECiphertext::external_product_scratch_space( &module, ct_rlwe_out.size(), @@ -549,7 +548,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size(), rank), ); From 723a41acd0c37fd63306dd05973aca97f228c3b6 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 15 May 2025 10:45:06 +0200 Subject: [PATCH 69/87] fixed tests for ciphertext fourier --- base2k/examples/rlwe_encrypt.rs | 2 - base2k/src/lib.rs | 3 +- base2k/src/mat_znx_dft_ops.rs | 4 - core/src/gglwe_ciphertext.rs | 11 + core/src/glwe_ciphertext_fourier.rs | 8 +- core/src/test_fft64/ggsw.rs | 2 +- core/src/test_fft64/glwe.rs | 38 +- core/src/test_fft64/glwe_fourier.rs | 883 ++++++++++++++-------------- core/src/vec_glwe_product.rs | 12 +- 9 files changed, 487 insertions(+), 476 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 4db6ef5..e73db89 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -130,6 +130,4 @@ fn main() { .for_each(|(i, (a, b))| { println!("{}: {} {}", i, a, (*b as f64) / scale); }); - - module.free(); } diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index b6ed099..89a52ef 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -150,8 +150,7 @@ impl Scratch { unsafe { &mut *(data as *mut [u8] as *mut Self) } } - #[allow(dead_code)] - fn available(&self) -> usize { + pub fn available(&self) -> usize { let ptr: *const u8 = self.data.as_ptr(); let self_len: usize = self.data.len(); let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index f302e9b..7b4ac36 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -337,8 +337,6 @@ mod tests { assert_eq!(a_dft.raw(), b_dft.raw()); } } - - module.free(); } #[test] @@ -425,7 +423,5 @@ mod tests { }); }); }); - - module.free(); } } diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index ae4329c..d20072a 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -255,6 +255,17 @@ where assert_eq!(self.n(), module.n()); assert_eq!(res.n(), module.n()); assert_eq!(a.n(), module.n()); + assert!( + scratch.available() + >= GGLWECiphertext::prod_with_glwe_scratch_space( + module, + res.size(), + a.size(), + self.size(), + self.rank_in(), + self.rank_out() + ) + ); } let cols_in: usize = self.rank_in(); diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index b302d5e..fe2a50d 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -98,25 +98,25 @@ impl GLWECiphertextFourier, FFT64> { rank_in: usize, rank_out: usize, ) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space( + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_scratch_space( module, res_size, lhs, rhs, rank_in, rank_out, ) } pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_inplace_scratch_space( module, res_size, rhs, rank, ) } pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space( + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_scratch_space( module, res_size, lhs, rhs, rank, rank, ) } pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_inplace_scratch_space( module, res_size, rhs, rank, ) } diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index 9420831..eb8c532 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -567,7 +567,7 @@ // }); // }); // } -pub(crate) fn noise_rgsw_product( +pub(crate) fn noise_ggsw_gglwe_product( n: f64, log_base2k: usize, var_xs: f64, diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 4ceac95..21bae6d 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -13,7 +13,7 @@ use crate::{ glwe_plaintext::GLWEPlaintext, keys::{GLWEPublicKey, SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, - test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_rgsw_product}, + test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_gglwe_product}, }; #[test] @@ -58,6 +58,22 @@ fn keyswitch_inplace() { }); } +#[test] +fn external_product() { + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product(12, 12, 60, 45, 60, rank, 3.2); + }); +} + +#[test] +fn external_product_inplace() { + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product_inplace(12, 15, 60, 60, rank, 3.2); + }); +} + fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { let module: Module = Module::::new(1 << log_n); @@ -399,14 +415,6 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, ); } -#[test] -fn external_product() { - (1..4).for_each(|rank| { - println!("test external_product rank: {}", rank); - test_external_product(12, 12, 60, 45, 60, rank, 3.2); - }); -} - fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usize, k_ct_out: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); @@ -490,7 +498,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; - let noise_want: f64 = noise_rgsw_product( + let noise_want: f64 = noise_ggsw_gglwe_product( module.n() as f64, basek, 0.5, @@ -512,14 +520,6 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi ); } -#[test] -fn external_product_inplace() { - (1..4).for_each(|rank| { - println!("test external_product rank: {}", rank); - test_external_product_inplace(12, 15, 60, 60, rank, 3.2); - }); -} - fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); let rows: usize = (k_ct + basek - 1) / basek; @@ -595,7 +595,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; - let noise_want: f64 = noise_rgsw_product( + let noise_want: f64 = noise_ggsw_gglwe_product( module.n() as f64, basek, 0.5, diff --git a/core/src/test_fft64/glwe_fourier.rs b/core/src/test_fft64/glwe_fourier.rs index f25bac9..d5ed622 100644 --- a/core/src/test_fft64/glwe_fourier.rs +++ b/core/src/test_fft64/glwe_fourier.rs @@ -1,438 +1,445 @@ -// use crate::{ -// elem::Infos, -// ggsw_ciphertext::GGSWCiphertext, -// glwe_ciphertext::GLWECiphertext, -// glwe_ciphertext_fourier::GLWECiphertextFourier, -// glwe_plaintext::GLWEPlaintext, -// keys::{SecretKey, SecretKeyFourier}, -// keyswitch_key::GLWESwitchingKey, -// test_fft64::{gglwe::noise_grlwe_rlwe_product, ggsw::noise_rgsw_product}, -// }; -// use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; -// use sampling::source::Source; -// -// #[test] -// fn keyswitch() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let log_k_rlwe_in: usize = 45; -// let log_k_rlwe_out: usize = 60; -// let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; -// -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_grlwe: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); -// let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); -// let mut ct_rlwe_in_dft: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); -// let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); -// let mut ct_rlwe_out_dft: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out, rank); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); -// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// Random input plaintext -// pt_want -// .data -// .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) -// | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) -// | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) -// | GLWECiphertextFourier::keyswitch_scratch_space( -// &module, -// ct_rlwe_out.size(), -// ct_rlwe_in.size(), -// ct_grlwe.size(), -// ), -// ); -// -// let mut sk0: SecretKey> = SecretKey::new(&module, rank); -// sk0.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk0_dft.dft(&module, &sk0); -// -// let mut sk1: SecretKey> = SecretKey::new(&module, rank); -// sk1.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk1_dft.dft(&module, &sk1); -// -// ct_grlwe.encrypt_sk( -// &module, -// &sk0.data, -// &sk1_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rlwe_in.encrypt_sk( -// &module, -// &pt_want, -// &sk0_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); -// ct_rlwe_out_dft.keyswitch(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); -// ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); -// -// ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); -// -// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); -// -// let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); -// let noise_want: f64 = noise_grlwe_rlwe_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// 0.5, -// 0f64, -// sigma * sigma, -// 0f64, -// log_k_rlwe_in, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "{} {}", -// noise_have, -// noise_want -// ); -// } -// -// #[test] -// fn keyswich_inplace() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let log_k_rlwe: usize = 45; -// let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_grlwe: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); -// let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe, rank); -// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe, rank); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); -// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// Random input plaintext -// pt_want -// .data -// .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) -// | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) -// | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) -// | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), -// ); -// -// let mut sk0: SecretKey> = SecretKey::new(&module, rank); -// sk0.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk0_dft.dft(&module, &sk0); -// -// let mut sk1: SecretKey> = SecretKey::new(&module, rank); -// sk1.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk1_dft.dft(&module, &sk1); -// -// ct_grlwe.encrypt_sk( -// &module, -// &sk0.data, -// &sk1_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rlwe.encrypt_sk( -// &module, -// &pt_want, -// &sk0_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rlwe.dft(&module, &mut ct_rlwe_dft); -// ct_rlwe_dft.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); -// ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); -// -// ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); -// -// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); -// -// let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); -// let noise_want: f64 = noise_grlwe_rlwe_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// 0.5, -// 0f64, -// sigma * sigma, -// 0f64, -// log_k_rlwe, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "{} {}", -// noise_have, -// noise_want -// ); -// } -// -// #[test] -// fn external_product() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let log_k_rlwe_in: usize = 45; -// let log_k_rlwe_out: usize = 60; -// let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); -// let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); -// let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); -// let mut ct_rlwe_dft_in: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); -// let mut ct_rlwe_dft_out: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out, rank); -// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); -// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// Random input plaintext -// pt_want -// .data -// .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); -// -// pt_want.to_mut().at_mut(0, 0)[1] = 1; -// -// let k: usize = 1; -// -// pt_rgsw.raw_mut()[k] = 1; // X^{k} -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) -// | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) -// | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) -// | GLWECiphertext::external_product_scratch_space( -// &module, -// ct_rlwe_out.size(), -// ct_rlwe_in.size(), -// ct_rgsw.size(), -// ), -// ); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// ct_rgsw.encrypt_sk( -// &module, -// &pt_rgsw, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rlwe_in.encrypt_sk( -// &module, -// &pt_want, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rlwe_in.dft(&module, &mut ct_rlwe_dft_in); -// ct_rlwe_dft_out.external_product(&module, &ct_rlwe_dft_in, &ct_rgsw, scratch.borrow()); -// ct_rlwe_dft_out.idft(&module, &mut ct_rlwe_out, scratch.borrow()); -// -// ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); -// -// module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); -// -// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); -// -// let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); -// -// let var_gct_err_lhs: f64 = sigma * sigma; -// let var_gct_err_rhs: f64 = 0f64; -// -// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} -// let var_a0_err: f64 = sigma * sigma; -// let var_a1_err: f64 = 1f64 / 12f64; -// -// let noise_want: f64 = noise_rgsw_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// var_msg, -// var_a0_err, -// var_a1_err, -// var_gct_err_lhs, -// var_gct_err_rhs, -// log_k_rlwe_in, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "{} {}", -// noise_have, -// noise_want -// ); -// } -// -// #[test] -// fn external_product_inplace() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let log_k_rlwe_in: usize = 45; -// let log_k_rlwe_out: usize = 60; -// let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); -// let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); -// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); -// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); -// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// Random input plaintext -// pt_want -// .data -// .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); -// -// pt_want.to_mut().at_mut(0, 0)[1] = 1; -// -// let k: usize = 1; -// -// pt_rgsw.raw_mut()[k] = 1; // X^{k} -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) -// | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) -// | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) -// | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), -// ); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// ct_rgsw.encrypt_sk( -// &module, -// &pt_rgsw, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rlwe.encrypt_sk( -// &module, -// &pt_want, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rlwe.dft(&module, &mut ct_rlwe_dft); -// ct_rlwe_dft.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); -// ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); -// -// ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); -// -// module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); -// -// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); -// -// let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); -// -// let var_gct_err_lhs: f64 = sigma * sigma; -// let var_gct_err_rhs: f64 = 0f64; -// -// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} -// let var_a0_err: f64 = sigma * sigma; -// let var_a1_err: f64 = 1f64 / 12f64; -// -// let noise_want: f64 = noise_rgsw_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// var_msg, -// var_a0_err, -// var_a1_err, -// var_gct_err_lhs, -// var_gct_err_rhs, -// log_k_rlwe_in, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "{} {}", -// noise_have, -// noise_want -// ); -// } +use crate::{ + elem::Infos, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, + test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_gglwe_product}, +}; +use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; +use sampling::source::Source; + +#[test] +fn keyswitch() { + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + println!("test keyswitch rank_in: {} rank_out: {}", rank_in, rank_out); + test_keyswitch(12, 12, 60, 45, 60, rank_in, rank_out, 3.2); + }); + }); +} + +#[test] +fn keyswitch_inplace() { + (1..4).for_each(|rank| { + println!("test keyswitch_inplace rank: {}", rank); + test_keyswitch_inplace(12, 12, 60, 45, rank, 3.2); + }); +} + +#[test] +fn external_product() { + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product(12, 12, 60, 45, 60, rank, 3.2); + }); +} + +#[test] +fn external_product_inplace() { + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product_inplace(12, 15, 60, 60, rank, 3.2); + }); +} + +fn test_keyswitch( + log_n: usize, + basek: usize, + k_ksk: usize, + k_ct_in: usize, + k_ct_out: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k_ct_in + basek - 1) / basek; + + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in, rank_out); + let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank_in); + let mut ct_glwe_dft_in: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct_in, rank_in); + let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank_out); + let mut ct_glwe_dft_out: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, basek, k_ct_out, rank_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ksk.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_glwe_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_glwe_in.size()) + | GLWECiphertextFourier::keyswitch_scratch_space( + &module, + ct_glwe_out.size(), + ct_glwe_in.size(), + ksk.size(), + rank_in, + rank_out, + ), + ); + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe_in.encrypt_sk( + &module, + &pt_want, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe_in.dft(&module, &mut ct_glwe_dft_in); + ct_glwe_dft_out.keyswitch(&module, &ct_glwe_dft_in, &ksk, scratch.borrow()); + ct_glwe_dft_out.idft(&module, &mut ct_glwe_out, scratch.borrow()); + + ct_glwe_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank_in as f64, + k_ct_in, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct + basek - 1) / basek; + + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_glwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_glwe.size()) + | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, ct_rlwe_dft.size(), ksk.size(), rank), + ); + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe.encrypt_sk( + &module, + &pt_want, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.keyswitch_inplace(&module, &ksk, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct_glwe, scratch.borrow()); + + ct_glwe.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usize, k_ct_out: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k_ct_in + basek - 1) / basek; + + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank); + let mut ct_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank); + let mut ct_in_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct_in, rank); + let mut ct_out_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct_out, rank); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size()) + | GLWECiphertextFourier::external_product_scratch_space(&module, ct_out.size(), ct_in.size(), ct_rgsw.size(), rank), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.dft(&module, &mut ct_in_dft); + ct_out_dft.external_product(&module, &ct_in_dft, &ct_rgsw, scratch.borrow()); + ct_out_dft.idft(&module, &mut ct_out, scratch.borrow()); + + ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_gglwe_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ct_in, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct + basek - 1) / basek; + + let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct, rank); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) + | GLWECiphertextFourier::external_product_inplace_scratch_space(&module, ct.size(), ct_ggsw.size(), rank), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct_ggsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.external_product_inplace(&module, &ct_ggsw, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct, scratch.borrow()); + + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_gglwe_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ct, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} diff --git a/core/src/vec_glwe_product.rs b/core/src/vec_glwe_product.rs index d3e6636..08afa1e 100644 --- a/core/src/vec_glwe_product.rs +++ b/core/src/vec_glwe_product.rs @@ -23,7 +23,7 @@ pub(crate) trait VecGLWEProductScratchSpace { Self::prod_with_glwe_scratch_space(module, res_size, res_size, rhs, rank, rank) } - fn prod_with_glwe_dft_scratch_space( + fn prod_with_glwe_fourier_scratch_space( module: &Module, res_size: usize, lhs: usize, @@ -32,11 +32,11 @@ pub(crate) trait VecGLWEProductScratchSpace { rank_out: usize, ) -> usize { (Self::prod_with_glwe_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(rank_in, lhs) - + module.bytes_of_vec_znx(rank_out, res_size) + + module.bytes_of_vec_znx(rank_in + 1, lhs) + + module.bytes_of_vec_znx(rank_out + 1, res_size) } - fn prod_with_glwe_dft_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { + fn prod_with_glwe_fourier_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { (Self::prod_with_glwe_inplace_scratch_space(module, res_size, rhs, rank) | module.vec_znx_idft_tmp_bytes()) + module.bytes_of_vec_znx(rank + 1, res_size) } @@ -49,13 +49,13 @@ pub(crate) trait VecGLWEProductScratchSpace { rank_in: usize, rank_out: usize, ) -> usize { - Self::prod_with_glwe_dft_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out) + Self::prod_with_glwe_fourier_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out) + module.bytes_of_vec_znx_dft(rank_in + 1, lhs) + module.bytes_of_vec_znx_dft(rank_out + 1, res_size) } fn prod_with_vec_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - Self::prod_with_glwe_dft_inplace_scratch_space(module, res_size, rhs, rank) + Self::prod_with_glwe_fourier_inplace_scratch_space(module, res_size, rhs, rank) + module.bytes_of_vec_znx_dft(rank + 1, res_size) } } From ccd7450c5fc27610741d14567cd5628b07da9a43 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 15 May 2025 18:24:56 +0200 Subject: [PATCH 70/87] refactor of key-switching & external product --- base2k/src/module.rs | 214 +++--- base2k/src/vec_znx_dft_ops.rs | 41 +- core/src/gglwe_ciphertext.rs | 86 +-- core/src/ggsw_ciphertext.rs | 248 +++---- core/src/glwe_ciphertext.rs | 151 +++- core/src/glwe_ciphertext_fourier.rs | 187 +++-- core/src/glwe_plaintext.rs | 6 +- core/src/keyswitch_key.rs | 240 ++++++- core/src/lib.rs | 1 - core/src/test_fft64/gglwe.rs | 1009 ++++++++++++++------------- core/src/test_fft64/ggsw.rs | 916 +++++++++--------------- core/src/test_fft64/glwe.rs | 6 +- core/src/test_fft64/glwe_fourier.rs | 6 +- core/src/utils.rs | 4 +- core/src/vec_glwe_product.rs | 218 ------ 15 files changed, 1593 insertions(+), 1740 deletions(-) delete mode 100644 core/src/vec_glwe_product.rs diff --git a/base2k/src/module.rs b/base2k/src/module.rs index aab18b4..904d0ec 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -1,107 +1,107 @@ -use crate::GALOISGENERATOR; -use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info}; -use std::marker::PhantomData; - -#[derive(Copy, Clone)] -#[repr(u8)] -pub enum BACKEND { - FFT64, - NTT120, -} - -pub trait Backend { - const KIND: BACKEND; - fn module_type() -> u32; -} - -pub struct FFT64; -pub struct NTT120; - -impl Backend for FFT64 { - const KIND: BACKEND = BACKEND::FFT64; - fn module_type() -> u32 { - 0 - } -} - -impl Backend for NTT120 { - const KIND: BACKEND = BACKEND::NTT120; - fn module_type() -> u32 { - 1 - } -} - -pub struct Module { - pub ptr: *mut MODULE, - n: usize, - _marker: PhantomData, -} - -impl Module { - // Instantiates a new module. - pub fn new(n: usize) -> Self { - unsafe { - let m: *mut module_info_t = new_module_info(n as u64, B::module_type()); - if m.is_null() { - panic!("Failed to create module."); - } - Self { - ptr: m, - n: n, - _marker: PhantomData, - } - } - } - - pub fn n(&self) -> usize { - self.n - } - - pub fn log_n(&self) -> usize { - (usize::BITS - (self.n() - 1).leading_zeros()) as _ - } - - pub fn cyclotomic_order(&self) -> u64 { - (self.n() << 1) as _ - } - - // Returns GALOISGENERATOR^|generator| * sign(generator) - pub fn galois_element(&self, generator: i64) -> i64 { - if generator == 0 { - return 1; - } - ((mod_exp_u64(GALOISGENERATOR, generator.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * generator.signum() - } - - // Returns gen^-1 - pub fn galois_element_inv(&self, generator: i64) -> i64 { - if generator == 0 { - panic!("cannot invert 0") - } - ((mod_exp_u64( - generator.abs() as u64, - (self.cyclotomic_order() - 1) as usize, - ) & (self.cyclotomic_order() - 1)) as i64) - * generator.signum() - } -} - -impl Drop for Module { - fn drop(&mut self) { - unsafe { delete_module_info(self.ptr) } - } -} - -fn mod_exp_u64(x: u64, e: usize) -> u64 { - let mut y: u64 = 1; - let mut x_pow: u64 = x; - let mut exp = e; - while exp > 0 { - if exp & 1 == 1 { - y = y.wrapping_mul(x_pow); - } - x_pow = x_pow.wrapping_mul(x_pow); - exp >>= 1; - } - y -} +use crate::GALOISGENERATOR; +use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info}; +use std::marker::PhantomData; + +#[derive(Copy, Clone)] +#[repr(u8)] +pub enum BACKEND { + FFT64, + NTT120, +} + +pub trait Backend { + const KIND: BACKEND; + fn module_type() -> u32; +} + +pub struct FFT64; +pub struct NTT120; + +impl Backend for FFT64 { + const KIND: BACKEND = BACKEND::FFT64; + fn module_type() -> u32 { + 0 + } +} + +impl Backend for NTT120 { + const KIND: BACKEND = BACKEND::NTT120; + fn module_type() -> u32 { + 1 + } +} + +pub struct Module { + pub ptr: *mut MODULE, + n: usize, + _marker: PhantomData, +} + +impl Module { + // Instantiates a new module. + pub fn new(n: usize) -> Self { + unsafe { + let m: *mut module_info_t = new_module_info(n as u64, B::module_type()); + if m.is_null() { + panic!("Failed to create module."); + } + Self { + ptr: m, + n: n, + _marker: PhantomData, + } + } + } + + pub fn n(&self) -> usize { + self.n + } + + pub fn log_n(&self) -> usize { + (usize::BITS - (self.n() - 1).leading_zeros()) as _ + } + + pub fn cyclotomic_order(&self) -> u64 { + (self.n() << 1) as _ + } + + // Returns GALOISGENERATOR^|generator| * sign(generator) + pub fn galois_element(&self, generator: i64) -> i64 { + if generator == 0 { + return 1; + } + ((mod_exp_u64(GALOISGENERATOR, generator.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * generator.signum() + } + + // Returns gen^-1 + pub fn galois_element_inv(&self, generator: i64) -> i64 { + if generator == 0 { + panic!("cannot invert 0") + } + ((mod_exp_u64( + generator.abs() as u64, + (self.cyclotomic_order() - 1) as usize, + ) & (self.cyclotomic_order() - 1)) as i64) + * generator.signum() + } +} + +impl Drop for Module { + fn drop(&mut self) { + unsafe { delete_module_info(self.ptr) } + } +} + +fn mod_exp_u64(x: u64, e: usize) -> u64 { + let mut y: u64 = 1; + let mut x_pow: u64 = x; + let mut exp = e; + while exp > 0 { + if exp & 1 == 1 { + y = y.wrapping_mul(x_pow); + } + x_pow = x_pow.wrapping_mul(x_pow); + exp >>= 1; + } + y +} diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 282ef4d..27e6f59 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -42,8 +42,13 @@ pub trait VecZnxDftOps { /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. fn vec_znx_idft_tmp_bytes(&self) -> usize; + fn vec_znx_dft_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; + /// b <- IDFT(a), uses a as scratch space. - fn vec_znx_idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_cols: usize) + fn vec_znx_idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxDftToMut; @@ -79,13 +84,33 @@ impl VecZnxDftAlloc for Module { } impl VecZnxDftOps for Module { + fn vec_znx_dft_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let min_size: usize = min(res_mut.size(), a_ref.size()); + + (0..min_size).for_each(|j| { + res_mut + .at_mut(res_col, j) + .copy_from_slice(a_ref.at(a_col, j)); + }); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } + fn vec_znx_idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxDftToMut, { - let mut res_mut = res.to_mut(); - let mut a_mut = a.to_mut(); + let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); let min_size: usize = min(res_mut.size(), a_mut.size()); @@ -136,14 +161,14 @@ impl VecZnxDftOps for Module { /// b <- DFT(a) /// /// # Panics - /// If b.cols < a_cols + /// If b.cols < a_col fn vec_znx_dft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxToRef, { - let mut res_mut = res.to_mut(); - let a_ref = a.to_ref(); + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: crate::VecZnx<&[u8]> = a.to_ref(); let min_size: usize = min(res_mut.size(), a_ref.size()); @@ -170,8 +195,8 @@ impl VecZnxDftOps for Module { R: VecZnxBigToMut, A: VecZnxDftToRef, { - let mut res_mut = res.to_mut(); - let a_ref = a.to_ref(); + let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_idft_tmp_bytes()); diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index d20072a..7deb225 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -1,8 +1,7 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, - ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, - ZnxZero, + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, + VecZnxOps, ZnxInfos, ZnxZero, }; use sampling::source::Source; @@ -13,7 +12,6 @@ use crate::{ glwe_plaintext::GLWEPlaintext, keys::SecretKeyFourier, utils::derive_size, - vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; pub struct GGLWECiphertext { @@ -212,81 +210,3 @@ where module.vmp_prepare_row(self, row_i, col_j, a); } } - -impl VecGLWEProductScratchSpace for GGLWECiphertext, FFT64> { - fn prod_with_glwe_scratch_space( - module: &Module, - res_size: usize, - a_size: usize, - grlwe_size: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - module.bytes_of_vec_znx_dft(rank_out + 1, grlwe_size) - + (module.vec_znx_big_normalize_tmp_bytes() - | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, rank_in, rank_out + 1, grlwe_size) - + module.bytes_of_vec_znx_dft(rank_in, a_size))) - } -} - -impl VecGLWEProduct for GGLWECiphertext -where - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - fn prod_with_glwe( - &self, - module: &Module, - res: &mut GLWECiphertext, - a: &GLWECiphertext, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef, - { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.rank(), self.rank_in()); - assert_eq!(res.rank(), self.rank_out()); - assert_eq!(res.basek(), basek); - assert_eq!(a.basek(), basek); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - assert_eq!(a.n(), module.n()); - assert!( - scratch.available() - >= GGLWECiphertext::prod_with_glwe_scratch_space( - module, - res.size(), - a.size(), - self.size(), - self.rank_in(), - self.rank_out() - ) - ); - } - - let cols_in: usize = self.rank_in(); - let cols_out: usize = self.rank_out() + 1; - - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, self.size()); // Todo optimise - - { - let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, a.size()); - (0..cols_in).for_each(|col_i| { - module.vec_znx_dft(&mut ai_dft, col_i, a, col_i + 1); - }); - module.vmp_apply(&mut res_dft, &ai_dft, self, scratch2); - } - - let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - - module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); - - (0..cols_out).for_each(|i| { - module.vec_znx_big_normalize(basek, res, i, &res_big, i, scratch1); - }); - } -} diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index fa3b365..67f4774 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -1,35 +1,31 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, - ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, - ZnxZero, + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, + VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero, }; use sampling::source::Source; use crate::{ elem::{GetRow, Infos, SetRow}, - gglwe_ciphertext::GGLWECiphertext, glwe_ciphertext::GLWECiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, glwe_plaintext::GLWEPlaintext, keys::SecretKeyFourier, - keyswitch_key::GLWESwitchingKey, utils::derive_size, - vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; pub struct GGSWCiphertext { pub data: MatZnxDft, - pub log_base2k: usize, - pub log_k: usize, + pub basek: usize, + pub k: usize, } impl GGSWCiphertext, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize, rank: usize) -> Self { + pub fn new(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { Self { - data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, + data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)), + basek: basek, + k: k, } } } @@ -42,11 +38,11 @@ impl Infos for GGSWCiphertext { } fn basek(&self) -> usize { - self.log_base2k + self.basek } fn k(&self) -> usize { - self.log_k + self.k } } @@ -82,35 +78,28 @@ impl GGSWCiphertext, FFT64> { + module.bytes_of_vec_znx_dft(rank + 1, size) } - pub fn keyswitch_scratch_space( + pub fn external_product_scratch_space( module: &Module, - res_size: usize, - lhs: usize, - rhs: usize, - rank_in: usize, - rank_out: usize, + out_size: usize, + in_size: usize, + ggsw_size: usize, + rank: usize, ) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, rank_in, rank_out, - ) + let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size); + let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank); + tmp_in + tmp_out + ggsw } - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( - module, res_size, rhs, rank, - ) - } - - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, rank, rank, - ) - } - - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, rank, - ) + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank); + tmp + ggsw } } @@ -140,7 +129,7 @@ where } let size: usize = self.size(); - let log_base2k: usize = self.basek(); + let basek: usize = self.basek(); let k: usize = self.k(); let cols: usize = self.rank() + 1; @@ -149,20 +138,20 @@ where let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { data: tmp_znx_pt, - basek: log_base2k, + basek: basek, k: k, }; let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { data: tmp_znx_ct, - basek: log_base2k, + basek: basek, k, }; (0..self.rows()).for_each(|row_j| { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0); - module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2); + module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scrach_2); (0..cols).for_each(|col_i| { // rlwe encrypt of vec_znx_pt into vec_znx_ct @@ -193,30 +182,6 @@ where }); } - pub fn keyswitch( - &mut self, - module: &Module, - lhs: &GGSWCiphertext, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, - { - rhs.0.prod_with_vec_glwe(module, self, lhs, scratch); - } - - pub fn keyswitch_inplace( - &mut self, - module: &Module, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - { - rhs.0.prod_with_vec_glwe_inplace(module, self, scratch); - } - pub fn external_product( &mut self, module: &Module, @@ -227,7 +192,55 @@ where MatZnxDft: MatZnxDftToRef, MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_vec_glwe(module, self, lhs, scratch); + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank(), + lhs.rank(), + "ggsw_out rank: {} != ggsw_in rank: {}", + self.rank(), + lhs.rank() + ); + assert_eq!( + self.rank(), + rhs.rank(), + "ggsw_in rank: {} != ggsw_apply rank: {}", + self.rank(), + rhs.rank() + ); + } + + let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank() + 1, lhs.size()); + + let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_in_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); + + let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_out_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank() + 1).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_in); + tmp_out.external_product(module, &tmp_in, rhs, scratch2); + self.set_row(module, row_j, col_i, &tmp_out); + }); + }); + + tmp_out.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank() + 1).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_out); + }); + }); } pub fn external_product_inplace( @@ -238,7 +251,32 @@ where ) where MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_vec_glwe_inplace(module, self, scratch); + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank(), + rhs.rank(), + "ggsw_out rank: {} != ggsw_apply: {}", + self.rank(), + rhs.rank() + ); + } + + let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); + + let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank() + 1).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + self.get_row(module, row_j, col_i, &mut tmp); + tmp.external_product_inplace(module, rhs, scratch1); + self.set_row(module, row_j, col_i, &tmp); + }); + }); } } @@ -270,73 +308,3 @@ where module.vmp_prepare_row(self, row_i, col_j, a); } } - -impl VecGLWEProductScratchSpace for GGSWCiphertext, FFT64> { - fn prod_with_glwe_scratch_space( - module: &Module, - res_size: usize, - a_size: usize, - rgsw_size: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - module.bytes_of_vec_znx_dft(rank_out + 1, rgsw_size) - + ((module.bytes_of_vec_znx_dft(rank_in + 1, a_size) - + module.vmp_apply_tmp_bytes( - res_size, - a_size, - a_size, - rank_in + 1, - rank_out + 1, - rgsw_size, - )) - | module.vec_znx_big_normalize_tmp_bytes()) - } -} - -impl VecGLWEProduct for GGSWCiphertext -where - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - fn prod_with_glwe( - &self, - module: &Module, - res: &mut GLWECiphertext, - a: &GLWECiphertext, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef, - { - let log_base2k: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), a.rank()); - assert_eq!(self.rank(), res.rank()); - assert_eq!(res.basek(), log_base2k); - assert_eq!(a.basek(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - assert_eq!(a.n(), module.n()); - } - - let cols: usize = self.rank() + 1; - - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, self.size()); // Todo optimise - - { - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, a.size()); - (0..cols).for_each(|col_i| { - module.vec_znx_dft(&mut a_dft, col_i, a, col_i); - }); - module.vmp_apply(&mut res_dft, &a_dft, self, scratch2); - } - - let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - - (0..cols).for_each(|i| { - module.vec_znx_big_normalize(log_base2k, res, i, &res_big, i, scratch1); - }); - } -} diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 82e44da..1875a54 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -1,22 +1,20 @@ use base2k::{ - AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, - ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, - ZnxZero, + AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxAlloc, + ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, + VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, + VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, }; use sampling::source::Source; use crate::{ SIX_SIGMA, elem::Infos, - gglwe_ciphertext::GGLWECiphertext, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, glwe_plaintext::GLWEPlaintext, keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, utils::derive_size, - vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; pub struct GLWECiphertext { @@ -115,33 +113,50 @@ impl GLWECiphertext> { pub fn keyswitch_scratch_space( module: &Module, - res_size: usize, - lhs: usize, - rhs: usize, - rank_in: usize, - rank_out: usize, + out_size: usize, + out_rank: usize, + in_size: usize, + in_rank: usize, + ksk_size: usize, ) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space( - module, res_size, lhs, rhs, rank_in, rank_out, - ) + module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size) + + (module.vec_znx_big_normalize_tmp_bytes() + | (module.vmp_apply_tmp_bytes( + out_size, + in_size, + in_size, + in_rank + 1, + out_rank + 1, + ksk_size, + ) + module.bytes_of_vec_znx_dft(in_size, in_size))) } - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, rank, - ) + pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { + GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size) } - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space( - module, res_size, lhs, rhs, rank, rank, - ) + pub fn external_product_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + module.bytes_of_vec_znx_dft(rank + 1, ggsw_size) + + ((module.bytes_of_vec_znx_dft(rank + 1, in_size) + + module.vmp_apply_tmp_bytes( + out_size, + in_size, + in_size, // rows + rank + 1, // cols in + rank + 1, // cols out + ggsw_size, + )) + | module.vec_znx_big_normalize_tmp_bytes()) } pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, rank, - ) + GLWECiphertext::external_product_scratch_space(module, res_size, res_size, rhs, rank) } } @@ -235,7 +250,50 @@ where VecZnx: VecZnxToRef, MatZnxDft: MatZnxDftToRef, { - rhs.0.prod_with_glwe(module, self, lhs, scratch); + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(lhs.rank(), rhs.rank_in()); + assert_eq!(self.rank(), rhs.rank_out()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertext::keyswitch_scratch_space( + module, + self.size(), + self.rank(), + lhs.size(), + lhs.rank(), + rhs.size(), + ) + ); + } + + let cols_in: usize = rhs.rank_in(); + let cols_out: usize = rhs.rank_out() + 1; + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise + + { + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size()); + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft(&mut ai_dft, col_i, lhs, col_i + 1); + }); + module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2); + } + + let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + module.vec_znx_big_add_small_inplace(&mut res_big, 0, lhs, 0); + + (0..cols_out).for_each(|i| { + module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); + }); } pub fn keyswitch_inplace( @@ -246,7 +304,10 @@ where ) where MatZnxDft: MatZnxDftToRef, { - rhs.0.prod_with_glwe_inplace(module, self, scratch); + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + self.keyswitch(&module, &*self_ptr, rhs, scratch); + } } pub fn external_product( @@ -259,7 +320,36 @@ where VecZnx: VecZnxToRef, MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_glwe(module, self, lhs, scratch); + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(rhs.rank(), lhs.rank()); + assert_eq!(rhs.rank(), self.rank()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + } + + let cols: usize = rhs.rank() + 1; + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise + + { + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size()); + (0..cols).for_each(|col_i| { + module.vec_znx_dft(&mut a_dft, col_i, lhs, col_i); + }); + module.vmp_apply(&mut res_dft, &a_dft, rhs, scratch2); + } + + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + (0..cols).for_each(|i| { + module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); + }); } pub fn external_product_inplace( @@ -270,7 +360,10 @@ where ) where MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_glwe_inplace(module, self, scratch); + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + self.external_product(&module, &*self_ptr, rhs, scratch); + } } pub(crate) fn encrypt_sk_private( diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index fe2a50d..ebbe9cf 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -1,20 +1,13 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, - VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, - VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero, + Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, + ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, + VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero, }; use sampling::source::Source; use crate::{ - elem::Infos, - gglwe_ciphertext::GGLWECiphertext, - ggsw_ciphertext::GGSWCiphertext, - glwe_ciphertext::GLWECiphertext, - glwe_plaintext::GLWEPlaintext, - keys::SecretKeyFourier, - keyswitch_key::GLWESwitchingKey, - utils::derive_size, - vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, + elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, glwe_plaintext::GLWEPlaintext, + keys::SecretKeyFourier, keyswitch_key::GLWESwitchingKey, utils::derive_size, }; pub struct GLWECiphertextFourier { @@ -24,11 +17,11 @@ pub struct GLWECiphertextFourier { } impl GLWECiphertextFourier, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, rank: usize) -> Self { + pub fn new(module: &Module, basek: usize, k: usize, rank: usize) -> Self { Self { - data: module.new_vec_znx_dft(rank + 1, derive_size(log_base2k, log_k)), - basek: log_base2k, - k: log_k, + data: module.new_vec_znx_dft(rank + 1, derive_size(basek, k)), + basek: basek, + k: k, } } } @@ -92,33 +85,56 @@ impl GLWECiphertextFourier, FFT64> { pub fn keyswitch_scratch_space( module: &Module, - res_size: usize, - lhs: usize, - rhs: usize, - rank_in: usize, - rank_out: usize, + out_size: usize, + out_rank: usize, + in_size: usize, + in_rank: usize, + ksk_size: usize, ) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_scratch_space( - module, res_size, lhs, rhs, rank_in, rank_out, - ) + let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size); + + let vmp = module.bytes_of_vec_znx_dft(in_rank, in_size) + + module.vmp_apply_tmp_bytes( + out_size, + in_size, + in_size, + in_rank + 1, + out_rank + 1, + ksk_size, + ); + let res_small: usize = module.bytes_of_vec_znx(out_rank + 1, out_size); + let add_a0: usize = module.bytes_of_vec_znx_big(1, in_size) + module.vec_znx_idft_tmp_bytes(); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); + + res_dft + (vmp | add_a0 | (res_small + normalize)) } - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_inplace_scratch_space( - module, res_size, rhs, rank, - ) + pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { + Self::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size) } - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_scratch_space( - module, res_size, lhs, rhs, rank, rank, - ) + pub fn external_product_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size); + let res_small: usize = module.bytes_of_vec_znx(rank + 1, out_size); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); + + res_dft + (vmp | (res_small + normalize)) } - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_inplace_scratch_space( - module, res_size, rhs, rank, - ) + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + Self::external_product_scratch_space(module, out_size, out_size, ggsw_size, rank) } } @@ -158,7 +174,61 @@ where VecZnxDft: VecZnxDftToRef, MatZnxDft: MatZnxDftToRef, { - rhs.0.prod_with_glwe_fourier(module, self, lhs, scratch); + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(lhs.rank(), rhs.rank_in()); + assert_eq!(self.rank(), rhs.rank_out()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertextFourier::keyswitch_scratch_space( + module, + self.size(), + self.rank(), + lhs.size(), + lhs.rank(), + rhs.size(), + ) + ); + } + + let cols_in: usize = rhs.rank_in(); + let cols_out: usize = rhs.rank_out() + 1; + + // Buffer of the result of VMP in DFT + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise + + { + // Applies VMP + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size()); + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1); + }); + module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2); + } + + // Switches result of VMP outside of DFT + let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); + + { + // Switches lhs 0-th outside of DFT domain and adds on + let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size()); + module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2); + module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0); + } + + // Space fr normalized VMP result outside of DFT domain + let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols_out, lhs.size()); + (0..cols_out).for_each(|i| { + module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); + module.vec_znx_dft(self, i, &res_small, i); + }); } pub fn keyswitch_inplace( @@ -169,7 +239,10 @@ where ) where MatZnxDft: MatZnxDftToRef, { - rhs.0.prod_with_glwe_fourier_inplace(module, self, scratch); + unsafe { + let self_ptr: *mut GLWECiphertextFourier = self as *mut GLWECiphertextFourier; + self.keyswitch(&module, &*self_ptr, rhs, scratch); + } } pub fn external_product( @@ -182,7 +255,37 @@ where VecZnxDft: VecZnxDftToRef, MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_glwe_fourier(module, self, lhs, scratch); + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(rhs.rank(), lhs.rank()); + assert_eq!(rhs.rank(), self.rank()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + } + + let cols: usize = rhs.rank() + 1; + + // Space for VMP result in DFT domain and high precision + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); + + { + module.vmp_apply(&mut res_dft, lhs, rhs, scratch1); + } + + // VMP result in high precision + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); + + // Space for VMP result normalized + let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size()); + (0..cols).for_each(|i| { + module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); + module.vec_znx_dft(self, i, &res_small, i); + }); } pub fn external_product_inplace( @@ -193,7 +296,10 @@ where ) where MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_glwe_fourier_inplace(module, self, scratch); + unsafe { + let self_ptr: *mut GLWECiphertextFourier = self as *mut GLWECiphertextFourier; + self.external_product(&module, &*self_ptr, rhs, scratch); + } } } @@ -247,6 +353,7 @@ where pt.k = pt.k().min(self.k()); } + #[allow(dead_code)] pub(crate) fn idft(&self, module: &Module, res: &mut GLWECiphertext, scratch: &mut Scratch) where GLWECiphertext: VecZnxToMut, diff --git a/core/src/glwe_plaintext.rs b/core/src/glwe_plaintext.rs index 75088d1..4900fa0 100644 --- a/core/src/glwe_plaintext.rs +++ b/core/src/glwe_plaintext.rs @@ -43,10 +43,10 @@ where } impl GLWEPlaintext> { - pub fn new(module: &Module, base2k: usize, k: usize) -> Self { + pub fn new(module: &Module, basek: usize, k: usize) -> Self { Self { - data: module.new_vec_znx(1, derive_size(base2k, k)), - basek: base2k, + data: module.new_vec_znx(1, derive_size(basek, k)), + basek: basek, k, } } diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index 37774eb..8b9f13d 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -1,6 +1,6 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef, - ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, + ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, ZnxZero, }; use sampling::source::Source; @@ -10,7 +10,6 @@ use crate::{ ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, keys::{SecretKey, SecretKeyFourier}, - vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; pub struct GLWESwitchingKey(pub(crate) GGLWECiphertext); @@ -39,6 +38,20 @@ impl Infos for GLWESwitchingKey { } } +impl GLWESwitchingKey { + pub fn rank(&self) -> usize { + self.0.data.cols_out() - 1 + } + + pub fn rank_in(&self) -> usize { + self.0.data.cols_in() + } + + pub fn rank_out(&self) -> usize { + self.0.data.cols_out() - 1 + } +} + impl MatZnxDftToMut for GLWESwitchingKey where MatZnxDft: MatZnxDftToMut, @@ -131,33 +144,46 @@ where impl GLWESwitchingKey, FFT64> { pub fn keyswitch_scratch_space( module: &Module, - res_size: usize, - lhs: usize, - rhs: usize, - rank_in: usize, - rank_out: usize, + out_size: usize, + out_rank: usize, + in_size: usize, + in_rank: usize, + ksk_size: usize, ) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, rank_in, rank_out, - ) + let tmp_in: usize = module.bytes_of_vec_znx_dft(in_rank + 1, in_size); + let tmp_out: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size); + let ksk: usize = GLWECiphertextFourier::keyswitch_scratch_space(module, out_size, out_rank, in_size, in_rank, ksk_size); + tmp_in + tmp_out + ksk } - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( - module, res_size, rhs, rank, - ) + pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { + let tmp: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size); + let ksk: usize = GLWECiphertextFourier::keyswitch_inplace_scratch_space(module, out_size, out_rank, ksk_size); + tmp + ksk } - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, rank, rank, - ) + pub fn external_product_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size); + let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank); + tmp_in + tmp_out + ggsw } - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, rank, - ) + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank); + tmp + ggsw } } @@ -175,8 +201,62 @@ where MatZnxDft: MatZnxDftToRef, MatZnxDft: MatZnxDftToRef, { - rhs.0 - .prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch); + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank_in(), + "ksk_in output rank: {} != ksk_apply input rank: {}", + self.rank_out(), + rhs.rank_in() + ); + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + } + + let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size()); + + let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_in_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); + + let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_out_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_in); + tmp_out.keyswitch(module, &tmp_in, rhs, scratch2); + self.set_row(module, row_j, col_i, &tmp_out); + }); + }); + + tmp_out.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_out); + }); + }); } pub fn keyswitch_inplace( @@ -187,8 +267,32 @@ where ) where MatZnxDft: MatZnxDftToRef, { - rhs.0 - .prod_with_vec_glwe_inplace(module, &mut self.0, scratch); + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + } + + let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); + + let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + self.get_row(module, row_j, col_i, &mut tmp); + tmp.keyswitch_inplace(module, rhs, scratch1); + self.set_row(module, row_j, col_i, &tmp); + }); + }); } pub fn external_product( @@ -201,7 +305,62 @@ where MatZnxDft: MatZnxDftToRef, MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch); + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank(), + "ksk_in output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + assert_eq!( + self.rank_out(), + rhs.rank(), + "ksk_out output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + } + + let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size()); + + let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_in_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); + + let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_out_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_in); + tmp_out.external_product(module, &tmp_in, rhs, scratch2); + self.set_row(module, row_j, col_i, &tmp_out); + }); + }); + + tmp_out.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_out); + }); + }); } pub fn external_product_inplace( @@ -212,6 +371,31 @@ where ) where MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_vec_glwe_inplace(module, &mut self.0, scratch); + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_out(), + rhs.rank(), + "ksk_out output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + } + + let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); + + let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + self.get_row(module, row_j, col_i, &mut tmp); + tmp.external_product_inplace(module, rhs, scratch1); + self.set_row(module, row_j, col_i, &tmp); + }); + }); } } diff --git a/core/src/lib.rs b/core/src/lib.rs index 14392df..60d57c2 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -9,6 +9,5 @@ pub mod keyswitch_key; #[cfg(test)] mod test_fft64; mod utils; -pub mod vec_glwe_product; pub(crate) const SIX_SIGMA: f64 = 6.0; diff --git a/core/src/test_fft64/gglwe.rs b/core/src/test_fft64/gglwe.rs index 8327325..3ba02a0 100644 --- a/core/src/test_fft64/gglwe.rs +++ b/core/src/test_fft64/gglwe.rs @@ -1,505 +1,510 @@ -// use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; -// use sampling::source::Source; -// -// use crate::{ -// elem::{GetRow, Infos}, -// ggsw_ciphertext::GGSWCiphertext, -// glwe_ciphertext_fourier::GLWECiphertextFourier, -// glwe_plaintext::GLWEPlaintext, -// keys::{SecretKey, SecretKeyFourier}, -// keyswitch_key::GLWESwitchingKey, -// test_fft64::ggsw::noise_rgsw_product, -// }; -// -// #[test] -// fn encrypt_sk() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 8; -// let log_k_ct: usize = 54; -// let rows: usize = 4; -// let rank: usize = 1; -// let rank_out: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_ct, rows, rank, rank_out); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); -// let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), -// ); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// sk.fill_zero(); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// ct.encrypt_sk( -// &module, -// &pt_scalar, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); -// -// (0..ct.rows()).for_each(|row_i| { -// ct.get_row(&module, row_i, 0, &mut ct_rlwe_dft); -// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); -// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); -// let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); -// assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); -// }); -// } -// -// #[test] -// fn keyswitch() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; -// -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); -// let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); -// let mut ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) -// | GLWESwitchingKey::keyswitch_scratch_space( -// &module, -// ct_grlwe_s0s2.size(), -// ct_grlwe_s0s1.size(), -// ct_grlwe_s1s2.size(), -// ), -// ); -// -// let mut sk0: SecretKey> = SecretKey::new(&module, rank); -// sk0.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk0_dft.dft(&module, &sk0); -// -// let mut sk1: SecretKey> = SecretKey::new(&module, rank); -// sk1.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk1_dft.dft(&module, &sk1); -// -// let mut sk2: SecretKey> = SecretKey::new(&module, rank); -// sk2.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk2_dft.dft(&module, &sk2); -// -// GRLWE_{s1}(s0) = s0 -> s1 -// ct_grlwe_s0s1.encrypt_sk( -// &module, -// &sk0.data, -// &sk1_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// GRLWE_{s2}(s1) -> s1 -> s2 -// ct_grlwe_s1s2.encrypt_sk( -// &module, -// &sk1.data, -// &sk2_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) -// ct_grlwe_s0s2.keyswitch(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); -// -// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); -// -// (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { -// ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); -// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); -// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// let noise_want: f64 = noise_grlwe_rlwe_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// 0.5, -// 0f64, -// sigma * sigma, -// 0f64, -// log_k_grlwe, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "{} {}", -// noise_have, -// noise_want -// ); -// }); -// } -// -// #[test] -// fn keyswitch_inplace() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; -// -// let rank: usize = 1; -// let rank_out: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); -// let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) -// | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), -// ); -// -// let mut sk0: SecretKey> = SecretKey::new(&module, rank); -// sk0.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk0_dft.dft(&module, &sk0); -// -// let mut sk1: SecretKey> = SecretKey::new(&module, rank); -// sk1.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk1_dft.dft(&module, &sk1); -// -// let mut sk2: SecretKey> = SecretKey::new(&module, rank); -// sk2.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk2_dft.dft(&module, &sk2); -// -// GRLWE_{s1}(s0) = s0 -> s1 -// ct_grlwe_s0s1.encrypt_sk( -// &module, -// &sk0.data, -// &sk1_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// GRLWE_{s2}(s1) -> s1 -> s2 -// ct_grlwe_s1s2.encrypt_sk( -// &module, -// &sk1.data, -// &sk2_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) -// ct_grlwe_s0s1.keyswitch_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); -// -// let ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = ct_grlwe_s0s1; -// -// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); -// -// (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { -// ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); -// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); -// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// let noise_want: f64 = noise_grlwe_rlwe_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// 0.5, -// 0f64, -// sigma * sigma, -// 0f64, -// log_k_grlwe, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "{} {}", -// noise_have, -// noise_want -// ); -// }); -// } -// -// #[test] -// fn external_product() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; -// -// let rank: usize = 1; -// let rank_out: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_grlwe_in: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); -// let mut ct_grlwe_out: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); -// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); -// -// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); -// let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_in.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_out.size()) -// | GLWESwitchingKey::external_product_scratch_space( -// &module, -// ct_grlwe_out.size(), -// ct_grlwe_in.size(), -// ct_rgsw.size(), -// ) -// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), -// ); -// -// let k: usize = 1; -// -// pt_rgsw.raw_mut()[k] = 1; // X^{k} -// -// pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// GRLWE_{s1}(s0) = s0 -> s1 -// ct_grlwe_in.encrypt_sk( -// &module, -// &pt_grlwe, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw.encrypt_sk( -// &module, -// &pt_rgsw, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) -// ct_grlwe_out.external_product(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); -// -// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); -// -// module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); -// -// (0..ct_grlwe_out.rows()).for_each(|row_i| { -// ct_grlwe_out.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); -// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); -// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// -// let var_gct_err_lhs: f64 = sigma * sigma; -// let var_gct_err_rhs: f64 = 0f64; -// -// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} -// let var_a0_err: f64 = sigma * sigma; -// let var_a1_err: f64 = 1f64 / 12f64; -// -// let noise_want: f64 = noise_rgsw_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// var_msg, -// var_a0_err, -// var_a1_err, -// var_gct_err_lhs, -// var_gct_err_rhs, -// log_k_grlwe, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "{} {}", -// noise_have, -// noise_want -// ); -// }); -// } -// -// #[test] -// fn external_product_inplace() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; -// -// let rank: usize = 1; -// let rank_out: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_grlwe: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); -// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); -// -// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); -// let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe.size()) -// | GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) -// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), -// ); -// -// let k: usize = 1; -// -// pt_rgsw.raw_mut()[k] = 1; // X^{k} -// -// pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// GRLWE_{s1}(s0) = s0 -> s1 -// ct_grlwe.encrypt_sk( -// &module, -// &pt_grlwe, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw.encrypt_sk( -// &module, -// &pt_rgsw, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) -// ct_grlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); -// -// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); -// -// module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); -// -// (0..ct_grlwe.rows()).for_each(|row_i| { -// ct_grlwe.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); -// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); -// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// -// let var_gct_err_lhs: f64 = sigma * sigma; -// let var_gct_err_rhs: f64 = 0f64; -// -// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} -// let var_a0_err: f64 = sigma * sigma; -// let var_a1_err: f64 = 1f64 / 12f64; -// -// let noise_want: f64 = noise_rgsw_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// var_msg, -// var_a0_err, -// var_a1_err, -// var_gct_err_lhs, -// var_gct_err_rhs, -// log_k_grlwe, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "{} {}", -// noise_have, -// noise_want -// ); -// }); -// } +use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos}, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, +}; + +#[test] +fn encrypt_sk() { + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + println!("test encrypt_sk rank_in rank_out: {} {}", rank_in, rank_out); + test_encrypt_sk(11, 8, 54, 3.2, rank_in, rank_out); + }); + }); +} + +fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in: usize, rank_out: usize) { + let module: Module = Module::::new(1 << log_n); + let rows = (k_ksk + basek - 1) / basek; + + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in, rank_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ksk); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ksk.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ksk.size()), + ); + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ct_gglwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out); + + (0..ksk.rank_in()).for_each(|col_i| { + (0..ksk.rows()).for_each(|row_i| { + ksk.get_row(&module, row_i, 0, &mut ct_gglwe_fourier); + ct_gglwe_fourier.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i); + let std_pt: f64 = pt.data.std(0, basek) * (k_ksk as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + }); + }); +} + +#[test] +fn keyswitch() { + let module: Module = Module::::new(2048); + let basek: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + basek - 1) / basek; + + let rank: usize = 1; + + let sigma: f64 = 3.2; + + let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); + let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); + let mut ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) + | GLWESwitchingKey::keyswitch_scratch_space( + &module, + ct_grlwe_s0s2.size(), + ct_grlwe_s0s1.size(), + ct_grlwe_s1s2.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module, rank); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module, rank); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk1_dft.dft(&module, &sk1); + + let mut sk2: SecretKey> = SecretKey::new(&module, rank); + sk2.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk2_dft.dft(&module, &sk2); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_s0s1.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // GRLWE_{s2}(s1) -> s1 -> s2 + ct_grlwe_s1s2.encrypt_sk( + &module, + &sk1.data, + &sk2_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) + ct_grlwe_s0s2.keyswitch(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); + + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, basek, log_k_grlwe, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_grlwe); + + (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { + ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); +} + +#[test] +fn keyswitch_inplace() { + let module: Module = Module::::new(2048); + let basek: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + basek - 1) / basek; + + let rank: usize = 1; + let rank_out: usize = 1; + + let sigma: f64 = 3.2; + + let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); + let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) + | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module, rank); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module, rank); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk1_dft.dft(&module, &sk1); + + let mut sk2: SecretKey> = SecretKey::new(&module, rank); + sk2.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk2_dft.dft(&module, &sk2); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_s0s1.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // GRLWE_{s2}(s1) -> s1 -> s2 + ct_grlwe_s1s2.encrypt_sk( + &module, + &sk1.data, + &sk2_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) + ct_grlwe_s0s1.keyswitch_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); + + let ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = ct_grlwe_s0s1; + + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, basek, log_k_grlwe, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_grlwe); + + (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { + ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); +} + +#[test] +fn external_product() { + let module: Module = Module::::new(2048); + let basek: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + basek - 1) / basek; + + let rank: usize = 1; + let rank_out: usize = 1; + + let sigma: f64 = 3.2; + + let mut ct_grlwe_in: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); + let mut ct_grlwe_out: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); + + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_out.size()) + | GLWESwitchingKey::external_product_scratch_space( + &module, + ct_grlwe_out.size(), + ct_grlwe_in.size(), + ct_rgsw.size(), + ) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), + ); + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_in.encrypt_sk( + &module, + &pt_grlwe, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) + ct_grlwe_out.external_product(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); + + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, basek, log_k_grlwe, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_grlwe); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); + + (0..ct_grlwe_out.rows()).for_each(|row_i| { + ct_grlwe_out.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); +} + +#[test] +fn external_product_inplace() { + let module: Module = Module::::new(2048); + let basek: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + basek - 1) / basek; + + let rank: usize = 1; + let rank_out: usize = 1; + + let sigma: f64 = 3.2; + + let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); + + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe.size()) + | GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), + ); + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe.encrypt_sk( + &module, + &pt_grlwe, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) + ct_grlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); + + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, basek, log_k_grlwe, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_grlwe); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); + + (0..ct_grlwe.rows()).for_each(|row_i| { + ct_grlwe.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); +} pub(crate) fn noise_gglwe_product( n: f64, - log_base2k: usize, + basek: usize, var_xs: f64, var_msg: f64, var_a_err: f64, @@ -510,12 +515,12 @@ pub(crate) fn noise_gglwe_product( b_logq: usize, ) -> f64 { let a_logq: usize = a_logq.min(b_logq); - let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k; + let a_cols: usize = (a_logq + basek - 1) / basek; let b_scale = 2.0f64.powi(b_logq as i32); let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); - let base: f64 = (1 << (log_base2k)) as f64; + let base: f64 = (1 << (basek)) as f64; let var_base: f64 = base * base / 12f64; // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index eb8c532..cf34dda 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -1,575 +1,345 @@ -// use base2k::{ -// FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, -// VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, -// }; -// use sampling::source::Source; -// -// use crate::{ -// elem::{GetRow, Infos}, -// ggsw_ciphertext::GGSWCiphertext, -// glwe_ciphertext_fourier::GLWECiphertextFourier, -// glwe_plaintext::GLWEPlaintext, -// keys::{SecretKey, SecretKeyFourier}, -// keyswitch_key::GLWESwitchingKey, -// test_fft64::gglwe::noise_grlwe_rlwe_product, -// }; -// -// #[test] -// fn encrypt_sk() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 8; -// let log_k_ct: usize = 54; -// let rows: usize = 4; -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows, rank); -// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); -// let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), -// ); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// ct.encrypt_sk( -// &module, -// &pt_scalar, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); -// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); -// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); -// -// (0..ct.rank()).for_each(|col_j| { -// (0..ct.rows()).for_each(|row_i| { -// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); -// -// if col_j == 1 { -// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); -// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); -// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); -// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); -// } -// -// ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); -// -// ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); -// -// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); -// -// let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); -// assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); -// -// pt_want.data.zero(); -// }); -// }); -// } -// -// #[test] -// fn keyswitch() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let log_k_rgsw_in: usize = 45; -// let log_k_rgsw_out: usize = 45; -// let rows: usize = (log_k_rgsw_in + log_base2k - 1) / log_base2k; -// -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_grlwe: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); -// let mut ct_rgsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_in, rows, rank); -// let mut ct_rgsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_out, rows, rank); -// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// Random input plaintext -// pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_out.size()) -// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_in.size()) -// | GGSWCiphertext::keyswitch_scratch_space( -// &module, -// ct_rgsw_out.size(), -// ct_rgsw_in.size(), -// ct_grlwe.size(), -// ), -// ); -// -// let mut sk0: SecretKey> = SecretKey::new(&module, rank); -// sk0.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk0_dft.dft(&module, &sk0); -// -// let mut sk1: SecretKey> = SecretKey::new(&module, rank); -// sk1.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk1_dft.dft(&module, &sk1); -// -// ct_grlwe.encrypt_sk( -// &module, -// &sk0.data, -// &sk1_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw_in.encrypt_sk( -// &module, -// &pt_rgsw, -// &sk0_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw_out.keyswitch(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow()); -// -// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_out, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); -// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_out.size()); -// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_out.size()); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); -// -// (0..ct_rgsw_out.rank()).for_each(|col_j| { -// (0..ct_rgsw_out.rows()).for_each(|row_i| { -// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); -// -// if col_j == 1 { -// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); -// module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); -// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); -// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); -// } -// -// ct_rgsw_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); -// ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); -// -// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// let noise_want: f64 = noise_grlwe_rlwe_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// 0.5, -// 0f64, -// sigma * sigma, -// 0f64, -// log_k_grlwe, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.2, -// "have: {} want: {}", -// noise_have, -// noise_want -// ); -// -// pt_want.data.zero(); -// }); -// }); -// } -// -// #[test] -// fn keyswitch_inplace() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let log_k_rgsw: usize = 45; -// let rows: usize = (log_k_rgsw + log_base2k - 1) / log_base2k; -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_grlwe: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); -// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw, rows, rank); -// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// Random input plaintext -// pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw.size()) -// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) -// | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()), -// ); -// -// let mut sk0: SecretKey> = SecretKey::new(&module, rank); -// sk0.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk0_dft.dft(&module, &sk0); -// -// let mut sk1: SecretKey> = SecretKey::new(&module, rank); -// sk1.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk1_dft.dft(&module, &sk1); -// -// ct_grlwe.encrypt_sk( -// &module, -// &sk0.data, -// &sk1_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw.encrypt_sk( -// &module, -// &pt_rgsw, -// &sk0_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); -// -// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); -// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw.size()); -// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw.size()); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); -// -// (0..ct_rgsw.rank()).for_each(|col_j| { -// (0..ct_rgsw.rows()).for_each(|row_i| { -// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); -// -// if col_j == 1 { -// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); -// module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); -// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); -// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); -// } -// -// ct_rgsw.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); -// ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); -// -// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// let noise_want: f64 = noise_grlwe_rlwe_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// 0.5, -// 0f64, -// sigma * sigma, -// 0f64, -// log_k_grlwe, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.2, -// "have: {} want: {}", -// noise_have, -// noise_want -// ); -// -// pt_want.data.zero(); -// }); -// }); -// } -// -// #[test] -// fn external_product() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_rgsw_rhs: usize = 60; -// let log_k_rgsw_lhs_in: usize = 45; -// let log_k_rgsw_lhs_out: usize = 45; -// let rows: usize = (log_k_rgsw_lhs_in + log_base2k - 1) / log_base2k; -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); -// let mut ct_rgsw_lhs_in: GGSWCiphertext, FFT64> = -// GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_in, rows, rank); -// let mut ct_rgsw_lhs_out: GGSWCiphertext, FFT64> = -// GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_out, rows, rank); -// let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); -// let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// Random input plaintext -// pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); -// -// let k: usize = 1; -// -// pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size()) -// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs_in.size()) -// | GGSWCiphertext::external_product_scratch_space( -// &module, -// ct_rgsw_lhs_out.size(), -// ct_rgsw_lhs_in.size(), -// ct_rgsw_rhs.size(), -// ), -// ); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// ct_rgsw_rhs.encrypt_sk( -// &module, -// &pt_rgsw_rhs, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw_lhs_in.encrypt_sk( -// &module, -// &pt_rgsw_lhs, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw_lhs_out.external_product(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow()); -// -// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs_out, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); -// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs_out.size()); -// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs_out.size()); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); -// -// module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); -// -// (0..ct_rgsw_lhs_out.rank()).for_each(|col_j| { -// (0..ct_rgsw_lhs_out.rows()).for_each(|row_i| { -// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); -// -// if col_j == 1 { -// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); -// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); -// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); -// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); -// } -// -// ct_rgsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); -// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); -// -// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// -// let var_gct_err_lhs: f64 = sigma * sigma; -// let var_gct_err_rhs: f64 = 0f64; -// -// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} -// let var_a0_err: f64 = sigma * sigma; -// let var_a1_err: f64 = 1f64 / 12f64; -// -// let noise_want: f64 = noise_rgsw_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// var_msg, -// var_a0_err, -// var_a1_err, -// var_gct_err_lhs, -// var_gct_err_rhs, -// log_k_rgsw_lhs_in, -// log_k_rgsw_rhs, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "have: {} want: {}", -// noise_have, -// noise_want -// ); -// -// pt_want.data.zero(); -// }); -// }); -// } -// -// #[test] -// fn external_product_inplace() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_rgsw_rhs: usize = 60; -// let log_k_rgsw_lhs: usize = 45; -// let rows: usize = (log_k_rgsw_lhs + log_base2k - 1) / log_base2k; -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); -// let mut ct_rgsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows, rank); -// let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); -// let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// Random input plaintext -// pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); -// -// let k: usize = 1; -// -// pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs.size()) -// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs.size()) -// | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()), -// ); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// ct_rgsw_rhs.encrypt_sk( -// &module, -// &pt_rgsw_rhs, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw_lhs.encrypt_sk( -// &module, -// &pt_rgsw_lhs, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw_lhs.external_product_inplace(&module, &ct_rgsw_rhs, scratch.borrow()); -// -// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); -// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs.size()); -// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs.size()); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); -// -// module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); -// -// (0..ct_rgsw_lhs.rank()).for_each(|col_j| { -// (0..ct_rgsw_lhs.rows()).for_each(|row_i| { -// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); -// -// if col_j == 1 { -// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); -// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); -// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); -// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); -// } -// -// ct_rgsw_lhs.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); -// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); -// -// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// -// let var_gct_err_lhs: f64 = sigma * sigma; -// let var_gct_err_rhs: f64 = 0f64; -// -// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} -// let var_a0_err: f64 = sigma * sigma; -// let var_a1_err: f64 = 1f64 / 12f64; -// -// let noise_want: f64 = noise_rgsw_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// var_msg, -// var_a0_err, -// var_a1_err, -// var_gct_err_lhs, -// var_gct_err_rhs, -// log_k_rgsw_lhs, -// log_k_rgsw_rhs, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "have: {} want: {}", -// noise_have, -// noise_want -// ); -// -// pt_want.data.zero(); -// }); -// }); -// } -pub(crate) fn noise_ggsw_gglwe_product( +use base2k::{ + FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos}, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, +}; + +#[test] +fn encrypt_sk() { + (1..4).for_each(|rank| { + println!("test encrypt_sk rank: {}", rank); + test_encrypt_sk(11, 8, 54, 3.2, rank); + }); +} + +#[test] +fn external_product() { + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product(12, 12, 60, rank, 3.2); + }); +} + +#[test] +fn external_product_inplace() { + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product_inplace(12, 15, 60, rank, 3.2); + }); +} + +fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k_ggsw + basek - 1) / basek; + + let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); + + (0..ct.rank() + 1).for_each(|col_j| { + (0..ct.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let std_pt: f64 = pt_have.data.std(0, basek) * (k_ggsw as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + + pt_want.data.zero(); + }); + }); +} + +fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k_ggsw + basek - 1) / basek; + + let mut ct_ggsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_ggsw_lhs_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_ggsw_lhs_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut pt_ggsw_lhs: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_ggsw_rhs: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + + let k: usize = 1; + + pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size()) + | GGSWCiphertext::external_product_scratch_space( + &module, + ct_ggsw_lhs_out.size(), + ct_ggsw_lhs_in.size(), + ct_ggsw_rhs.size(), + rank, + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct_ggsw_rhs.encrypt_sk( + &module, + &pt_ggsw_rhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_ggsw_lhs_in.encrypt_sk( + &module, + &pt_ggsw_lhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_ggsw_lhs_out.external_product(&module, &ct_ggsw_lhs_in, &ct_ggsw_rhs, scratch.borrow()); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size()); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); + + (0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| { + (0..ct_ggsw_lhs_out.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0); + + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ggsw, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "have: {} want: {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); +} + +fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ggsw + basek - 1) / basek; + + let mut ct_ggsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_ggsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut pt_ggsw_lhs: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_ggsw_rhs: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + + let k: usize = 1; + + pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs.size()) + | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_ggsw_lhs.size(), ct_ggsw_rhs.size(), rank), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct_ggsw_rhs.encrypt_sk( + &module, + &pt_ggsw_rhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_ggsw_lhs.encrypt_sk( + &module, + &pt_ggsw_lhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_ggsw_lhs.external_product_inplace(&module, &ct_ggsw_rhs, scratch.borrow()); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs.size()); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); + + (0..ct_ggsw_lhs.rank() + 1).for_each(|col_j| { + (0..ct_ggsw_lhs.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0); + + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_ggsw_lhs.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ggsw, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "have: {} want: {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); +} +pub(crate) fn noise_ggsw_product( n: f64, - log_base2k: usize, + basek: usize, var_xs: f64, var_msg: f64, var_a0_err: f64, @@ -581,12 +351,12 @@ pub(crate) fn noise_ggsw_gglwe_product( b_logq: usize, ) -> f64 { let a_logq: usize = a_logq.min(b_logq); - let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k; + let a_cols: usize = (a_logq + basek - 1) / basek; let b_scale = 2.0f64.powi(b_logq as i32); let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); - let base: f64 = (1 << (log_base2k)) as f64; + let base: f64 = (1 << (basek)) as f64; let var_base: f64 = base * base / 12f64; // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 21bae6d..2d83791 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -13,7 +13,7 @@ use crate::{ glwe_plaintext::GLWEPlaintext, keys::{GLWEPublicKey, SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, - test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_gglwe_product}, + test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_product}, }; #[test] @@ -498,7 +498,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; - let noise_want: f64 = noise_ggsw_gglwe_product( + let noise_want: f64 = noise_ggsw_product( module.n() as f64, basek, 0.5, @@ -595,7 +595,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; - let noise_want: f64 = noise_ggsw_gglwe_product( + let noise_want: f64 = noise_ggsw_product( module.n() as f64, basek, 0.5, diff --git a/core/src/test_fft64/glwe_fourier.rs b/core/src/test_fft64/glwe_fourier.rs index d5ed622..c737c55 100644 --- a/core/src/test_fft64/glwe_fourier.rs +++ b/core/src/test_fft64/glwe_fourier.rs @@ -6,7 +6,7 @@ use crate::{ glwe_plaintext::GLWEPlaintext, keys::{SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, - test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_gglwe_product}, + test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_product}, }; use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; use sampling::source::Source; @@ -322,7 +322,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; - let noise_want: f64 = noise_ggsw_gglwe_product( + let noise_want: f64 = noise_ggsw_product( module.n() as f64, basek, 0.5, @@ -422,7 +422,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; - let noise_want: f64 = noise_ggsw_gglwe_product( + let noise_want: f64 = noise_ggsw_product( module.n() as f64, basek, 0.5, diff --git a/core/src/utils.rs b/core/src/utils.rs index 0bb0b45..c3bc5d5 100644 --- a/core/src/utils.rs +++ b/core/src/utils.rs @@ -1,3 +1,3 @@ -pub(crate) fn derive_size(log_base2k: usize, log_k: usize) -> usize { - (log_k + log_base2k - 1) / log_base2k +pub(crate) fn derive_size(basek: usize, k: usize) -> usize { + (k + basek - 1) / basek } diff --git a/core/src/vec_glwe_product.rs b/core/src/vec_glwe_product.rs deleted file mode 100644 index 08afa1e..0000000 --- a/core/src/vec_glwe_product.rs +++ /dev/null @@ -1,218 +0,0 @@ -use base2k::{ - FFT64, Module, Scratch, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, - VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, -}; - -use crate::{ - elem::{GetRow, Infos, SetRow}, - glwe_ciphertext::GLWECiphertext, - glwe_ciphertext_fourier::GLWECiphertextFourier, -}; - -pub(crate) trait VecGLWEProductScratchSpace { - fn prod_with_glwe_scratch_space( - module: &Module, - res_size: usize, - lhs: usize, - rhs: usize, - rank_in: usize, - rank_out: usize, - ) -> usize; - - fn prod_with_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - Self::prod_with_glwe_scratch_space(module, res_size, res_size, rhs, rank, rank) - } - - fn prod_with_glwe_fourier_scratch_space( - module: &Module, - res_size: usize, - lhs: usize, - rhs: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - (Self::prod_with_glwe_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(rank_in + 1, lhs) - + module.bytes_of_vec_znx(rank_out + 1, res_size) - } - - fn prod_with_glwe_fourier_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - (Self::prod_with_glwe_inplace_scratch_space(module, res_size, rhs, rank) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(rank + 1, res_size) - } - - fn prod_with_vec_glwe_scratch_space( - module: &Module, - res_size: usize, - lhs: usize, - rhs: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - Self::prod_with_glwe_fourier_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out) - + module.bytes_of_vec_znx_dft(rank_in + 1, lhs) - + module.bytes_of_vec_znx_dft(rank_out + 1, res_size) - } - - fn prod_with_vec_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - Self::prod_with_glwe_fourier_inplace_scratch_space(module, res_size, rhs, rank) - + module.bytes_of_vec_znx_dft(rank + 1, res_size) - } -} - -pub(crate) trait VecGLWEProduct: Infos { - fn prod_with_glwe( - &self, - module: &Module, - res: &mut GLWECiphertext, - a: &GLWECiphertext, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef; - - fn prod_with_glwe_inplace(&self, module: &Module, res: &mut GLWECiphertext, scratch: &mut Scratch) - where - VecZnx: VecZnxToMut + VecZnxToRef, - { - unsafe { - let res_ptr: *mut GLWECiphertext = res as *mut GLWECiphertext; // This is ok because [Self::mul_rlwe] only updates res at the end. - self.prod_with_glwe(&module, &mut *res_ptr, &*res_ptr, scratch); - } - } - - fn prod_with_glwe_fourier( - &self, - module: &Module, - res: &mut GLWECiphertextFourier, - a: &GLWECiphertextFourier, - scratch: &mut Scratch, - ) where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + ZnxInfos, - { - let log_base2k: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.basek(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (a_data, scratch_1) = scratch.tmp_vec_znx(module, a.rank() + 1, a.size()); - - let mut a_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { - data: a_data, - basek: a.basek(), - k: a.k(), - }; - - a.idft(module, &mut a_idft, scratch_1); - - let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, res.rank() + 1, res.size()); - - let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { - data: res_data, - basek: res.basek(), - k: res.k(), - }; - - self.prod_with_glwe(module, &mut res_idft, &a_idft, scratch_2); - - res_idft.dft(module, res); - } - - fn prod_with_glwe_fourier_inplace( - &self, - module: &Module, - res: &mut GLWECiphertextFourier, - scratch: &mut Scratch, - ) where - VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, - { - let log_base2k: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.basek(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (res_data, scratch_1) = scratch.tmp_vec_znx(module, res.rank() + 1, res.size()); - - let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { - data: res_data, - basek: res.basek(), - k: res.k(), - }; - - res.idft(module, &mut res_idft, scratch_1); - - self.prod_with_glwe_inplace(module, &mut res_idft, scratch_1); - - res_idft.dft(module, res); - } - - fn prod_with_vec_glwe(&self, module: &Module, res: &mut RES, a: &LHS, scratch: &mut Scratch) - where - LHS: GetRow + Infos, - RES: SetRow + Infos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, a.cols(), a.size()); - - let mut tmp_a_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_row_data, - basek: a.basek(), - k: a.k(), - }; - - let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, res.cols(), res.size()); - - let mut tmp_res_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_res_data, - basek: res.basek(), - k: res.k(), - }; - - let min_rows: usize = res.rows().min(a.rows()); - - (0..res.rows()).for_each(|row_i| { - (0..res.cols()).for_each(|col_j| { - a.get_row(module, row_i, col_j, &mut tmp_a_row); - self.prod_with_glwe_fourier(module, &mut tmp_res_row, &tmp_a_row, scratch2); - res.set_row(module, row_i, col_j, &tmp_res_row); - }); - }); - - tmp_res_row.data.zero(); - - (min_rows..res.rows()).for_each(|row_i| { - (0..self.cols()).for_each(|col_j| { - res.set_row(module, row_i, col_j, &tmp_res_row); - }); - }); - } - - fn prod_with_vec_glwe_inplace(&self, module: &Module, res: &mut RES, scratch: &mut Scratch) - where - RES: GetRow + SetRow + Infos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, res.cols(), res.size()); - - let mut tmp_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_row_data, - basek: res.basek(), - k: res.k(), - }; - - (0..res.rows()).for_each(|row_i| { - (0..res.cols()).for_each(|col_j| { - res.get_row(module, row_i, col_j, &mut tmp_row); - self.prod_with_glwe_fourier_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, col_j, &tmp_row); - }); - }); - } -} From 49a08289db6441bd6be792aa49c942ad750520c8 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 16 May 2025 09:47:04 +0200 Subject: [PATCH 71/87] base2k: fixed buffer zeroing overflow --- base2k/src/znx_base.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index a168e18..c8636f3 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -113,7 +113,7 @@ where fn zero_at(&mut self, i: usize, j: usize) { unsafe { - std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.sl()); + std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n()); } } } From b80bcb8bbd2a58ab1f030be357a05d3100dfc18f Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 16 May 2025 09:56:39 +0200 Subject: [PATCH 72/87] fixed another buffer overflow of coefficient zeroing --- base2k/src/znx_base.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index c8636f3..f618446 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -107,7 +107,7 @@ where { fn zero(&mut self) { unsafe { - std::ptr::write_bytes(self.as_mut_ptr(), 0, self.sl() * self.poly_count()); + std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count()); } } From c86af112eba8c6e4e72623bb59966dec1149c528 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 16 May 2025 10:22:42 +0200 Subject: [PATCH 73/87] All test passing --- core/src/glwe_ciphertext.rs | 46 +-- core/src/keyswitch_key.rs | 8 - core/src/test_fft64/gglwe.rs | 562 +++++++++++++++------------- core/src/test_fft64/glwe.rs | 8 +- core/src/test_fft64/glwe_fourier.rs | 6 +- 5 files changed, 341 insertions(+), 289 deletions(-) diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 1875a54..f0dbae1 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -119,16 +119,18 @@ impl GLWECiphertext> { in_rank: usize, ksk_size: usize, ) -> usize { - module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size) - + (module.vec_znx_big_normalize_tmp_bytes() - | (module.vmp_apply_tmp_bytes( - out_size, - in_size, - in_size, - in_rank + 1, - out_rank + 1, - ksk_size, - ) + module.bytes_of_vec_znx_dft(in_size, in_size))) + let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size); + let vmp: usize = module.vmp_apply_tmp_bytes( + out_size, + in_size, + in_size, + in_rank + 1, + out_rank + 1, + ksk_size, + ) + module.bytes_of_vec_znx_dft(in_rank, in_size); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); + + return res_dft + (vmp | normalize); } pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { @@ -142,17 +144,19 @@ impl GLWECiphertext> { ggsw_size: usize, rank: usize, ) -> usize { - module.bytes_of_vec_znx_dft(rank + 1, ggsw_size) - + ((module.bytes_of_vec_znx_dft(rank + 1, in_size) - + module.vmp_apply_tmp_bytes( - out_size, - in_size, - in_size, // rows - rank + 1, // cols in - rank + 1, // cols out - ggsw_size, - )) - | module.vec_znx_big_normalize_tmp_bytes()) + let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, ggsw_size); + let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) + + module.vmp_apply_tmp_bytes( + out_size, + in_size, + in_size, // rows + rank + 1, // cols in + rank + 1, // cols out + ggsw_size, + ); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); + + res_dft + (vmp | normalize) } pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index 8b9f13d..34595e3 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -78,10 +78,6 @@ where where VecZnxDft: VecZnxDftToMut, { - #[cfg(debug_assertions)] - { - assert_eq!(col_j, 0); - } module.vmp_extract_row(res, self, row_i, col_j); } } @@ -94,10 +90,6 @@ where where VecZnxDft: VecZnxDftToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(col_j, 0); - } module.vmp_prepare_row(self, row_i, col_j, a); } } diff --git a/core/src/test_fft64/gglwe.rs b/core/src/test_fft64/gglwe.rs index 3ba02a0..ff4bcfe 100644 --- a/core/src/test_fft64/gglwe.rs +++ b/core/src/test_fft64/gglwe.rs @@ -1,4 +1,4 @@ -use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; +use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToMut, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; use sampling::source::Source; use crate::{ @@ -8,6 +8,7 @@ use crate::{ glwe_plaintext::GLWEPlaintext, keys::{SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, + test_fft64::ggsw::noise_ggsw_product, }; #[test] @@ -15,7 +16,58 @@ fn encrypt_sk() { (1..4).for_each(|rank_in| { (1..4).for_each(|rank_out| { println!("test encrypt_sk rank_in rank_out: {} {}", rank_in, rank_out); - test_encrypt_sk(11, 8, 54, 3.2, rank_in, rank_out); + test_encrypt_sk(12, 8, 54, 3.2, rank_in, rank_out); + }); + }); +} + +#[test] +fn key_switch() { + (1..4).for_each(|rank_in_s0s1| { + (1..4).for_each(|rank_out_s0s1| { + (1..4).for_each(|rank_out_s1s2| { + println!( + "test key_switch : ({},{},{})", + rank_in_s0s1, rank_out_s0s1, rank_out_s1s2 + ); + test_key_switch(12, 15, 60, 3.2, rank_in_s0s1, rank_out_s0s1, rank_out_s1s2); + }) + }); + }); +} + +#[test] +fn key_switch_inplace() { + (1..4).for_each(|rank_in_s0s1| { + (1..4).for_each(|rank_out_s0s1| { + println!( + "test key_switch_inplace : ({},{})", + rank_in_s0s1, rank_out_s0s1 + ); + test_key_switch_inplace(12, 15, 60, 3.2, rank_in_s0s1, rank_out_s0s1); + }); + }); +} + +#[test] +fn external_product() { + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + println!("test external_product rank: {} {}", rank_in, rank_out); + test_external_product(12, 12, 60, 3.2, rank_in, rank_out); + }); + }); +} + +#[test] +fn external_product_inplace() { + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + println!( + "test external_product_inplace rank: {} {}", + rank_in, rank_out + ); + test_external_product_inplace(12, 12, 60, 3.2, rank_in, rank_out); }); }); } @@ -58,12 +110,12 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in scratch.borrow(), ); - let mut ct_gglwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out); + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out); (0..ksk.rank_in()).for_each(|col_i| { (0..ksk.rows()).for_each(|row_i| { - ksk.get_row(&module, row_i, 0, &mut ct_gglwe_fourier); - ct_gglwe_fourier.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); + ksk.get_row(&module, row_i, col_i, &mut ct_glwe_fourier); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i); let std_pt: f64 = pt.data.std(0, basek) * (k_ksk as f64).exp2(); assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); @@ -71,61 +123,64 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in }); } -#[test] -fn keyswitch() { - let module: Module = Module::::new(2048); - let basek: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + basek - 1) / basek; +fn test_key_switch( + log_n: usize, + basek: usize, + k_ksk: usize, + sigma: f64, + rank_in_s0s1: usize, + rank_out_s0s1: usize, + rank_out_s1s2: usize, +) { + let module: Module = Module::::new(1 << log_n); + let rows = (k_ksk + basek - 1) / basek; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); - let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); - let mut ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); + let mut ct_gglwe_s0s1: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s0s1); + let mut ct_gglwe_s1s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_out_s0s1, rank_out_s1s2); + let mut ct_gglwe_s0s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s1s2); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_in_s0s1 | rank_out_s0s1, ct_gglwe_s0s1.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_s0s2.size()) | GLWESwitchingKey::keyswitch_scratch_space( &module, - ct_grlwe_s0s2.size(), - ct_grlwe_s0s1.size(), - ct_grlwe_s1s2.size(), + ct_gglwe_s0s2.size(), + ct_gglwe_s0s2.rank(), + ct_gglwe_s0s1.size(), + ct_gglwe_s0s1.rank(), + ct_gglwe_s1s2.size(), ), ); - let mut sk0: SecretKey> = SecretKey::new(&module, rank); + let mut sk0: SecretKey> = SecretKey::new(&module, rank_in_s0s1); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in_s0s1); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module, rank); + let mut sk1: SecretKey> = SecretKey::new(&module, rank_out_s0s1); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out_s0s1); sk1_dft.dft(&module, &sk1); - let mut sk2: SecretKey> = SecretKey::new(&module, rank); + let mut sk2: SecretKey> = SecretKey::new(&module, rank_out_s1s2); sk2.fill_ternary_prob(0.5, &mut source_xs); - let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out_s1s2); sk2_dft.dft(&module, &sk2); - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_s0s1.encrypt_sk( + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe_s0s1.encrypt_sk( &module, - &sk0.data, + &sk0, &sk1_dft, &mut source_xa, &mut source_xe, @@ -133,10 +188,10 @@ fn keyswitch() { scratch.borrow(), ); - // GRLWE_{s2}(s1) -> s1 -> s2 - ct_grlwe_s1s2.encrypt_sk( + // gglwe_{s2}(s1) -> s1 -> s2 + ct_gglwe_s1s2.encrypt_sk( &module, - &sk1.data, + &sk1, &sk2_dft, &mut source_xa, &mut source_xe, @@ -144,89 +199,88 @@ fn keyswitch() { scratch.borrow(), ); - // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s0s2.keyswitch(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + ct_gglwe_s0s2.keyswitch(&module, &ct_gglwe_s0s1, &ct_gglwe_s1s2, scratch.borrow()); - let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, basek, log_k_grlwe, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_grlwe); + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out_s1s2); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ksk); - (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { - ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); + (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { + (0..ct_gglwe_s0s2.rows()).for_each(|row_i| { + ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft); + ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, col_i); - let noise_have: f64 = pt.data.std(0, basek).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - basek, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_grlwe, - log_k_grlwe, - ); + let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_want: f64 = noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank_out_s0s1 as f64, + k_ksk, + k_ksk, + ); - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); }); } -#[test] -fn keyswitch_inplace() { - let module: Module = Module::::new(2048); - let basek: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + basek - 1) / basek; +fn test_key_switch_inplace(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in_s0s1: usize, rank_out_s0s1: usize) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ksk + basek - 1) / basek; - let rank: usize = 1; - let rank_out: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); - let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); + let mut ct_gglwe_s0s1: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s0s1); + let mut ct_gglwe_s1s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_out_s0s1, rank_out_s0s1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) - | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out_s0s1, ct_gglwe_s0s1.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_s0s1.size()) + | GLWESwitchingKey::keyswitch_inplace_scratch_space( + &module, + ct_gglwe_s0s1.size(), + ct_gglwe_s0s1.rank(), + ct_gglwe_s1s2.size(), + ), ); - let mut sk0: SecretKey> = SecretKey::new(&module, rank); + let mut sk0: SecretKey> = SecretKey::new(&module, rank_in_s0s1); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in_s0s1); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module, rank); + let mut sk1: SecretKey> = SecretKey::new(&module, rank_out_s0s1); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out_s0s1); sk1_dft.dft(&module, &sk1); - let mut sk2: SecretKey> = SecretKey::new(&module, rank); + let mut sk2: SecretKey> = SecretKey::new(&module, rank_out_s0s1); sk2.fill_ternary_prob(0.5, &mut source_xs); - let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out_s0s1); sk2_dft.dft(&module, &sk2); - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_s0s1.encrypt_sk( + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe_s0s1.encrypt_sk( &module, - &sk0.data, + &sk0, &sk1_dft, &mut source_xa, &mut source_xe, @@ -234,10 +288,10 @@ fn keyswitch_inplace() { scratch.borrow(), ); - // GRLWE_{s2}(s1) -> s1 -> s2 - ct_grlwe_s1s2.encrypt_sk( + // gglwe_{s2}(s1) -> s1 -> s2 + ct_gglwe_s1s2.encrypt_sk( &module, - &sk1.data, + &sk1, &sk2_dft, &mut source_xa, &mut source_xe, @@ -245,96 +299,93 @@ fn keyswitch_inplace() { scratch.borrow(), ); - // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s0s1.keyswitch_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + ct_gglwe_s0s1.keyswitch_inplace(&module, &ct_gglwe_s1s2, scratch.borrow()); - let ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = ct_grlwe_s0s1; + let ct_gglwe_s0s2: GLWESwitchingKey, FFT64> = ct_gglwe_s0s1; - let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, basek, log_k_grlwe, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_grlwe); + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out_s0s1); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ksk); - (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { - ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); + (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { + (0..ct_gglwe_s0s2.rows()).for_each(|row_i| { + ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft); + ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, col_i); - let noise_have: f64 = pt.data.std(0, basek).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - basek, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_grlwe, - log_k_grlwe, - ); + let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_want: f64 = noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank_out_s0s1 as f64, + k_ksk, + k_ksk, + ); - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); }); } -#[test] -fn external_product() { - let module: Module = Module::::new(2048); - let basek: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + basek - 1) / basek; +fn test_external_product(log_n: usize, basek: usize, k: usize, sigma: f64, rank_in: usize, rank_out: usize) { + let module: Module = Module::::new(1 << log_n); - let rank: usize = 1; - let rank_out: usize = 1; + let rows: usize = (k + basek - 1) / basek; - let sigma: f64 = 3.2; - - let mut ct_grlwe_in: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); - let mut ct_grlwe_out: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); + let mut ct_gglwe_in: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank_in, rank_out); + let mut ct_gglwe_out: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank_in, rank_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank_out); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_in.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_out.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ct_gglwe_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_out.size()) | GLWESwitchingKey::external_product_scratch_space( &module, - ct_grlwe_out.size(), - ct_grlwe_in.size(), + ct_gglwe_out.size(), + ct_gglwe_in.size(), ct_rgsw.size(), + rank_out, ) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_rgsw.size()), ); - let k: usize = 1; + let r: usize = 1; - pt_rgsw.raw_mut()[k] = 1; // X^{k} + pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} - pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); + let mut sk_in: SecretKey> = SecretKey::new(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in); + sk_in_dft.dft(&module, &sk_in); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); + let mut sk_out: SecretKey> = SecretKey::new(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_in.encrypt_sk( + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out); + sk_out_dft.dft(&module, &sk_out); + + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe_in.encrypt_sk( &module, - &pt_grlwe, - &sk_dft, + &sk_in, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -344,104 +395,104 @@ fn external_product() { ct_rgsw.encrypt_sk( &module, &pt_rgsw, - &sk_dft, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) - ct_grlwe_out.external_product(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); + // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) + ct_gglwe_out.external_product(&module, &ct_gglwe_in, &ct_rgsw, scratch.borrow()); - let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, basek, log_k_grlwe, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_grlwe); + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); - module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); + (0..rank_in).for_each(|i| { + module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data, i); // * X^{r} + }); - (0..ct_grlwe_out.rows()).for_each(|row_i| { - ct_grlwe_out.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); + (0..rank_in).for_each(|col_i| { + (0..ct_gglwe_out.rows()).for_each(|row_i| { + ct_gglwe_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft); + ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i); - let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_have: f64 = pt.data.std(0, basek).log2(); - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; - let noise_want: f64 = noise_rgsw_product( - module.n() as f64, - basek, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_grlwe, - log_k_grlwe, - ); + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank_out as f64, + k, + k, + ); - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); }); } -#[test] -fn external_product_inplace() { - let module: Module = Module::::new(2048); - let basek: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + basek - 1) / basek; +fn test_external_product_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank_in: usize, rank_out: usize) { + let module: Module = Module::::new(1 << log_n); - let rank: usize = 1; - let rank_out: usize = 1; + let rows: usize = (k + basek - 1) / basek; - let sigma: f64 = 3.2; - - let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); + let mut ct_gglwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank_in, rank_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank_out); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe.size()) - | GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ct_gglwe.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe.size()) + | GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_gglwe.size(), ct_rgsw.size(), rank_out) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_rgsw.size()), ); - let k: usize = 1; + let r: usize = 1; - pt_rgsw.raw_mut()[k] = 1; // X^{k} + pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} - pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); + let mut sk_in: SecretKey> = SecretKey::new(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in); + sk_in_dft.dft(&module, &sk_in); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); + let mut sk_out: SecretKey> = SecretKey::new(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe.encrypt_sk( + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out); + sk_out_dft.dft(&module, &sk_out); + + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe.encrypt_sk( &module, - &pt_grlwe, - &sk_dft, + &sk_in, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -451,57 +502,62 @@ fn external_product_inplace() { ct_rgsw.encrypt_sk( &module, &pt_rgsw, - &sk_dft, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) - ct_grlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); + // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) + ct_gglwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); - let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, basek, log_k_grlwe, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_grlwe); + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); - module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); + (0..rank_in).for_each(|i| { + module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data, i); // * X^{r} + }); - (0..ct_grlwe.rows()).for_each(|row_i| { - ct_grlwe.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); + (0..rank_in).for_each(|col_i| { + (0..ct_gglwe.rows()).for_each(|row_i| { + ct_gglwe.get_row(&module, row_i, col_i, &mut ct_glwe_dft); + ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i); - let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_have: f64 = pt.data.std(0, basek).log2(); - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; - let noise_want: f64 = noise_rgsw_product( - module.n() as f64, - basek, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_grlwe, - log_k_grlwe, - ); + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank_out as f64, + k, + k, + ); - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); }); } + pub(crate) fn noise_gglwe_product( n: f64, basek: usize, diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 2d83791..37bfc4e 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -263,10 +263,10 @@ fn test_keyswitch( | GLWECiphertext::keyswitch_scratch_space( &module, ct_out.size(), - ct_in.size(), - ksk.size(), - rank_in, rank_out, + ct_in.size(), + rank_in, + ksk.size(), ), ); @@ -352,7 +352,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size(), rank), + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), rank, ct_grlwe.size()), ); let mut sk0: SecretKey> = SecretKey::new(&module, rank); diff --git a/core/src/test_fft64/glwe_fourier.rs b/core/src/test_fft64/glwe_fourier.rs index c737c55..3887558 100644 --- a/core/src/test_fft64/glwe_fourier.rs +++ b/core/src/test_fft64/glwe_fourier.rs @@ -84,10 +84,10 @@ fn test_keyswitch( | GLWECiphertextFourier::keyswitch_scratch_space( &module, ct_glwe_out.size(), - ct_glwe_in.size(), - ksk.size(), - rank_in, rank_out, + ct_glwe_in.size(), + rank_in, + ksk.size(), ), ); From 7434f289fe8a8c9156b1529e9eb6e52888d207f7 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 16 May 2025 14:15:41 +0200 Subject: [PATCH 74/87] Added automorphism for glwe --- base2k/src/scalar_znx.rs | 64 +++++++++ core/src/automorphism.rs | 253 +++++++++++++++++++++++++++++++++++ core/src/gglwe_ciphertext.rs | 9 +- core/src/glwe_ciphertext.rs | 71 ++++++++-- core/src/keyswitch_key.rs | 62 ++++----- core/src/lib.rs | 1 + core/src/test_fft64/glwe.rs | 109 +++++++++++++++ 7 files changed, 521 insertions(+), 48 deletions(-) create mode 100644 core/src/automorphism.rs diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index 108ba3f..8da145f 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -1,3 +1,4 @@ +use crate::ffi::vec_znx; use crate::znx_base::ZnxInfos; use crate::{ Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned, @@ -122,6 +123,69 @@ impl ScalarZnxAlloc for Module { } } +pub trait ScalarZnxOps { + fn scalar_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxToMut, + A: ScalarZnxToRef; + + /// Applies the automorphism X^i -> X^ik on the selected column of `a`. + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: ScalarZnxToMut; +} + +impl ScalarZnxOps for Module { + fn scalar_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxToMut, + A: ScalarZnxToRef, + { + let a: ScalarZnx<&[u8]> = a.to_ref(); + let mut res: ScalarZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + self.ptr, + k, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: ScalarZnxToMut, + { + let mut a: ScalarZnx<&mut [u8]> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + self.ptr, + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + impl ScalarZnx { pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { Self { data, n, cols } diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs new file mode 100644 index 0000000..ed6a954 --- /dev/null +++ b/core/src/automorphism.rs @@ -0,0 +1,253 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps, + ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, +}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos, SetRow}, + gglwe_ciphertext::GGLWECiphertext, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, +}; + +pub struct AutomorphismKey { + pub(crate) key: GLWESwitchingKey, + pub(crate) p: i64, +} + +impl AutomorphismKey, FFT64> { + pub fn new(module: &Module, basek: usize, p: i64, k: usize, rows: usize, rank: usize) -> Self { + AutomorphismKey { + key: GLWESwitchingKey::new(module, basek, k, rows, rank, rank), + p: p, + } + } +} + +impl Infos for AutomorphismKey { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.key.inner() + } + + fn basek(&self) -> usize { + self.key.basek() + } + + fn k(&self) -> usize { + self.key.k() + } +} + +impl AutomorphismKey { + pub fn p(&self) -> i64 { + self.p + } + + pub fn rank(&self) -> usize { + self.key.rank() + } + + pub fn rank_in(&self) -> usize { + self.key.rank_in() + } + + pub fn rank_out(&self) -> usize { + self.key.rank_out() + } +} + +impl MatZnxDftToMut for AutomorphismKey +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.key.to_mut() + } +} + +impl MatZnxDftToRef for AutomorphismKey +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.key.to_ref() + } +} + +impl GetRow for AutomorphismKey +where + MatZnxDft: MatZnxDftToRef, +{ + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToMut, + { + module.vmp_extract_row(res, self, row_i, col_j); + } +} + +impl SetRow for AutomorphismKey +where + MatZnxDft: MatZnxDftToMut, +{ + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToRef, + { + module.vmp_prepare_row(self, row_i, col_j, a); + } +} + +impl AutomorphismKey, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + GGLWECiphertext::encrypt_sk_scratch_space(module, rank, size) + } + + pub fn encrypt_pk_scratch_space(module: &Module, rank: usize, pk_size: usize) -> usize { + GGLWECiphertext::encrypt_pk_scratch_space(module, rank, pk_size) + } + + pub fn keyswitch_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ksk_size: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::keyswitch_scratch_space(module, out_size, rank, in_size, rank, ksk_size) + } + + pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { + GLWESwitchingKey::keyswitch_inplace_scratch_space(module, out_size, out_rank, ksk_size) + } + + pub fn external_product_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank) + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank) + } +} + +impl AutomorphismKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + p: i64, + sk: &SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + ScalarZnx: ScalarZnxToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.n(), module.n()); + assert_eq!(sk.n(), module.n()); + assert_eq!(self.rank_out(), self.rank_in()); + assert_eq!(sk.rank(), self.rank()); + } + + let (sk_out_dft_data, scratch_1) = scratch.tmp_scalar_znx_dft(module, sk.rank()); + + let mut sk_out_dft: SecretKeyFourier<&mut [u8], FFT64> = SecretKeyFourier { + data: sk_out_dft_data, + dist: sk.dist, + }; + + { + (0..self.rank()).for_each(|i| { + let (mut sk_inv_auto, _) = scratch_1.tmp_scalar_znx(module, 1); + module.scalar_znx_automorphism(module.galois_element_inv(p), &mut sk_inv_auto, 0, sk, i); + module.svp_prepare(&mut sk_out_dft, i, &sk_inv_auto, 0); + }); + } + + self.key.encrypt_sk( + module, + &sk, + &sk_out_dft, + source_xa, + source_xe, + sigma, + scratch_1, + ); + + self.p = p; + } +} + +impl AutomorphismKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &AutomorphismKey, + rhs: &GLWESwitchingKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + self.key.keyswitch(module, &lhs.key, rhs, scratch); + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + rhs: &AutomorphismKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + self.key.keyswitch_inplace(module, &rhs.key, scratch); + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &AutomorphismKey, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + self.key.external_product(module, &lhs.key, rhs, scratch); + } + + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + self.key.external_product_inplace(module, rhs, scratch); + } +} diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index 7deb225..863fd54 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -21,10 +21,10 @@ pub struct GGLWECiphertext { } impl GGLWECiphertext, B> { - pub fn new(module: &Module, base2k: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { + pub fn new(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { Self { - data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(base2k, k)), - basek: base2k, + data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(basek, k)), + basek: basek, k, } } @@ -161,6 +161,7 @@ where (0..cols_in).for_each(|col_i| { (0..rows).for_each(|row_i| { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + vec_znx_pt.data.zero(); // zeroes for next iteration module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, col_i); // Selects the i-th module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3); @@ -175,8 +176,6 @@ where scratch_3, ); - vec_znx_pt.data.zero(); // zeroes for next iteration - // Switch vec_znx_ct into DFT domain vec_znx_ct.dft(module, &mut vec_znx_ct_dft); diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index f0dbae1..422f2cc 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -8,6 +8,7 @@ use sampling::source::Source; use crate::{ SIX_SIGMA, + automorphism::AutomorphismKey, elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, @@ -137,21 +138,40 @@ impl GLWECiphertext> { GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size) } + pub fn automorphism_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + in_size: usize, + autokey_size: usize, + ) -> usize { + GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, in_size, out_rank, autokey_size) + } + + pub fn automorphism_inplace_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + autokey_size: usize, + ) -> usize { + GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, autokey_size) + } + pub fn external_product_scratch_space( module: &Module, out_size: usize, + out_rank: usize, in_size: usize, ggsw_size: usize, - rank: usize, ) -> usize { - let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, ggsw_size); - let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) + let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, ggsw_size); + let vmp: usize = module.bytes_of_vec_znx_dft(out_rank + 1, in_size) + module.vmp_apply_tmp_bytes( out_size, in_size, - in_size, // rows - rank + 1, // cols in - rank + 1, // cols out + in_size, // rows + out_rank + 1, // cols in + out_rank + 1, // cols out ggsw_size, ); let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); @@ -159,8 +179,13 @@ impl GLWECiphertext> { res_dft + (vmp | normalize) } - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - GLWECiphertext::external_product_scratch_space(module, res_size, res_size, rhs, rank) + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + ggsw_size: usize, + ) -> usize { + GLWECiphertext::external_product_scratch_space(module, out_size, out_rank, out_size, ggsw_size) } } @@ -244,6 +269,36 @@ where self.encrypt_pk_private(module, None, pk, source_xu, source_xe, sigma, scratch); } + pub fn automorphism( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + { + self.keyswitch(module, lhs, &rhs.key, scratch); + //(0..self.rank() + 1).for_each(|i| { + // module.vec_znx_automorphism_inplace(rhs.p(), self, i); + //}) + } + + pub fn automorphism_inplace( + &mut self, + module: &Module, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + self.keyswitch_inplace(module, &rhs.key, scratch); + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(rhs.p(), self, i); + }) + } + pub fn keyswitch( &mut self, module: &Module, diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index 34595e3..e01df09 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -15,9 +15,9 @@ use crate::{ pub struct GLWESwitchingKey(pub(crate) GGLWECiphertext); impl GLWESwitchingKey, FFT64> { - pub fn new(module: &Module, base2k: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { + pub fn new(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { GLWESwitchingKey(GGLWECiphertext::new( - module, base2k, k, rows, rank_in, rank_out, + module, basek, k, rows, rank_in, rank_out, )) } } @@ -26,7 +26,7 @@ impl Infos for GLWESwitchingKey { type Inner = MatZnxDft; fn inner(&self) -> &Self::Inner { - &self.0.inner() + self.0.inner() } fn basek(&self) -> usize { @@ -102,38 +102,7 @@ impl GLWESwitchingKey, FFT64> { pub fn encrypt_pk_scratch_space(module: &Module, rank: usize, pk_size: usize) -> usize { GGLWECiphertext::encrypt_pk_scratch_space(module, rank, pk_size) } -} -impl GLWESwitchingKey -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, -{ - pub fn encrypt_sk( - &mut self, - module: &Module, - sk_in: &SecretKey, - sk_out_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) where - ScalarZnx: ScalarZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - self.0.encrypt_sk( - module, - &sk_in.data, - sk_out_dft, - source_xa, - source_xe, - sigma, - scratch, - ); - } -} - -impl GLWESwitchingKey, FFT64> { pub fn keyswitch_scratch_space( module: &Module, out_size: usize, @@ -178,11 +147,34 @@ impl GLWESwitchingKey, FFT64> { tmp + ggsw } } - impl GLWESwitchingKey where MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, { + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_in: &SecretKey, + sk_out_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + ScalarZnx: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + self.0.encrypt_sk( + module, + &sk_in.data, + sk_out_dft, + source_xa, + source_xe, + sigma, + scratch, + ); + } + pub fn keyswitch( &mut self, module: &Module, diff --git a/core/src/lib.rs b/core/src/lib.rs index 60d57c2..f04ca06 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,3 +1,4 @@ +pub mod automorphism; pub mod elem; pub mod gglwe_ciphertext; pub mod ggsw_ciphertext; diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 37bfc4e..525de22 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -6,6 +6,7 @@ use itertools::izip; use sampling::source::Source; use crate::{ + automorphism::AutomorphismKey, elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, @@ -415,6 +416,114 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, ); } +#[test] +fn automorphism() { + (1..4).for_each(|rank| { + println!("test automorphism rank: {}", rank); + test_automorphism(12, 12, 1, 60, 45, 60, rank, 3.2); + }); +} + +fn test_automorphism( + log_n: usize, + basek: usize, + p: i64, + k_autokey: usize, + k_ct_in: usize, + k_ct_out: usize, + rank: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct_in + basek - 1) / basek; + + let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, p, k_autokey, rows, rank); + let mut ct_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank); + let mut ct_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + // pt_want + // .data + // .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + pt_want + .to_mut() + .at_mut(0, 1) + .iter_mut() + .enumerate() + .for_each(|(i, x)| { + *x = i as i64; + }); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, autokey.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size()) + | GLWECiphertext::automorphism_scratch_space(&module, ct_out.size(), rank, ct_in.size(), autokey.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + autokey.encrypt_sk( + &module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_out.automorphism(&module, &ct_in, &autokey, scratch.borrow()); + + ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_automorphism_inplace(p, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct_in, + k_autokey, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} + fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usize, k_ct_out: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); From b71e526260ef2403032634e3660bad08a712d8a6 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 16 May 2025 16:27:49 +0200 Subject: [PATCH 75/87] wip adding automorphism on AutomorphismKey --- base2k/src/scalar_znx.rs | 4 +- base2k/src/vec_znx_big.rs | 15 ++++- base2k/src/vec_znx_big_ops.rs | 2 + core/src/automorphism.rs | 100 +++++++++++++++++++++++++++- core/src/glwe_ciphertext.rs | 64 ++++++++++++++++++ core/src/glwe_ciphertext_fourier.rs | 61 +++-------------- 6 files changed, 191 insertions(+), 55 deletions(-) diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index 8da145f..fa812a8 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -130,7 +130,7 @@ pub trait ScalarZnxOps { A: ScalarZnxToRef; /// Applies the automorphism X^i -> X^ik on the selected column of `a`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + fn scalar_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) where A: ScalarZnxToMut; } @@ -162,7 +162,7 @@ impl ScalarZnxOps for Module { } } - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + fn scalar_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) where A: ScalarZnxToMut, { diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 8b3223b..eba90e9 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,6 +1,6 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned}; +use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, FFT64}; use std::fmt; use std::marker::PhantomData; @@ -97,7 +97,18 @@ impl VecZnxBig { impl VecZnxBig where VecZnxBig: VecZnxBigToMut + ZnxInfos, -{ +{ + // Consumes the VecZnxBig to return a VecZnx. + // Useful when no normalization is needed. + pub fn to_vec_znx_small(self) -> VecZnx{ + VecZnx{ + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + } + } + /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. pub fn extract_column(&mut self, self_col: usize, a: &VecZnxBig, a_col: usize) where diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 8208c97..f6dad7a 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -147,6 +147,7 @@ pub trait VecZnxBigOps { fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) where A: VecZnxBigToMut; + } pub trait VecZnxBigScratch { @@ -169,6 +170,7 @@ impl VecZnxBigAlloc for Module { } impl VecZnxBigOps for Module { + fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) where R: VecZnxBigToMut, diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index ed6a954..8741bf9 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -1,6 +1,6 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps, - ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, + ScalarZnxToRef, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxZero, }; use sampling::source::Source; @@ -8,6 +8,7 @@ use crate::{ elem::{GetRow, Infos, SetRow}, gglwe_ciphertext::GGLWECiphertext, ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext::GLWECiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, keys::{SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, @@ -203,6 +204,103 @@ impl AutomorphismKey where MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, { + pub fn automorphism( + &mut self, + module: &Module, + lhs: &AutomorphismKey, + rhs: &AutomorphismKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank_in(), + "ksk_in output rank: {} != ksk_apply input rank: {}", + self.rank_out(), + rhs.rank_in() + ); + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + } + + let cols_out: usize = rhs.rank_out() + 1; + + let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, lhs.size()); + + let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_dft_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + // Extracts relevant row + lhs.get_row(module, row_j, col_i, &mut tmp_dft); + + // Get a VecZnxBig from scratch space + let (mut tmp_idft_data, scratch2) = scratch1.tmp_vec_znx_big(module, cols_out, self.size()); + + // Switches input outside of DFT + (0..cols_out).for_each(|i| { + module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2); + }); + + // Consumes to small vec znx + let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small(); + + // Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + (0..cols_out).for_each(|i| { + module.vec_znx_automorphism_inplace(self.p(), &mut tmp_idft_small_data, i); + }); + + // Wraps into ciphertext + let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: tmp_idft_small_data, + basek: self.basek(), + k: self.k(), + }; + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + tmp_idft.keyswitch_inplace(module, &rhs.key, scratch2); + + // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) + // and switches back to DFT domain + (0..self.rank_out() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i); + module.vec_znx_dft(&mut tmp_dft, i, &tmp_idft, i); + }); + + // Sets back the relevant row + self.set_row(module, row_j, col_i, &tmp_dft); + }); + }); + + tmp_dft.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_dft); + }); + }); + } + pub fn keyswitch( &mut self, module: &Module, diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 422f2cc..245ee26 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -299,6 +299,70 @@ where }) } + pub(crate) fn keyswitch_from_fourier( + &mut self, + module: &Module, + lhs: &GLWECiphertextFourier, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(lhs.rank(), rhs.rank_in()); + assert_eq!(self.rank(), rhs.rank_out()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertextFourier::keyswitch_scratch_space( + module, + self.size(), + self.rank(), + lhs.size(), + lhs.rank(), + rhs.size(), + ) + ); + } + + let cols_in: usize = rhs.rank_in(); + let cols_out: usize = rhs.rank_out() + 1; + + // Buffer of the result of VMP in DFT + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise + + { + // Applies VMP + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size()); + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1); + }); + module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2); + } + + // Switches result of VMP outside of DFT + let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); + + { + // Switches lhs 0-th outside of DFT domain and adds on + let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size()); + module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2); + module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0); + } + + (0..cols_out).for_each(|i| { + module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); + }); + } + pub fn keyswitch( &mut self, module: &Module, diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index ebbe9cf..4c22507 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -174,60 +174,21 @@ where VecZnxDft: VecZnxDftToRef, MatZnxDft: MatZnxDftToRef, { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(lhs.rank(), rhs.rank_in()); - assert_eq!(self.rank(), rhs.rank_out()); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); - assert!( - scratch.available() - >= GLWECiphertextFourier::keyswitch_scratch_space( - module, - self.size(), - self.rank(), - lhs.size(), - lhs.rank(), - rhs.size(), - ) - ); - } - - let cols_in: usize = rhs.rank_in(); let cols_out: usize = rhs.rank_out() + 1; - // Buffer of the result of VMP in DFT - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise - - { - // Applies VMP - let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size()); - (0..cols_in).for_each(|col_i| { - module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1); - }); - module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2); - } - - // Switches result of VMP outside of DFT - let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); - - { - // Switches lhs 0-th outside of DFT domain and adds on - let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size()); - module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2); - module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0); - } - // Space fr normalized VMP result outside of DFT domain - let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols_out, lhs.size()); + let (res_idft_data, scratch1) = scratch.tmp_vec_znx(module, cols_out, lhs.size()); + + let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: res_idft_data, + basek: self.basek, + k: self.k, + }; + + res_idft.keyswitch_from_fourier(module, self, rhs, scratch1); + (0..cols_out).for_each(|i| { - module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); - module.vec_znx_dft(self, i, &res_small, i); + module.vec_znx_dft(self, i, &res_idft, i); }); } From 937e7c6ccf2e244cf21edec591d802a572a6cb6f Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 19 May 2025 13:04:37 +0200 Subject: [PATCH 76/87] fixed all broken tests --- core/src/ggsw_ciphertext.rs | 4 ++-- core/src/glwe_ciphertext.rs | 32 +++++++++++++++++++++-------- core/src/glwe_ciphertext_fourier.rs | 24 +++++----------------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 67f4774..577bd6e 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -149,6 +149,8 @@ where }; (0..self.rows()).for_each(|row_j| { + vec_znx_pt.data.zero(); + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0); module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scrach_2); @@ -177,8 +179,6 @@ where module.vmp_prepare_row(self, row_j, col_i, &vec_znx_dft_ct); } }); - - vec_znx_pt.data.zero(); // zeroes for next iteration }); } diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 245ee26..0fd6242 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -121,19 +121,33 @@ impl GLWECiphertext> { ksk_size: usize, ) -> usize { let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size); - let vmp: usize = module.vmp_apply_tmp_bytes( - out_size, - in_size, - in_size, - in_rank + 1, - out_rank + 1, - ksk_size, - ) + module.bytes_of_vec_znx_dft(in_rank, in_size); + let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size) + + module.bytes_of_vec_znx_dft(in_rank, in_size); let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); return res_dft + (vmp | normalize); } + pub fn keyswitch_from_fourier_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + in_size: usize, + in_rank: usize, + ksk_size: usize, + ) -> usize { + let res_dft = module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size); + + let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size) + + module.bytes_of_vec_znx_dft(in_rank, in_size); + + let a0_big: usize = module.bytes_of_vec_znx_big(1, in_size) + module.vec_znx_idft_tmp_bytes(); + + let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); + + res_dft + (vmp | a0_big | norm) + } + pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size) } @@ -322,7 +336,7 @@ where assert_eq!(lhs.n(), module.n()); assert!( scratch.available() - >= GLWECiphertextFourier::keyswitch_scratch_space( + >= GLWECiphertext::keyswitch_from_fourier_scratch_space( module, self.size(), self.rank(), diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index 4c22507..135a2dd 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -91,22 +91,8 @@ impl GLWECiphertextFourier, FFT64> { in_rank: usize, ksk_size: usize, ) -> usize { - let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size); - - let vmp = module.bytes_of_vec_znx_dft(in_rank, in_size) - + module.vmp_apply_tmp_bytes( - out_size, - in_size, - in_size, - in_rank + 1, - out_rank + 1, - ksk_size, - ); - let res_small: usize = module.bytes_of_vec_znx(out_rank + 1, out_size); - let add_a0: usize = module.bytes_of_vec_znx_big(1, in_size) + module.vec_znx_idft_tmp_bytes(); - let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); - - res_dft + (vmp | add_a0 | (res_small + normalize)) + module.bytes_of_vec_znx(out_rank + 1, out_size) + + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, out_rank, in_size, in_rank, ksk_size) } pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { @@ -181,11 +167,11 @@ where let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: res_idft_data, - basek: self.basek, - k: self.k, + basek: lhs.basek, + k: lhs.k, }; - res_idft.keyswitch_from_fourier(module, self, rhs, scratch1); + res_idft.keyswitch_from_fourier(module, lhs, rhs, scratch1); (0..cols_out).for_each(|i| { module.vec_znx_dft(self, i, &res_idft, i); From 13e26c815282780630426be3aff8c0065814813e Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 19 May 2025 13:17:56 +0200 Subject: [PATCH 77/87] Added test for automorphism inplace --- core/src/test_fft64/glwe.rs | 110 ++++++++++++++++++++++++++++++++---- 1 file changed, 99 insertions(+), 11 deletions(-) diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 525de22..53c06fe 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -448,18 +448,9 @@ fn test_automorphism( let mut source_xa: Source = Source::new([0u8; 32]); // Random input plaintext - // pt_want - // .data - // .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - pt_want - .to_mut() - .at_mut(0, 1) - .iter_mut() - .enumerate() - .for_each(|(i, x)| { - *x = i as i64; - }); + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, autokey.size()) @@ -524,6 +515,103 @@ fn test_automorphism( ); } +#[test] +fn automorphism_inplace() { + (1..4).for_each(|rank| { + println!("test automorphism_inplace rank: {}", rank); + test_automorphism_inplace(12, 12, 1, 60, 60, rank, 3.2); + }); +} + +fn test_automorphism_inplace( + log_n: usize, + basek: usize, + p: i64, + k_autokey: usize, + k_ct: usize, + rank: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct + basek - 1) / basek; + + let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, p, k_autokey, rows, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, autokey.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) + | GLWECiphertext::automorphism_inplace_scratch_space(&module, ct.size(), rank, autokey.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + autokey.encrypt_sk( + &module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.automorphism_inplace(&module, &autokey, scratch.borrow()); + + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_automorphism_inplace(p, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct, + k_autokey, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} + fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usize, k_ct_out: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); From b9cc21079362867684a341ffdc2bd581c32d9cd3 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 19 May 2025 14:05:20 +0200 Subject: [PATCH 78/87] added test for automorphism on automorphism key --- base2k/src/module.rs | 8 +- base2k/src/vec_znx_ops.rs | 1 + core/src/automorphism.rs | 27 +++++- core/src/test_fft64/automorphism_key.rs | 117 ++++++++++++++++++++++++ core/src/test_fft64/gglwe.rs | 13 +++ core/src/test_fft64/glwe.rs | 14 +-- core/src/test_fft64/mod.rs | 1 + 7 files changed, 161 insertions(+), 20 deletions(-) create mode 100644 core/src/test_fft64/automorphism_key.rs diff --git a/base2k/src/module.rs b/base2k/src/module.rs index 904d0ec..8ee6e4b 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -74,15 +74,15 @@ impl Module { } // Returns gen^-1 - pub fn galois_element_inv(&self, generator: i64) -> i64 { - if generator == 0 { + pub fn galois_element_inv(&self, gal_el: i64) -> i64 { + if gal_el == 0 { panic!("cannot invert 0") } ((mod_exp_u64( - generator.abs() as u64, + gal_el.abs() as u64, (self.cyclotomic_order() - 1) as usize, ) & (self.cyclotomic_order() - 1)) as i64) - * generator.signum() + * gal_el.signum() } } diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index f57e99f..85c2e1f 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -574,6 +574,7 @@ impl VecZnxOps for Module { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); + assert!(k & 1 != 0, "invalid galois element: must be odd but is {}", k); } unsafe { vec_znx::vec_znx_automorphism( diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index 8741bf9..6913fab 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -1,6 +1,7 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps, - ScalarZnxToRef, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxZero, + ScalarZnxToRef, Scratch, VecZnx, VecZnxBigAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, + VecZnxOps, ZnxZero, }; use sampling::source::Source; @@ -20,10 +21,10 @@ pub struct AutomorphismKey { } impl AutomorphismKey, FFT64> { - pub fn new(module: &Module, basek: usize, p: i64, k: usize, rows: usize, rank: usize) -> Self { + pub fn new(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { AutomorphismKey { key: GLWESwitchingKey::new(module, basek, k, rows, rank, rank), - p: p, + p: 0, } } } @@ -127,6 +128,20 @@ impl AutomorphismKey, FFT64> { GLWESwitchingKey::keyswitch_inplace_scratch_space(module, out_size, out_rank, ksk_size) } + pub fn automorphism_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ksk_size: usize, + rank: usize, + ) -> usize { + let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size); + let tmp_idft: usize = module.bytes_of_vec_znx_big(rank + 1, out_size); + let idft: usize = module.vec_znx_idft_tmp_bytes(); + let keyswitch: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, out_size, rank, ksk_size); + tmp_dft + tmp_idft + idft + keyswitch + } + pub fn external_product_scratch_space( module: &Module, out_size: usize, @@ -267,7 +282,7 @@ where // Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) (0..cols_out).for_each(|i| { - module.vec_znx_automorphism_inplace(self.p(), &mut tmp_idft_small_data, i); + module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft_small_data, i); }); // Wraps into ciphertext @@ -283,7 +298,7 @@ where // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) // and switches back to DFT domain (0..self.rank_out() + 1).for_each(|i| { - module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i); + module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft, i); module.vec_znx_dft(&mut tmp_dft, i, &tmp_idft, i); }); @@ -299,6 +314,8 @@ where self.set_row(module, row_i, col_j, &tmp_dft); }); }); + + self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64); } pub fn keyswitch( diff --git a/core/src/test_fft64/automorphism_key.rs b/core/src/test_fft64/automorphism_key.rs new file mode 100644 index 0000000..0e81578 --- /dev/null +++ b/core/src/test_fft64/automorphism_key.rs @@ -0,0 +1,117 @@ +use base2k::{FFT64, Module, ScalarZnxOps, ScalarZnxToRef, ScratchOwned, Stats, VecZnxOps, ZnxView}; +use sampling::source::Source; + +use crate::{ + automorphism::AutomorphismKey, + elem::{GetRow, Infos}, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{SecretKey, SecretKeyFourier}, + test_fft64::gglwe::noise_gglwe_product, +}; + +#[test] +fn automorphism() { + (1..4).for_each(|rank| { + println!("test automorphism rank: {}", rank); + test_automorphism(-1, 5, 12, 12, 60, 3.2, rank); + }); +} + +fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + let rows = (k_ksk + basek - 1) / basek; + + let mut auto_key_in: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank); + let mut auto_key_out: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank); + let mut auto_key_apply: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, auto_key_out.size()) + | AutomorphismKey::automorphism_scratch_space( + &module, + auto_key_out.size(), + auto_key_in.size(), + auto_key_apply.size(), + rank, + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + // gglwe_{s1}(s0) = s0 -> s1 + auto_key_in.encrypt_sk( + &module, + p0, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + auto_key_apply.encrypt_sk( + &module, + p1, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + auto_key_out.automorphism(&module, &auto_key_in, &auto_key_apply, scratch.borrow()); + + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ksk); + + let mut sk_auto: SecretKey> = SecretKey::new(&module, rank); + sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk + (0..rank).for_each(|i| { + module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i); + }); + + let mut sk_auto_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_auto_dft.dft(&module, &sk_auto); + + (0..auto_key_out.rank_in()).for_each(|col_i| { + (0..auto_key_out.rows()).for_each(|row_i| { + auto_key_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft); + + ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_want: f64 = noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ksk, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + }); +} diff --git a/core/src/test_fft64/gglwe.rs b/core/src/test_fft64/gglwe.rs index ff4bcfe..f0ccdb5 100644 --- a/core/src/test_fft64/gglwe.rs +++ b/core/src/test_fft64/gglwe.rs @@ -405,6 +405,19 @@ fn test_external_product(log_n: usize, basek: usize, k: usize, sigma: f64, rank_ // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) ct_gglwe_out.external_product(&module, &ct_gglwe_in, &ct_rgsw, scratch.borrow()); + scratch = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ct_gglwe_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_out.size()) + | GLWESwitchingKey::external_product_scratch_space( + &module, + ct_gglwe_out.size(), + ct_gglwe_in.size(), + ct_rgsw.size(), + rank_out, + ) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_rgsw.size()), + ); + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank_out); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 53c06fe..687d49a 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -437,7 +437,7 @@ fn test_automorphism( let module: Module = Module::::new(1 << log_n); let rows: usize = (k_ct_in + basek - 1) / basek; - let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, p, k_autokey, rows, rank); + let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_autokey, rows, rank); let mut ct_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank); let mut ct_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_in); @@ -523,19 +523,11 @@ fn automorphism_inplace() { }); } -fn test_automorphism_inplace( - log_n: usize, - basek: usize, - p: i64, - k_autokey: usize, - k_ct: usize, - rank: usize, - sigma: f64, -) { +fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usize, k_ct: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); let rows: usize = (k_ct + basek - 1) / basek; - let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, p, k_autokey, rows, rank); + let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_autokey, rows, rank); let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); diff --git a/core/src/test_fft64/mod.rs b/core/src/test_fft64/mod.rs index ffaf1dc..9af0cfc 100644 --- a/core/src/test_fft64/mod.rs +++ b/core/src/test_fft64/mod.rs @@ -1,3 +1,4 @@ +mod automorphism_key; mod gglwe; mod ggsw; mod glwe; From c5fe07188fbfb0cd2e317e745213095cdbc879db Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 19 May 2025 14:22:05 +0200 Subject: [PATCH 79/87] fixed tests of automorphism over glwe --- core/src/automorphism.rs | 18 ++++++++++++++++++ core/src/glwe_ciphertext.rs | 6 +++--- core/src/test_fft64/automorphism_key.rs | 2 +- core/src/test_fft64/glwe.rs | 18 ++++++++---------- 4 files changed, 30 insertions(+), 14 deletions(-) diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index 6913fab..64b3cf7 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -142,6 +142,10 @@ impl AutomorphismKey, FFT64> { tmp_dft + tmp_idft + idft + keyswitch } + pub fn automorphism_inplace_scratch_space(module: &Module, out_size: usize, ksk_size: usize, rank: usize) -> usize { + AutomorphismKey::automorphism_scratch_space(module, out_size, out_size, ksk_size, rank) + } + pub fn external_product_scratch_space( module: &Module, out_size: usize, @@ -318,6 +322,20 @@ where self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64); } + pub fn automorphism_inplace( + &mut self, + module: &Module, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + unsafe { + let self_ptr: *mut AutomorphismKey = self as *mut AutomorphismKey; + self.automorphism(&module, &*self_ptr, rhs, scratch); + } + } + pub fn keyswitch( &mut self, module: &Module, diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 0fd6242..ca94db1 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -294,9 +294,9 @@ where MatZnxDft: MatZnxDftToRef, { self.keyswitch(module, lhs, &rhs.key, scratch); - //(0..self.rank() + 1).for_each(|i| { - // module.vec_znx_automorphism_inplace(rhs.p(), self, i); - //}) + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(rhs.p(), self, i); + }) } pub fn automorphism_inplace( diff --git a/core/src/test_fft64/automorphism_key.rs b/core/src/test_fft64/automorphism_key.rs index 0e81578..9705a3f 100644 --- a/core/src/test_fft64/automorphism_key.rs +++ b/core/src/test_fft64/automorphism_key.rs @@ -1,4 +1,4 @@ -use base2k::{FFT64, Module, ScalarZnxOps, ScalarZnxToRef, ScratchOwned, Stats, VecZnxOps, ZnxView}; +use base2k::{FFT64, Module, ScalarZnxOps, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; use crate::{ diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 687d49a..54f389c 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -1,6 +1,6 @@ use base2k::{ Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, - ZnxViewMut, ZnxZero, + ZnxView, ZnxViewMut, ZnxZero, }; use itertools::izip; use sampling::source::Source; @@ -420,7 +420,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, fn automorphism() { (1..4).for_each(|rank| { println!("test automorphism rank: {}", rank); - test_automorphism(12, 12, 1, 60, 45, 60, rank, 3.2); + test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2); }); } @@ -447,7 +447,6 @@ fn test_automorphism( let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - // Random input plaintext pt_want .data .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); @@ -486,14 +485,15 @@ fn test_automorphism( ); ct_out.automorphism(&module, &ct_in, &autokey, scratch.borrow()); - ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - module.vec_znx_automorphism_inplace(p, &mut pt_want, 0); - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + module.vec_znx_normalize_inplace(basek, &mut pt_have, 0, scratch.borrow()); let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + println!("{}", noise_have); + let noise_want: f64 = noise_gglwe_product( module.n() as f64, basek, @@ -519,7 +519,7 @@ fn test_automorphism( fn automorphism_inplace() { (1..4).for_each(|rank| { println!("test automorphism_inplace rank: {}", rank); - test_automorphism_inplace(12, 12, 1, 60, 60, rank, 3.2); + test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2); }); } @@ -575,12 +575,10 @@ fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usiz ); ct.automorphism_inplace(&module, &autokey, scratch.borrow()); - ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - module.vec_znx_automorphism_inplace(p, &mut pt_want, 0); - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + module.vec_znx_normalize_inplace(basek, &mut pt_have, 0, scratch.borrow()); let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_want: f64 = noise_gglwe_product( From 8f2eac4928f1026c74077d896262b20e20e66c53 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 19 May 2025 18:06:14 +0200 Subject: [PATCH 80/87] Added tensor key & associated test --- base2k/src/scalar_znx.rs | 6 +- base2k/src/scalar_znx_dft.rs | 72 +++++++++++++- base2k/src/vec_znx.rs | 10 ++ base2k/src/vec_znx_dft.rs | 10 +- core/src/ggsw_ciphertext.rs | 82 ++++++++++++++++ core/src/lib.rs | 1 + core/src/tensor_key.rs | 125 ++++++++++++++++++++++++ core/src/test_fft64/automorphism_key.rs | 99 +++++++++++++++++++ core/src/test_fft64/ggsw.rs | 121 ++++++++++++++++++++++- core/src/test_fft64/glwe.rs | 34 +++---- core/src/test_fft64/mod.rs | 1 + core/src/test_fft64/tensor_key.rs | 77 +++++++++++++++ 12 files changed, 610 insertions(+), 28 deletions(-) create mode 100644 core/src/tensor_key.rs create mode 100644 core/src/test_fft64/tensor_key.rs diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index fa812a8..4c981c1 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -9,9 +9,9 @@ use rand_distr::{Distribution, weighted::WeightedIndex}; use sampling::source::Source; pub struct ScalarZnx { - data: D, - n: usize, - cols: usize, + pub(crate) data: D, + pub(crate) n: usize, + pub(crate) cols: usize, } impl ZnxInfos for ScalarZnx { diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index 3626625..248b87d 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use crate::ffi::svp; use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; +use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView, FFT64}; pub struct ScalarZnxDft { data: D, @@ -92,6 +92,16 @@ impl ScalarZnxDft { _phantom: PhantomData, } } + + pub fn as_vec_znx_dft(self) -> VecZnxDft{ + VecZnxDft{ + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } } pub type ScalarZnxDftOwned = ScalarZnxDft, B>; @@ -158,3 +168,63 @@ impl ScalarZnxDftToRef for ScalarZnxDft<&[u8], B> { } } } + +impl VecZnxDftToMut for ScalarZnxDft, B> { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + VecZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for ScalarZnxDft, B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToMut for ScalarZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for ScalarZnxDft<&mut [u8], B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for ScalarZnxDft<&[u8], B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} \ No newline at end of file diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index b945b2c..5d9f1ca 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,5 +1,6 @@ use crate::DataView; use crate::DataViewMut; +use crate::ScalarZnx; use crate::ZnxSliceSize; use crate::ZnxZero; use crate::alloc_aligned; @@ -128,6 +129,15 @@ impl VecZnx { size, } } + + pub fn to_scalar_znx(self) -> ScalarZnx{ + debug_assert_eq!(self.size, 1, "cannot convert VecZnx to ScalarZnx if cols: {} != 1", self.cols); + ScalarZnx{ + data: self.data, + n: self.n, + cols: self.cols, + } + } } /// Copies the coefficients of `a` on the receiver. diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index b4bc973..7b4ec29 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -8,11 +8,11 @@ use crate::{ use std::fmt; pub struct VecZnxDft { - data: D, - n: usize, - cols: usize, - size: usize, - _phantom: PhantomData, + pub(crate) data: D, + pub(crate) n: usize, + pub(crate) cols: usize, + pub(crate) size: usize, + pub(crate) _phantom: PhantomData, } impl VecZnxDft { diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 577bd6e..4a8c0a8 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -6,6 +6,7 @@ use base2k::{ use sampling::source::Source; use crate::{ + automorphism::AutomorphismKey, elem::{GetRow, Infos, SetRow}, glwe_ciphertext::GLWECiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, @@ -78,6 +79,20 @@ impl GGSWCiphertext, FFT64> { + module.bytes_of_vec_znx_dft(rank + 1, size) } + pub fn automorphism_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + auto_key_size: usize, + rank: usize, + ) -> usize { + let size: usize = in_size.min(out_size); + let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, size); + let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, size); + let vmp: usize = GLWECiphertext::keyswitch_from_fourier_scratch_space(module, size, rank, size, rank, auto_key_size); + tmp_dft + tmp_idft + vmp + } + pub fn external_product_scratch_space( module: &Module, out_size: usize, @@ -182,6 +197,73 @@ where }); } + pub fn automorphism( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank(), + lhs.rank(), + "ggsw_out rank: {} != ggsw_in rank: {}", + self.rank(), + lhs.rank() + ); + assert_eq!( + self.rank(), + rhs.rank(), + "ggsw_in rank: {} != auto_key rank: {}", + self.rank(), + rhs.rank() + ); + } + + let size: usize = self.size().min(lhs.size()); + let cols: usize = self.rank() + 1; + + let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, size); + + let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_dft_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + let (tmp_idft_data, scratch2) = scratch1.tmp_vec_znx(module, cols, size); + + let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: tmp_idft_data, + basek: self.basek(), + k: self.k(), + }; + + (0..cols).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_dft); + tmp_idft.keyswitch_from_fourier(module, &tmp_dft, &rhs.key, scratch2); + (0..cols).for_each(|i| { + module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i); + }); + self.set_row(module, row_j, col_i, &tmp_dft); + }); + }); + + tmp_dft.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank() + 1).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_dft); + }); + }); + } + pub fn external_product( &mut self, module: &Module, diff --git a/core/src/lib.rs b/core/src/lib.rs index f04ca06..74ed7ef 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -7,6 +7,7 @@ pub mod glwe_ciphertext_fourier; pub mod glwe_plaintext; pub mod keys; pub mod keyswitch_key; +pub mod tensor_key; #[cfg(test)] mod test_fft64; mod utils; diff --git a/core/src/tensor_key.rs b/core/src/tensor_key.rs new file mode 100644 index 0000000..5625b51 --- /dev/null +++ b/core/src/tensor_key.rs @@ -0,0 +1,125 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftAlloc, + ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnxDftOps, VecZnxDftToRef, +}; +use sampling::source::Source; + +use crate::{ + elem::Infos, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, +}; + +pub struct TensorKey { + pub(crate) keys: Vec>, +} + +impl TensorKey, FFT64> { + pub fn new(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + let mut keys: Vec, FFT64>> = Vec::new(); + let pairs: usize = ((rank + 1) * rank) >> 1; + (0..pairs).for_each(|_| { + keys.push(GLWESwitchingKey::new(module, basek, k, rows, 1, rank)); + }); + Self { keys: keys } + } +} + +impl Infos for TensorKey { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.keys[0].inner() + } + + fn basek(&self) -> usize { + self.keys[0].basek() + } + + fn k(&self) -> usize { + self.keys[0].k() + } +} + +impl TensorKey { + pub fn rank(&self) -> usize { + self.keys[0].rank() + } + + pub fn rank_in(&self) -> usize { + self.keys[0].rank_in() + } + + pub fn rank_out(&self) -> usize { + self.keys[0].rank_out() + } +} + +impl TensorKey, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + module.bytes_of_scalar_znx_dft(1) + GLWESwitchingKey::encrypt_sk_scratch_space(module, rank, size) + } +} + +impl TensorKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + ScalarZnxDft: VecZnxDftToRef + ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + } + + let rank: usize = self.rank(); + + (0..rank).for_each(|i| { + (i..rank).for_each(|j| { + let (mut sk_ij_dft, scratch1) = scratch.tmp_scalar_znx_dft(module, 1); + module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j); + let sk_ij: ScalarZnx<&mut [u8]> = module + .vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft()) + .to_vec_znx_small() + .to_scalar_znx(); + let sk_ij: SecretKey<&mut [u8]> = SecretKey { + data: sk_ij, + dist: sk_dft.dist, + }; + + self.at_mut(i, j).encrypt_sk( + module, &sk_ij, sk_dft, source_xa, source_xe, sigma, scratch1, + ); + }); + }) + } + + // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) + pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank(); + &self.keys[i * rank + j - (i * (i + 1) / 2)] + } + + // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) + pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKey { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank(); + &mut self.keys[i * rank + j - (i * (i + 1) / 2)] + } +} diff --git a/core/src/test_fft64/automorphism_key.rs b/core/src/test_fft64/automorphism_key.rs index 9705a3f..6ac6b40 100644 --- a/core/src/test_fft64/automorphism_key.rs +++ b/core/src/test_fft64/automorphism_key.rs @@ -18,6 +18,14 @@ fn automorphism() { }); } +#[test] +fn automorphism_inplace() { + (1..4).for_each(|rank| { + println!("test automorphism_inplace rank: {}", rank); + test_automorphism_inplace(-1, 5, 12, 12, 60, 3.2, rank); + }); +} + fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) { let module: Module = Module::::new(1 << log_n); let rows = (k_ksk + basek - 1) / basek; @@ -115,3 +123,94 @@ fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, }); }); } + +fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + let rows = (k_ksk + basek - 1) / basek; + + let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank); + let mut auto_key_apply: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, auto_key.size()) + | AutomorphismKey::automorphism_inplace_scratch_space(&module, auto_key.size(), auto_key_apply.size(), rank), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + // gglwe_{s1}(s0) = s0 -> s1 + auto_key.encrypt_sk( + &module, + p0, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + auto_key_apply.encrypt_sk( + &module, + p1, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + auto_key.automorphism_inplace(&module, &auto_key_apply, scratch.borrow()); + + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ksk); + + let mut sk_auto: SecretKey> = SecretKey::new(&module, rank); + sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk + (0..rank).for_each(|i| { + module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i); + }); + + let mut sk_auto_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_auto_dft.dft(&module, &sk_auto); + + (0..auto_key.rank_in()).for_each(|col_i| { + (0..auto_key.rows()).for_each(|row_i| { + auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft); + + ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_want: f64 = noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ksk, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + }); +} diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index cf34dda..4325426 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -5,6 +5,7 @@ use base2k::{ use sampling::source::Source; use crate::{ + automorphism::AutomorphismKey, elem::{GetRow, Infos}, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, @@ -104,6 +105,123 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank: }); } +// fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { +// let module: Module = Module::::new(1 << log_n); +// let rows: usize = (k_ggsw + basek - 1) / basek; +// +// let mut ct_ggsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); +// let mut ct_ggsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); +// let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank); +// +// let mut pt_ggsw_in: ScalarZnx> = module.new_scalar_znx(1); +// let mut pt_ggsw_out: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// pt_ggsw_in.fill_ternary_prob(0, 0.5, &mut source_xs); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_out.size()) +// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_in.size()) +// | GGSWCiphertext::automorphism_scratch_space( +// &module, +// ct_ggsw_out.size(), +// ct_ggsw_in.size(), +// auto_key.size(), +// rank, +// ), +// ); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// ct_ggsw_in.encrypt_sk( +// &module, +// &pt_ggsw_in, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// auto_key.encrypt_sk( +// &module, +// p, +// &sk, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_ggsw_out.automorphism(&module, &ct_ggsw_in, &auto_key, scratch.borrow()); +// +// let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); +// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size()); +// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size()); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); +// +// module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); +// +// (0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| { +// (0..ct_ggsw_lhs_out.rows()).for_each(|row_i| { +// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0); +// +// if col_j > 0 { +// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); +// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); +// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); +// module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); +// } +// +// ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); +// ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); +// +// let noise_have: f64 = pt.data.std(0, basek).log2(); +// +// let var_gct_err_lhs: f64 = sigma * sigma; +// let var_gct_err_rhs: f64 = 0f64; +// +// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} +// let var_a0_err: f64 = sigma * sigma; +// let var_a1_err: f64 = 1f64 / 12f64; +// +// let noise_want: f64 = noise_ggsw_product( +// module.n() as f64, +// basek, +// 0.5, +// var_msg, +// var_a0_err, +// var_a1_err, +// var_gct_err_lhs, +// var_gct_err_rhs, +// rank as f64, +// k_ggsw, +// k_ggsw, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "have: {} want: {}", +// noise_have, +// noise_want +// ); +// +// pt_want.data.zero(); +// }); +// }); +// } + fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); @@ -126,8 +244,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size()) + GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size()) | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size()) | GGSWCiphertext::external_product_scratch_space( &module, diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 54f389c..e0323fa 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -1,6 +1,6 @@ use base2k::{ Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, - ZnxView, ZnxViewMut, ZnxZero, + ZnxViewMut, ZnxZero, }; use itertools::izip; use sampling::source::Source; @@ -75,6 +75,22 @@ fn external_product_inplace() { }); } +#[test] +fn automorphism_inplace() { + (1..4).for_each(|rank| { + println!("test automorphism_inplace rank: {}", rank); + test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2); + }); +} + +#[test] +fn automorphism() { + (1..4).for_each(|rank| { + println!("test automorphism rank: {}", rank); + test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2); + }); +} + fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { let module: Module = Module::::new(1 << log_n); @@ -416,14 +432,6 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, ); } -#[test] -fn automorphism() { - (1..4).for_each(|rank| { - println!("test automorphism rank: {}", rank); - test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2); - }); -} - fn test_automorphism( log_n: usize, basek: usize, @@ -515,14 +523,6 @@ fn test_automorphism( ); } -#[test] -fn automorphism_inplace() { - (1..4).for_each(|rank| { - println!("test automorphism_inplace rank: {}", rank); - test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2); - }); -} - fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usize, k_ct: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); let rows: usize = (k_ct + basek - 1) / basek; diff --git a/core/src/test_fft64/mod.rs b/core/src/test_fft64/mod.rs index 9af0cfc..fb2129e 100644 --- a/core/src/test_fft64/mod.rs +++ b/core/src/test_fft64/mod.rs @@ -3,3 +3,4 @@ mod gglwe; mod ggsw; mod glwe; mod glwe_fourier; +mod tensor_key; diff --git a/core/src/test_fft64/tensor_key.rs b/core/src/test_fft64/tensor_key.rs new file mode 100644 index 0000000..920341b --- /dev/null +++ b/core/src/test_fft64/tensor_key.rs @@ -0,0 +1,77 @@ +use base2k::{FFT64, Module, ScalarZnx, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxDftOps, VecZnxOps}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos}, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{SecretKey, SecretKeyFourier}, + tensor_key::TensorKey, +}; + +#[test] +fn encrypt_sk() { + (1..4).for_each(|rank| { + println!("test encrypt_sk rank: {}", rank); + test_encrypt_sk(12, 16, 54, 3.2, rank); + }); +} + +fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k + basek - 1) / basek; + + let mut tensor_key: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new(TensorKey::encrypt_sk_scratch_space( + &module, + rank, + tensor_key.size(), + )); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + tensor_key.encrypt_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + + (0..rank).for_each(|i| { + (0..rank).for_each(|j| { + let mut sk_ij_dft: base2k::ScalarZnxDft, FFT64> = module.new_scalar_znx_dft(1); + module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j); + let sk_ij: ScalarZnx> = module + .vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft()) + .to_vec_znx_small() + .to_scalar_znx(); + + (0..tensor_key.rank_in()).for_each(|col_i| { + (0..tensor_key.rows()).for_each(|row_i| { + tensor_key + .at(i, j) + .get_row(&module, row_i, col_i, &mut ct_glwe_fourier); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_ij, col_i); + let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + }); + }); + }) + }) +} From 06b3cccbffaceeb0e6e4bc23dd4a440955621db1 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 20 May 2025 11:43:18 +0200 Subject: [PATCH 81/87] Added GGSW key-switching along with algo description --- base2k/src/mat_znx_dft_ops.rs | 60 +++++++++++ core/src/automorphism.rs | 12 +-- core/src/elem.rs | 12 +-- core/src/gglwe_ciphertext.rs | 12 +-- core/src/ggsw_ciphertext.rs | 194 ++++++++++++++++++++++++++++++---- core/src/keyswitch_key.rs | 10 +- core/src/tensor_key.rs | 23 ++-- 7 files changed, 272 insertions(+), 51 deletions(-) diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 7b4ac36..24be2e2 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -99,6 +99,13 @@ pub trait MatZnxDftOps { R: VecZnxDftToMut, A: VecZnxDftToRef, B: MatZnxDftToRef; + + // Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R. + fn vmp_apply_add(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: MatZnxDftToRef; } impl MatZnxDftAlloc for Module { @@ -301,6 +308,59 @@ impl MatZnxDftOps for Module { ) } } + + fn vmp_apply_add(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: MatZnxDftToRef { + let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); + let a: VecZnxDft<&[u8], _> = a.to_ref(); + let b: MatZnxDft<&[u8], _> = b.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + res.cols(), + b.cols_out(), + "res.cols(): {} != b.cols_out: {}", + res.cols(), + b.cols_out() + ); + assert_eq!( + a.cols(), + b.cols_in(), + "a.cols(): {} != b.cols_in: {}", + a.cols(), + b.cols_in() + ); + } + + let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes( + res.size(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size(), + )); + unsafe { + vmp::vmp_apply_dft_to_dft_add( + self.ptr, + res.as_mut_ptr() as *mut vec_znx_dft_t, + (res.size() * res.cols()) as u64, + a.as_ptr() as *const vec_znx_dft_t, + (a.size() * a.cols()) as u64, + b.as_ptr() as *const vmp::vmp_pmat_t, + (b.rows() * b.cols_in()) as u64, + (b.size() * b.cols_out()) as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } } #[cfg(test)] mod tests { diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index 64b3cf7..8b4fe3a 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -1,7 +1,7 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps, - ScalarZnxToRef, Scratch, VecZnx, VecZnxBigAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, - VecZnxOps, ZnxZero, + ScalarZnxToRef, Scratch, VecZnx, VecZnxBigAlloc, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, + ZnxZero, }; use sampling::source::Source; @@ -85,9 +85,9 @@ impl GetRow for AutomorphismKey where MatZnxDft: MatZnxDftToRef, { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) where - VecZnxDft: VecZnxDftToMut, + R: VecZnxDftToMut, { module.vmp_extract_row(res, self, row_i, col_j); } @@ -97,9 +97,9 @@ impl SetRow for AutomorphismKey where MatZnxDft: MatZnxDftToMut, { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) where - VecZnxDft: VecZnxDftToRef, + R: VecZnxDftToRef, { module.vmp_prepare_row(self, row_i, col_j, a); } diff --git a/core/src/elem.rs b/core/src/elem.rs index 4562137..66cb1d0 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,6 +1,6 @@ -use base2k::{Backend, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos}; +use base2k::{Backend, Module, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos}; -use crate::{glwe_ciphertext_fourier::GLWECiphertextFourier, utils::derive_size}; +use crate::utils::derive_size; pub trait Infos { type Inner: ZnxInfos; @@ -47,13 +47,13 @@ pub trait Infos { } pub trait GetRow { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) where - VecZnxDft: VecZnxDftToMut; + R: VecZnxDftToMut; } pub trait SetRow { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) where - VecZnxDft: VecZnxDftToRef; + R: VecZnxDftToRef; } diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index 863fd54..f8983c8 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -1,7 +1,7 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, - ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, - VecZnxOps, ZnxInfos, ZnxZero, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos, + ZnxZero, }; use sampling::source::Source; @@ -190,9 +190,9 @@ impl GetRow for GGLWECiphertext where MatZnxDft: MatZnxDftToRef, { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) where - VecZnxDft: VecZnxDftToMut, + R: VecZnxDftToMut, { module.vmp_extract_row(res, self, row_i, col_j); } @@ -202,9 +202,9 @@ impl SetRow for GGLWECiphertext where MatZnxDft: MatZnxDftToMut, { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) where - VecZnxDft: VecZnxDftToRef, + R: VecZnxDftToRef, { module.vmp_prepare_row(self, row_i, col_j, a); } diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 4a8c0a8..1b00bd3 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -1,6 +1,6 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, - ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxBigOps, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero, }; use sampling::source::Source; @@ -12,6 +12,8 @@ use crate::{ glwe_ciphertext_fourier::GLWECiphertextFourier, glwe_plaintext::GLWEPlaintext, keys::SecretKeyFourier, + keyswitch_key::GLWESwitchingKey, + tensor_key::TensorKey, utils::derive_size, }; @@ -86,10 +88,9 @@ impl GGSWCiphertext, FFT64> { auto_key_size: usize, rank: usize, ) -> usize { - let size: usize = in_size.min(out_size); - let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, size); - let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, size); - let vmp: usize = GLWECiphertext::keyswitch_from_fourier_scratch_space(module, size, rank, size, rank, auto_key_size); + let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, auto_key_size); + let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, out_size); + let vmp: usize = GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, auto_key_size); tmp_dft + tmp_idft + vmp } @@ -197,6 +198,167 @@ where }); } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + ksk: &GLWESwitchingKey, + tsk: &TensorKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank(), + lhs.rank(), + "ggsw_out rank: {} != ggsw_in rank: {}", + self.rank(), + lhs.rank() + ); + assert_eq!( + lhs.rank(), + ksk.rank(), + "ggsw_in rank: {} != ksk rank: {}", + lhs.rank(), + ksk.rank() + ); + assert_eq!( + lhs.rank(), + tsk.rank(), + "ggsw_in rank: {} != tsk rank: {}", + lhs.rank(), + tsk.rank() + ); + } + + let cols: usize = self.rank() + 1; + + // Example for rank 3: + // + // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is + // actually composed of that many rows. + // + // # Input + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M, a0 , a1 , a2 ) + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M, b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M, c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M) + // + // # Output + // + // col 0: (-(a0s0' + a1s1' + a2s2') + M, a0 , a1 , a2 ) + // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M, b1 , b2 ) + // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M, c2 ) + // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M) + (0..self.rows()).for_each(|row_j| { + + let (tmp_dft_out_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); + + let mut tmp_dft_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_dft_out_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + { + let (tmp_dft_in_data, scratch2) = scratch1.tmp_vec_znx_dft(module, lhs.rank() + 1, lhs.size()); + + let mut tmp_dft_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_dft_in_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + // 1) Applies key-switching to GGSW[i][0]: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) + lhs.get_row(module, row_j, 0, &mut tmp_dft_in); + // (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + tmp_dft_out.keyswitch(module, &tmp_dft_in, ksk, scratch2); + self.set_row(module, row_j, 0, &tmp_dft_out); + } + + // 2) Isolates IDFT(-(a0s0' + a1s1' + a2s2') + M[i]) + let (mut tmp_c0_data, scratch2) = scratch1.tmp_vec_znx_big(module, 1, self.size()); + module.vec_znx_idft_tmp_a(&mut tmp_c0_data, 0, &mut tmp_dft_out, 0); + + // 3) Expands the i-th row of the other columns using the tensor key + // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M[i], b1 , b2 ) = KS_{s0's0', s0's1', s0's2'}(a0) + (0, -(a0s0' + a1s1' + a2s2') + M[i], 0, 0) + // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) = KS_{s1's0', s1's1', s1's2'}(a1) + (0, 0, -(a0s0' + a1s1' + a2s2') + M[i], 0) + // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) = KS_{s2's0', s2's1', s2's2'}(a2) + (0, 0, 0, -(a0s0' + a1s1' + a2s2') + M[i]) + (1..cols).for_each(|col_i| { + + let (tmp_dft_i_data, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, tsk.size()); + let mut tmp_dft_i: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_dft_i_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + // 5) Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 + // + // # Example for col=1 + // + // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) + // + + // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) + // + + // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) + // = + // (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0, x1, x2) + (1..cols).for_each(|col_j| { + + // Extracts a[i] and multipies with Enc(s'[i]s'[j]) + let (mut tmp_dft_col_data, scratch4) = scratch3.tmp_vec_znx_dft(module, 1, self.size()); + tmp_dft_col_data.extract_column(0, &tmp_dft_out.data, col_j); + + if col_j == 1 { + module.vmp_apply( + &mut tmp_dft_i, + &tmp_dft_col_data, + tsk.at(col_i - 1, col_j - 1), // Selects Enc(s'[i]s'[j]) + scratch4, + ); + } else { + module.vmp_apply_add( + &mut tmp_dft_i, + &tmp_dft_col_data, + tsk.at(col_i - 1, col_j - 1), // Selects Enc(s'[i]s'[j]) + scratch4, + ); + } + }); + + // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i + // + // (-(x0s0' + x1s1' + x2s2') + a0s0's0' + a1s0's1' + a2s0's2', x0, x1, x2) + // + + // (0, -(a0s0' + a1s1' + a2s2') + M[i], 0, 0) + // = + // (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0 -(a0s0' + a1s1' + a2s2') + M[i], x1, x2) + // = + // (-(x0s0' + x1s1' + x2s2'), x0 + M[i], x1, x2) + { + let (mut tmp_idft, scratch3) = scratch3.tmp_vec_znx_big(module, 1, tsk.size()); + let (mut tmp_znx_small, scratch5) = scratch3.tmp_vec_znx(module, 1, self.size()); + (0..cols).for_each(|i| { + module.vec_znx_idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); + module.vec_znx_big_add_inplace(&mut tmp_idft, col_i, &tmp_c0_data, 0); + module.vec_znx_big_normalize(self.basek(), &mut tmp_znx_small, 0, &tmp_idft, 0, scratch5); + module.vec_znx_dft(&mut tmp_dft_i, i, &tmp_znx_small, 0); + }); + } + + // Stores (-(x0s0' + x1s1' + x2s2'), x0 + M[i], x1, x2) + self.set_row(module, row_j, col_i, &tmp_dft_i); + }) + }) + } + pub fn automorphism( &mut self, module: &Module, @@ -224,11 +386,10 @@ where rhs.rank() ); } - - let size: usize = self.size().min(lhs.size()); +; let cols: usize = self.rank() + 1; - let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, size); + let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); //TODO optimize let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_dft_data, @@ -236,7 +397,7 @@ where k: lhs.k(), }; - let (tmp_idft_data, scratch2) = scratch1.tmp_vec_znx(module, cols, size); + let (tmp_idft_data, scratch2) = scratch1.tmp_vec_znx(module, cols, self.size()); let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: tmp_idft_data, @@ -366,14 +527,9 @@ impl GetRow for GGSWCiphertext where MatZnxDft: MatZnxDftToRef, { - fn get_row( - &self, - module: &Module, - row_i: usize, - col_j: usize, - res: &mut GLWECiphertextFourier, - ) where - VecZnxDft: VecZnxDftToMut, + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) + where + R: VecZnxDftToMut, { module.vmp_extract_row(res, self, row_i, col_j); } @@ -383,9 +539,9 @@ impl SetRow for GGSWCiphertext where MatZnxDft: MatZnxDftToMut, { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) where - VecZnxDft: VecZnxDftToRef, + R: VecZnxDftToRef, { module.vmp_prepare_row(self, row_i, col_j, a); } diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index e01df09..cade469 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -1,6 +1,6 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef, - ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, ZnxZero, + ScalarZnxToRef, Scratch, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, ZnxZero, }; use sampling::source::Source; @@ -74,9 +74,9 @@ impl GetRow for GLWESwitchingKey where MatZnxDft: MatZnxDftToRef, { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) where - VecZnxDft: VecZnxDftToMut, + R: VecZnxDftToMut, { module.vmp_extract_row(res, self, row_i, col_j); } @@ -86,9 +86,9 @@ impl SetRow for GLWESwitchingKey where MatZnxDft: MatZnxDftToMut, { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) where - VecZnxDft: VecZnxDftToRef, + R: VecZnxDftToRef, { module.vmp_prepare_row(self, row_i, col_j, a); } diff --git a/core/src/tensor_key.rs b/core/src/tensor_key.rs index 5625b51..158274d 100644 --- a/core/src/tensor_key.rs +++ b/core/src/tensor_key.rs @@ -105,15 +105,6 @@ where }) } - // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) - pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank(); - &self.keys[i * rank + j - (i * (i + 1) / 2)] - } - // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKey { if i > j { @@ -123,3 +114,17 @@ where &mut self.keys[i * rank + j - (i * (i + 1) / 2)] } } + +impl TensorKey +where + MatZnxDft: MatZnxDftToRef, +{ + // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) + pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank(); + &self.keys[i * rank + j - (i * (i + 1) / 2)] + } +} From 7d84477e6411390bc312306a5ea66f4de4e930d2 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 20 May 2025 13:51:13 +0200 Subject: [PATCH 82/87] working GGSW key-switch + added test (missing noise formula) --- core/src/ggsw_ciphertext.rs | 47 +++++++++---- core/src/test_fft64/ggsw.rs | 135 ++++++++++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 14 deletions(-) diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 1b00bd3..2546b7d 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -1,7 +1,7 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, - ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxBigOps, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, - VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero, + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, + ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, + VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero, }; use sampling::source::Source; @@ -81,6 +81,27 @@ impl GGSWCiphertext, FFT64> { + module.bytes_of_vec_znx_dft(rank + 1, size) } + pub fn keyswitch_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ksk_size: usize, + tsk_size: usize, + rank: usize, + ) -> usize { + let tmp_dft_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let vmp_ksk: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) + + GLWECiphertextFourier::keyswitch_scratch_space(module, out_size, rank, in_size, rank, ksk_size); + let tmp_c0: usize = module.bytes_of_vec_znx_big(1, out_size); + let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size); + let vmp_tsk: usize = module.bytes_of_vec_znx_dft(1, out_size) + + module.vmp_apply_tmp_bytes(out_size, out_size, rank + 1, rank + 1, rank + 1, tsk_size); + let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size); + let tmp_znx_small: usize = module.bytes_of_vec_znx(1, out_size); + let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); + tmp_dft_out + (vmp_ksk | (tmp_c0 + tmp_dft_i + (vmp_tsk | (tmp_idft + tmp_znx_small + norm)))) + } + pub fn automorphism_scratch_space( module: &Module, out_size: usize, @@ -90,7 +111,8 @@ impl GGSWCiphertext, FFT64> { ) -> usize { let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, auto_key_size); let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, out_size); - let vmp: usize = GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, auto_key_size); + let vmp: usize = + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, auto_key_size); tmp_dft + tmp_idft + vmp } @@ -198,7 +220,6 @@ where }); } - pub fn keyswitch( &mut self, module: &Module, @@ -239,7 +260,7 @@ where let cols: usize = self.rank() + 1; // Example for rank 3: - // + // // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is // actually composed of that many rows. // @@ -250,14 +271,13 @@ where // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M, c2 ) // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M) // - // # Output + // # Output // // col 0: (-(a0s0' + a1s1' + a2s2') + M, a0 , a1 , a2 ) // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M, b1 , b2 ) // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M, c2 ) // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M) (0..self.rows()).for_each(|row_j| { - let (tmp_dft_out_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); let mut tmp_dft_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { @@ -291,7 +311,6 @@ where // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) = KS_{s1's0', s1's1', s1's2'}(a1) + (0, 0, -(a0s0' + a1s1' + a2s2') + M[i], 0) // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) = KS_{s2's0', s2's1', s2's2'}(a2) + (0, 0, 0, -(a0s0' + a1s1' + a2s2') + M[i]) (1..cols).for_each(|col_i| { - let (tmp_dft_i_data, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, tsk.size()); let mut tmp_dft_i: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_dft_i_data, @@ -311,7 +330,6 @@ where // = // (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0, x1, x2) (1..cols).for_each(|col_j| { - // Extracts a[i] and multipies with Enc(s'[i]s'[j]) let (mut tmp_dft_col_data, scratch4) = scratch3.tmp_vec_znx_dft(module, 1, self.size()); tmp_dft_col_data.extract_column(0, &tmp_dft_out.data, col_j); @@ -336,7 +354,7 @@ where // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i // // (-(x0s0' + x1s1' + x2s2') + a0s0's0' + a1s0's1' + a2s0's2', x0, x1, x2) - // + + // + // (0, -(a0s0' + a1s1' + a2s2') + M[i], 0, 0) // = // (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0 -(a0s0' + a1s1' + a2s2') + M[i], x1, x2) @@ -347,7 +365,9 @@ where let (mut tmp_znx_small, scratch5) = scratch3.tmp_vec_znx(module, 1, self.size()); (0..cols).for_each(|i| { module.vec_znx_idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); - module.vec_znx_big_add_inplace(&mut tmp_idft, col_i, &tmp_c0_data, 0); + if i == col_i { + module.vec_znx_big_add_inplace(&mut tmp_idft, 0, &tmp_c0_data, 0); + } module.vec_znx_big_normalize(self.basek(), &mut tmp_znx_small, 0, &tmp_idft, 0, scratch5); module.vec_znx_dft(&mut tmp_dft_i, i, &tmp_znx_small, 0); }); @@ -385,8 +405,7 @@ where self.rank(), rhs.rank() ); - } -; + }; let cols: usize = self.rank() + 1; let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); //TODO optimize diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index 4325426..07cb650 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -12,6 +12,8 @@ use crate::{ glwe_plaintext::GLWEPlaintext, keys::{SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, + tensor_key::TensorKey, + test_fft64::gglwe::noise_gglwe_product, }; #[test] @@ -105,6 +107,139 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank: }); } +#[test] +fn keyswitch() { + (1..4).for_each(|rank| { + println!("test keyswitch rank: {}", rank); + test_keyswitch(12, 15, 60, rank, 3.2); + }); +} + +fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k + basek - 1) / basek; + + let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut tsk: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_out.size()) + | GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | GGSWCiphertext::keyswitch_scratch_space( + &module, + ct_out.size(), + ct_in.size(), + ksk.size(), + tsk.size(), + rank, + ), + ); + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + tsk.encrypt_sk( + &module, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + ct_in.encrypt_sk( + &module, + &pt_scalar, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_out.keyswitch(&module, &ct_in, &ksk, &tsk, scratch.borrow()); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_out.size()); + + (0..ct_out.rank() + 1).for_each(|col_j| { + (0..ct_out.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k, + k, + ); + + println!("{} {}", noise_have, noise_want); + + // assert!( + // (noise_have - noise_want).abs() <= 0.1, + // "{} {}", + // noise_have, + // noise_want + // ); + + pt_want.data.zero(); + }); + }); +} + // fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { // let module: Module = Module::::new(1 << log_n); // let rows: usize = (k_ggsw + basek - 1) / basek; From a803127424bdb81c0317b25ebe3053594a8822d1 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 20 May 2025 14:36:26 +0200 Subject: [PATCH 83/87] Added noise equation for keyswitch over GGSW & updated associated test --- core/src/test_fft64/automorphism_key.rs | 6 +- core/src/test_fft64/gglwe.rs | 35 +++++++++-- core/src/test_fft64/ggsw.rs | 79 ++++++++++++++++++++----- core/src/test_fft64/glwe.rs | 10 ++-- core/src/test_fft64/glwe_fourier.rs | 6 +- 5 files changed, 107 insertions(+), 29 deletions(-) diff --git a/core/src/test_fft64/automorphism_key.rs b/core/src/test_fft64/automorphism_key.rs index 6ac6b40..ea63550 100644 --- a/core/src/test_fft64/automorphism_key.rs +++ b/core/src/test_fft64/automorphism_key.rs @@ -7,7 +7,7 @@ use crate::{ glwe_ciphertext_fourier::GLWECiphertextFourier, glwe_plaintext::GLWEPlaintext, keys::{SecretKey, SecretKeyFourier}, - test_fft64::gglwe::noise_gglwe_product, + test_fft64::gglwe::log2_std_noise_gglwe_product, }; #[test] @@ -101,7 +101,7 @@ fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i); let noise_have: f64 = pt.data.std(0, basek).log2(); - let noise_want: f64 = noise_gglwe_product( + let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek, 0.5, @@ -192,7 +192,7 @@ fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i); let noise_have: f64 = pt.data.std(0, basek).log2(); - let noise_want: f64 = noise_gglwe_product( + let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek, 0.5, diff --git a/core/src/test_fft64/gglwe.rs b/core/src/test_fft64/gglwe.rs index f0ccdb5..d497dbf 100644 --- a/core/src/test_fft64/gglwe.rs +++ b/core/src/test_fft64/gglwe.rs @@ -212,7 +212,7 @@ fn test_key_switch( module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, col_i); let noise_have: f64 = pt.data.std(0, basek).log2(); - let noise_want: f64 = noise_gglwe_product( + let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek, 0.5, @@ -314,7 +314,7 @@ fn test_key_switch_inplace(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, col_i); let noise_have: f64 = pt.data.std(0, basek).log2(); - let noise_want: f64 = noise_gglwe_product( + let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek, 0.5, @@ -571,7 +571,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k: usize, sigma: f6 }); } -pub(crate) fn noise_gglwe_product( +pub(crate) fn var_noise_gglwe_product( n: f64, basek: usize, var_xs: f64, @@ -597,7 +597,34 @@ pub(crate) fn noise_gglwe_product( let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); noise += var_msg * var_a_err * a_scale * a_scale * n; noise *= rank_in; + noise /= b_scale * b_scale; + noise +} + +pub(crate) fn log2_std_noise_gglwe_product( + n: f64, + basek: usize, + var_xs: f64, + var_msg: f64, + var_a_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + rank_in: f64, + a_logq: usize, + b_logq: usize, +) -> f64 { + let mut noise: f64 = var_noise_gglwe_product( + n, + basek, + var_xs, + var_msg, + var_a_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank_in, + a_logq, + b_logq, + ); noise = noise.sqrt(); - noise /= b_scale; noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] } diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index 07cb650..40237c9 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -13,9 +13,11 @@ use crate::{ keys::{SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, tensor_key::TensorKey, - test_fft64::gglwe::noise_gglwe_product, + test_fft64::gglwe::log2_std_noise_gglwe_product, }; +use super::gglwe::var_noise_gglwe_product; + #[test] fn encrypt_sk() { (1..4).for_each(|rank| { @@ -146,14 +148,16 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) ), ); + let var_xs: f64 = 0.5; + let mut sk_in: SecretKey> = SecretKey::new(&module, rank); - sk_in.fill_ternary_prob(0.5, &mut source_xs); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_in_dft.dft(&module, &sk_in); let mut sk_out: SecretKey> = SecretKey::new(&module, rank); - sk_out.fill_ternary_prob(0.5, &mut source_xs); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_out_dft.dft(&module, &sk_out); @@ -213,11 +217,11 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = noise_gglwe_product( + let noise_want: f64 = noise_ggsw_keyswitch( module.n() as f64, basek, - 0.5, - 0.5, + col_j, + var_xs, 0f64, sigma * sigma, 0f64, @@ -226,20 +230,67 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) k, ); - println!("{} {}", noise_have, noise_want); - - // assert!( - // (noise_have - noise_want).abs() <= 0.1, - // "{} {}", - // noise_have, - // noise_want - // ); + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); pt_want.data.zero(); }); }); } +pub(crate) fn noise_ggsw_keyswitch( + n: f64, + basek: usize, + col: usize, + var_xs: f64, + var_a_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + rank: f64, + a_logq: usize, + b_logq: usize, +) -> f64 { + let var_si_x_sj: f64 = n * var_xs * var_xs; + + // Initial KS for col = 0 + let mut noise: f64 = var_noise_gglwe_product( + n, + basek, + var_xs, + var_xs, + var_a_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank, + a_logq, + b_logq, + ); + + // Other GGSW reconstruction for col > 0 + if col > 0 { + noise += var_noise_gglwe_product( + n, + basek, + var_xs, + var_si_x_sj, + var_a_err + 1f64 / 12.0, + var_gct_err_lhs, + var_gct_err_rhs, + rank, + a_logq, + b_logq, + ); + noise += n * noise * var_xs * 0.5; + } + + noise = noise.sqrt(); + noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] +} + // fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { // let module: Module = Module::::new(1 << log_n); // let rows: usize = (k_ggsw + basek - 1) / basek; diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index e0323fa..0f7fcc1 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -14,7 +14,7 @@ use crate::{ glwe_plaintext::GLWEPlaintext, keys::{GLWEPublicKey, SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, - test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_product}, + test_fft64::{gglwe::log2_std_noise_gglwe_product, ggsw::noise_ggsw_product}, }; #[test] @@ -326,7 +326,7 @@ fn test_keyswitch( module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = noise_gglwe_product( + let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek, 0.5, @@ -411,7 +411,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = noise_gglwe_product( + let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek, 0.5, @@ -502,7 +502,7 @@ fn test_automorphism( println!("{}", noise_have); - let noise_want: f64 = noise_gglwe_product( + let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek, 0.5, @@ -581,7 +581,7 @@ fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usiz module.vec_znx_normalize_inplace(basek, &mut pt_have, 0, scratch.borrow()); let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = noise_gglwe_product( + let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek, 0.5, diff --git a/core/src/test_fft64/glwe_fourier.rs b/core/src/test_fft64/glwe_fourier.rs index 3887558..d8bd11c 100644 --- a/core/src/test_fft64/glwe_fourier.rs +++ b/core/src/test_fft64/glwe_fourier.rs @@ -6,7 +6,7 @@ use crate::{ glwe_plaintext::GLWEPlaintext, keys::{SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, - test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_product}, + test_fft64::{gglwe::log2_std_noise_gglwe_product, ggsw::noise_ggsw_product}, }; use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; use sampling::source::Source; @@ -132,7 +132,7 @@ fn test_keyswitch( module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = noise_gglwe_product( + let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek, 0.5, @@ -220,7 +220,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = noise_gglwe_product( + let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek, 0.5, From 640ff9ea614e11405966d738728f496b51823f0c Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 20 May 2025 17:42:43 +0200 Subject: [PATCH 84/87] Refactor of GGSW key-switch to enable easier implementation of GGSW automorphism --- base2k/src/module.rs | 5 +- base2k/src/scalar_znx_dft.rs | 11 +- base2k/src/vec_znx.rs | 14 +- base2k/src/vec_znx_big.rs | 10 +- base2k/src/vec_znx_big_ops.rs | 2 - base2k/src/vec_znx_dft_ops.rs | 69 +++++++ base2k/src/vec_znx_ops.rs | 6 +- core/src/ggsw_ciphertext.rs | 332 +++++++++++++++++++--------------- core/src/test_fft64/ggsw.rs | 243 +++++++++++++------------ 9 files changed, 404 insertions(+), 288 deletions(-) diff --git a/base2k/src/module.rs b/base2k/src/module.rs index 8ee6e4b..f6d0e0e 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -78,10 +78,7 @@ impl Module { if gal_el == 0 { panic!("cannot invert 0") } - ((mod_exp_u64( - gal_el.abs() as u64, - (self.cyclotomic_order() - 1) as usize, - ) & (self.cyclotomic_order() - 1)) as i64) + ((mod_exp_u64(gal_el.abs() as u64, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1)) as i64) * gal_el.signum() } } diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index 248b87d..fa4ab10 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -2,7 +2,10 @@ use std::marker::PhantomData; use crate::ffi::svp; use crate::znx_base::ZnxInfos; -use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView, FFT64}; +use crate::{ + Backend, DataView, DataViewMut, FFT64, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView, + alloc_aligned, +}; pub struct ScalarZnxDft { data: D, @@ -93,8 +96,8 @@ impl ScalarZnxDft { } } - pub fn as_vec_znx_dft(self) -> VecZnxDft{ - VecZnxDft{ + pub fn as_vec_znx_dft(self) -> VecZnxDft { + VecZnxDft { data: self.data, n: self.n, cols: self.cols, @@ -227,4 +230,4 @@ impl VecZnxDftToRef for ScalarZnxDft<&[u8], B> { _phantom: PhantomData, } } -} \ No newline at end of file +} diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 5d9f1ca..d4b0b9c 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -130,9 +130,13 @@ impl VecZnx { } } - pub fn to_scalar_znx(self) -> ScalarZnx{ - debug_assert_eq!(self.size, 1, "cannot convert VecZnx to ScalarZnx if cols: {} != 1", self.cols); - ScalarZnx{ + pub fn to_scalar_znx(self) -> ScalarZnx { + debug_assert_eq!( + self.size, 1, + "cannot convert VecZnx to ScalarZnx if cols: {} != 1", + self.cols + ); + ScalarZnx { data: self.data, n: self.n, cols: self.cols, @@ -198,9 +202,9 @@ where VecZnx: VecZnxToMut + ZnxInfos, { /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. - pub fn extract_column(&mut self, self_col: usize, a: &VecZnx, a_col: usize) + pub fn extract_column(&mut self, self_col: usize, a: &R, a_col: usize) where - VecZnx: VecZnxToRef + ZnxInfos, + R: VecZnxToRef + ZnxInfos, { #[cfg(debug_assertions)] { diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index eba90e9..2bf4dcc 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,6 +1,6 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; -use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, FFT64}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned}; use std::fmt; use std::marker::PhantomData; @@ -97,11 +97,11 @@ impl VecZnxBig { impl VecZnxBig where VecZnxBig: VecZnxBigToMut + ZnxInfos, -{ - // Consumes the VecZnxBig to return a VecZnx. +{ + // Consumes the VecZnxBig to return a VecZnx. // Useful when no normalization is needed. - pub fn to_vec_znx_small(self) -> VecZnx{ - VecZnx{ + pub fn to_vec_znx_small(self) -> VecZnx { + VecZnx { data: self.data, n: self.n, cols: self.cols, diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index f6dad7a..8208c97 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -147,7 +147,6 @@ pub trait VecZnxBigOps { fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) where A: VecZnxBigToMut; - } pub trait VecZnxBigScratch { @@ -170,7 +169,6 @@ impl VecZnxBigAlloc for Module { } impl VecZnxBigOps for Module { - fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) where R: VecZnxBigToMut, diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 27e6f59..3e5965b 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -42,6 +42,17 @@ pub trait VecZnxDftOps { /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. fn vec_znx_idft_tmp_bytes(&self) -> usize; + fn vec_znx_dft_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef; + + fn vec_znx_dft_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; + fn vec_znx_dft_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, @@ -84,6 +95,64 @@ impl VecZnxDftAlloc for Module { } impl VecZnxDftOps for Module { + fn vec_znx_dft_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_add( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } + + fn vec_znx_dft_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_add( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } + fn vec_znx_dft_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 85c2e1f..b97e6b7 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -574,7 +574,11 @@ impl VecZnxOps for Module { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); - assert!(k & 1 != 0, "invalid galois element: must be odd but is {}", k); + assert!( + k & 1 != 0, + "invalid galois element: must be odd but is {}", + k + ); } unsafe { vec_znx::vec_znx_automorphism( diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 2546b7d..e8f913c 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -1,7 +1,8 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, - ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, - VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero, + ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, + VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, + VecZnxToRef, ZnxInfos, ZnxZero, }; use sampling::source::Source; @@ -81,6 +82,26 @@ impl GGSWCiphertext, FFT64> { + module.bytes_of_vec_znx_dft(rank + 1, size) } + pub(crate) fn expand_row_scratch_space(module: &Module, self_size: usize, tsk_size: usize, rank: usize) -> usize { + let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size); + let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size); + let vmp: usize = tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tsk_size); + let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size); + let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); + tmp_dft_i + ((tmp_dft_col_data + vmp) | (tmp_idft + norm)) + } + + pub(crate) fn keyswitch_internal_col0_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ksk_size: usize, + rank: usize, + ) -> usize { + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, ksk_size) + + module.bytes_of_vec_znx_dft(rank + 1, in_size) + } + pub fn keyswitch_scratch_space( module: &Module, out_size: usize, @@ -89,17 +110,12 @@ impl GGSWCiphertext, FFT64> { tsk_size: usize, rank: usize, ) -> usize { - let tmp_dft_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); - let vmp_ksk: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) - + GLWECiphertextFourier::keyswitch_scratch_space(module, out_size, rank, in_size, rank, ksk_size); - let tmp_c0: usize = module.bytes_of_vec_znx_big(1, out_size); - let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size); - let vmp_tsk: usize = module.bytes_of_vec_znx_dft(1, out_size) - + module.vmp_apply_tmp_bytes(out_size, out_size, rank + 1, rank + 1, rank + 1, tsk_size); - let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size); - let tmp_znx_small: usize = module.bytes_of_vec_znx(1, out_size); - let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); - tmp_dft_out + (vmp_ksk | (tmp_c0 + tmp_dft_i + (vmp_tsk | (tmp_idft + tmp_znx_small + norm)))) + let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size); + let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ks: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, ksk_size, rank); + let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tsk_size, rank); + let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + res_znx + ci_dft + (ks | expand_rows | res_dft) } pub fn automorphism_scratch_space( @@ -186,19 +202,19 @@ where k, }; - (0..self.rows()).for_each(|row_j| { + (0..self.rows()).for_each(|row_i| { vec_znx_pt.data.zero(); // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0); + module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0); module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scrach_2); - (0..cols).for_each(|col_i| { + (0..cols).for_each(|col_j| { // rlwe encrypt of vec_znx_pt into vec_znx_ct vec_znx_ct.encrypt_sk_private( module, - Some((&vec_znx_pt, col_i)), + Some((&vec_znx_pt, col_j)), sk_dft, source_xa, source_xe, @@ -214,167 +230,151 @@ where module.vec_znx_dft(&mut vec_znx_dft_ct, i, &vec_znx_ct, i); }); - module.vmp_prepare_row(self, row_j, col_i, &vec_znx_dft_ct); + self.set_row(module, row_i, col_j, &vec_znx_dft_ct); } }); }); } - pub fn keyswitch( + pub(crate) fn expand_row( &mut self, module: &Module, - lhs: &GGSWCiphertext, - ksk: &GLWESwitchingKey, - tsk: &TensorKey, + col_j: usize, + res: &mut R, + ci_dft: &VecZnxDft, + tsk: &TensorKey, scratch: &mut Scratch, ) where - MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, + R: VecZnxToMut, + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank(), - lhs.rank(), - "ggsw_out rank: {} != ggsw_in rank: {}", - self.rank(), - lhs.rank() - ); - assert_eq!( - lhs.rank(), - ksk.rank(), - "ggsw_in rank: {} != ksk rank: {}", - lhs.rank(), - ksk.rank() - ); - assert_eq!( - lhs.rank(), - tsk.rank(), - "ggsw_in rank: {} != tsk rank: {}", - lhs.rank(), - tsk.rank() - ); - } - let cols: usize = self.rank() + 1; // Example for rank 3: // // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is - // actually composed of that many rows. + // actually composed of that many rows and we focus on a specific row here + // implicitely given ci_dft. // // # Input // - // col 0: (-(a0s0 + a1s1 + a2s2) + M, a0 , a1 , a2 ) - // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M, b1 , b2 ) - // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M, c2 ) - // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M) + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (0, 0, 0, 0) + // col 2: (0, 0, 0, 0) + // col 3: (0, 0, 0, 0) // // # Output // - // col 0: (-(a0s0' + a1s1' + a2s2') + M, a0 , a1 , a2 ) - // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M, b1 , b2 ) - // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M, c2 ) - // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M) - (0..self.rows()).for_each(|row_j| { - let (tmp_dft_out_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) - let mut tmp_dft_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_dft_out_data, - basek: lhs.basek(), - k: lhs.k(), - }; + let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size()); + { + let (mut tmp_dft_col_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, self.size()); - { - let (tmp_dft_in_data, scratch2) = scratch1.tmp_vec_znx_dft(module, lhs.rank() + 1, lhs.size()); - - let mut tmp_dft_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_dft_in_data, - basek: lhs.basek(), - k: lhs.k(), - }; - - // 1) Applies key-switching to GGSW[i][0]: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) - lhs.get_row(module, row_j, 0, &mut tmp_dft_in); - // (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) - tmp_dft_out.keyswitch(module, &tmp_dft_in, ksk, scratch2); - self.set_row(module, row_j, 0, &tmp_dft_out); - } - - // 2) Isolates IDFT(-(a0s0' + a1s1' + a2s2') + M[i]) - let (mut tmp_c0_data, scratch2) = scratch1.tmp_vec_znx_big(module, 1, self.size()); - module.vec_znx_idft_tmp_a(&mut tmp_c0_data, 0, &mut tmp_dft_out, 0); - - // 3) Expands the i-th row of the other columns using the tensor key - // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M[i], b1 , b2 ) = KS_{s0's0', s0's1', s0's2'}(a0) + (0, -(a0s0' + a1s1' + a2s2') + M[i], 0, 0) - // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) = KS_{s1's0', s1's1', s1's2'}(a1) + (0, 0, -(a0s0' + a1s1' + a2s2') + M[i], 0) - // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) = KS_{s2's0', s2's1', s2's2'}(a2) + (0, 0, 0, -(a0s0' + a1s1' + a2s2') + M[i]) + // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 + // + // # Example for col=1 + // + // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) + // + + // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) + // + + // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) (1..cols).for_each(|col_i| { - let (tmp_dft_i_data, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, tsk.size()); - let mut tmp_dft_i: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_dft_i_data, - basek: lhs.basek(), - k: lhs.k(), - }; + // Extracts a[i] and multipies with Enc(s[i]s[j]) + tmp_dft_col_data.extract_column(0, ci_dft, col_i); - // 5) Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 - // - // # Example for col=1 - // - // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) - // + - // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) - // + - // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) - // = - // (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0, x1, x2) - (1..cols).for_each(|col_j| { - // Extracts a[i] and multipies with Enc(s'[i]s'[j]) - let (mut tmp_dft_col_data, scratch4) = scratch3.tmp_vec_znx_dft(module, 1, self.size()); - tmp_dft_col_data.extract_column(0, &tmp_dft_out.data, col_j); + if col_i == 1 { + module.vmp_apply( + &mut tmp_dft_i, + &tmp_dft_col_data, + tsk.at(col_i - 1, col_j - 1), // Selects Enc(s[i]s[j]) + scratch2, + ); + } else { + module.vmp_apply_add( + &mut tmp_dft_i, + &tmp_dft_col_data, + tsk.at(col_i - 1, col_j - 1), // Selects Enc(s[i]s[j]) + scratch2, + ); + } + }); + } - if col_j == 1 { - module.vmp_apply( - &mut tmp_dft_i, - &tmp_dft_col_data, - tsk.at(col_i - 1, col_j - 1), // Selects Enc(s'[i]s'[j]) - scratch4, - ); - } else { - module.vmp_apply_add( - &mut tmp_dft_i, - &tmp_dft_col_data, - tsk.at(col_i - 1, col_j - 1), // Selects Enc(s'[i]s'[j]) - scratch4, - ); - } + // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i + // + // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) + // + + // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) + // = + // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) + module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, ci_dft, 0); + let (mut tmp_idft, scratch2) = scratch1.tmp_vec_znx_big(module, 1, tsk.size()); + (0..cols).for_each(|i| { + module.vec_znx_idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); + module.vec_znx_big_normalize(self.basek(), res, i, &tmp_idft, 0, scratch2); + }); + } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + ksk: &GLWESwitchingKey, + tsk: &TensorKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + let cols: usize = self.rank() + 1; + + let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size()); + let mut res: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: res_data, + basek: self.basek(), + k: self.k(), + }; + + let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size()); + + // Keyswitch the j-th row of the col 0 + (0..lhs.rows()).for_each(|row_i| { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + lhs.keyswitch_internal_col0(module, row_i, &mut res, ksk, scratch2); + + // Isolates DFT(a[i]) + (0..cols).for_each(|col_i| { + module.vec_znx_dft(&mut ci_dft, col_i, &res, col_i); + }); + + self.set_row(module, row_i, 0, &ci_dft); + + // Generates + // + // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) + (1..cols).for_each(|col_j| { + self.expand_row(module, col_j, &mut res, &ci_dft, tsk, scratch2); + + let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size()); + (0..cols).for_each(|i| { + module.vec_znx_dft(&mut res_dft, i, &res, i); }); - // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i - // - // (-(x0s0' + x1s1' + x2s2') + a0s0's0' + a1s0's1' + a2s0's2', x0, x1, x2) - // + - // (0, -(a0s0' + a1s1' + a2s2') + M[i], 0, 0) - // = - // (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0 -(a0s0' + a1s1' + a2s2') + M[i], x1, x2) - // = - // (-(x0s0' + x1s1' + x2s2'), x0 + M[i], x1, x2) - { - let (mut tmp_idft, scratch3) = scratch3.tmp_vec_znx_big(module, 1, tsk.size()); - let (mut tmp_znx_small, scratch5) = scratch3.tmp_vec_znx(module, 1, self.size()); - (0..cols).for_each(|i| { - module.vec_znx_idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); - if i == col_i { - module.vec_znx_big_add_inplace(&mut tmp_idft, 0, &tmp_c0_data, 0); - } - module.vec_znx_big_normalize(self.basek(), &mut tmp_znx_small, 0, &tmp_idft, 0, scratch5); - module.vec_znx_dft(&mut tmp_dft_i, i, &tmp_znx_small, 0); - }); - } - - // Stores (-(x0s0' + x1s1' + x2s2'), x0 + M[i], x1, x2) - self.set_row(module, row_j, col_i, &tmp_dft_i); + self.set_row(module, row_i, col_j, &res_dft); }) }) } @@ -542,6 +542,38 @@ where } } +impl GGSWCiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + pub(crate) fn keyswitch_internal_col0( + &self, + module: &Module, + row_i: usize, + res: &mut GLWECiphertext, + ksk: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut + VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), ksk.rank()); + assert_eq!(res.rank(), ksk.rank()); + } + + let (tmp_dft_in_data, scratch2) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); + let mut tmp_dft_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_dft_in_data, + basek: self.basek(), + k: self.k(), + }; + self.get_row(module, row_i, 0, &mut tmp_dft_in); + res.keyswitch_from_fourier(module, &tmp_dft_in, ksk, scratch2); + } +} + impl GetRow for GGSWCiphertext where MatZnxDft: MatZnxDftToRef, diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index 40237c9..2ee8708 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -13,7 +13,6 @@ use crate::{ keys::{SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, tensor_key::TensorKey, - test_fft64::gglwe::log2_std_noise_gglwe_product, }; use super::gglwe::var_noise_gglwe_product; @@ -230,6 +229,8 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) k, ); + println!("{} {}", noise_have, noise_want); + assert!( (noise_have - noise_want).abs() <= 0.1, "{} {}", @@ -291,122 +292,130 @@ pub(crate) fn noise_ggsw_keyswitch( noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] } -// fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { -// let module: Module = Module::::new(1 << log_n); -// let rows: usize = (k_ggsw + basek - 1) / basek; -// -// let mut ct_ggsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); -// let mut ct_ggsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); -// let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank); -// -// let mut pt_ggsw_in: ScalarZnx> = module.new_scalar_znx(1); -// let mut pt_ggsw_out: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// pt_ggsw_in.fill_ternary_prob(0, 0.5, &mut source_xs); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_out.size()) -// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_in.size()) -// | GGSWCiphertext::automorphism_scratch_space( -// &module, -// ct_ggsw_out.size(), -// ct_ggsw_in.size(), -// auto_key.size(), -// rank, -// ), -// ); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// ct_ggsw_in.encrypt_sk( -// &module, -// &pt_ggsw_in, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// auto_key.encrypt_sk( -// &module, -// p, -// &sk, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_ggsw_out.automorphism(&module, &ct_ggsw_in, &auto_key, scratch.borrow()); -// -// let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); -// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size()); -// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size()); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); -// -// module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); -// -// (0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| { -// (0..ct_ggsw_lhs_out.rows()).for_each(|row_i| { -// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0); -// -// if col_j > 0 { -// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); -// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); -// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); -// module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); -// } -// -// ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); -// ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); -// -// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); -// -// let noise_have: f64 = pt.data.std(0, basek).log2(); -// -// let var_gct_err_lhs: f64 = sigma * sigma; -// let var_gct_err_rhs: f64 = 0f64; -// -// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} -// let var_a0_err: f64 = sigma * sigma; -// let var_a1_err: f64 = 1f64 / 12f64; -// -// let noise_want: f64 = noise_ggsw_product( -// module.n() as f64, -// basek, -// 0.5, -// var_msg, -// var_a0_err, -// var_a1_err, -// var_gct_err_lhs, -// var_gct_err_rhs, -// rank as f64, -// k_ggsw, -// k_ggsw, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "have: {} want: {}", -// noise_have, -// noise_want -// ); -// -// pt_want.data.zero(); -// }); -// }); -// } +fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k + basek - 1) / basek; + + let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut tsk: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_out.size()) + | GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | GGSWCiphertext::keyswitch_scratch_space( + &module, + ct_out.size(), + ct_in.size(), + ksk.size(), + tsk.size(), + rank, + ), + ); + + let var_xs: f64 = 0.5; + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + tsk.encrypt_sk( + &module, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + ct_in.encrypt_sk( + &module, + &pt_scalar, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_out.keyswitch(&module, &ct_in, &ksk, &tsk, scratch.borrow()); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_out.size()); + + (0..ct_out.rank() + 1).for_each(|col_j| { + (0..ct_out.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_ggsw_keyswitch( + module.n() as f64, + basek, + col_j, + var_xs, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k, + k, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); +} fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); From fb35dfa0f7c0d60e0cbcf91fcd40a1f7e66805ba Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 20 May 2025 21:38:38 +0200 Subject: [PATCH 85/87] Added automorphism + test on GGSW --- core/src/ggsw_ciphertext.rs | 112 ++++++++++++++++++++++-------------- core/src/test_fft64/ggsw.rs | 60 ++++++++++--------- 2 files changed, 101 insertions(+), 71 deletions(-) diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index e8f913c..28c5961 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -82,11 +82,17 @@ impl GGSWCiphertext, FFT64> { + module.bytes_of_vec_znx_dft(rank + 1, size) } - pub(crate) fn expand_row_scratch_space(module: &Module, self_size: usize, tsk_size: usize, rank: usize) -> usize { - let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size); + pub(crate) fn expand_row_scratch_space( + module: &Module, + self_size: usize, + tensor_key_size: usize, + rank: usize, + ) -> usize { + let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tensor_key_size); let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size); - let vmp: usize = tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tsk_size); - let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size); + let vmp: usize = + tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tensor_key_size); + let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tensor_key_size); let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); tmp_dft_i + ((tmp_dft_col_data + vmp) | (tmp_idft + norm)) } @@ -107,13 +113,13 @@ impl GGSWCiphertext, FFT64> { out_size: usize, in_size: usize, ksk_size: usize, - tsk_size: usize, + tensor_key_size: usize, rank: usize, ) -> usize { let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size); let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); let ks: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, ksk_size, rank); - let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tsk_size, rank); + let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tensor_key_size, rank); let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); res_znx + ci_dft + (ks | expand_rows | res_dft) } @@ -123,13 +129,17 @@ impl GGSWCiphertext, FFT64> { out_size: usize, in_size: usize, auto_key_size: usize, + tensor_key_size: usize, rank: usize, ) -> usize { - let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, auto_key_size); - let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, out_size); - let vmp: usize = - GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, auto_key_size); - tmp_dft + tmp_idft + vmp + GGSWCiphertext::keyswitch_scratch_space( + module, + out_size, + in_size, + auto_key_size, + tensor_key_size, + rank, + ) } pub fn external_product_scratch_space( @@ -379,15 +389,17 @@ where }) } - pub fn automorphism( + pub fn automorphism( &mut self, module: &Module, lhs: &GGSWCiphertext, - rhs: &AutomorphismKey, + auto_key: &AutomorphismKey, + tensor_key: &TensorKey, scratch: &mut Scratch, ) where MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, { #[cfg(debug_assertions)] { @@ -400,48 +412,62 @@ where ); assert_eq!( self.rank(), - rhs.rank(), + auto_key.rank(), "ggsw_in rank: {} != auto_key rank: {}", self.rank(), - rhs.rank() + auto_key.rank() + ); + assert_eq!( + self.rank(), + tensor_key.rank(), + "ggsw_in rank: {} != tensor_key rank: {}", + self.rank(), + tensor_key.rank() ); }; + let cols: usize = self.rank() + 1; - let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); //TODO optimize - - let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_dft_data, - basek: lhs.basek(), - k: lhs.k(), - }; - - let (tmp_idft_data, scratch2) = scratch1.tmp_vec_znx(module, cols, self.size()); - - let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { - data: tmp_idft_data, + let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size()); + let mut res: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: res_data, basek: self.basek(), k: self.k(), }; - (0..cols).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { - lhs.get_row(module, row_j, col_i, &mut tmp_dft); - tmp_idft.keyswitch_from_fourier(module, &tmp_dft, &rhs.key, scratch2); + let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size()); + + // Keyswitch the j-th row of the col 0 + (0..lhs.rows()).for_each(|row_i| { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) + lhs.keyswitch_internal_col0(module, row_i, &mut res, &auto_key.key, scratch2); + + // Isolates DFT(AUTO(a[i])) + (0..cols).for_each(|col_i| { + // (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) -> (-(a0s0 + a1s1 + a2s2) + pi(M[i]), a0, a1, a2) + module.vec_znx_automorphism_inplace(auto_key.p(), &mut res, col_i); + module.vec_znx_dft(&mut ci_dft, col_i, &res, col_i); + }); + + self.set_row(module, row_i, 0, &ci_dft); + + // Generates + // + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + pi(M[i]), b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + pi(M[i]), c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + pi(M[i])) + (1..cols).for_each(|col_j| { + self.expand_row(module, col_j, &mut res, &ci_dft, tensor_key, scratch2); + + let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size()); (0..cols).for_each(|i| { - module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i); + module.vec_znx_dft(&mut res_dft, i, &res, i); }); - self.set_row(module, row_j, col_i, &tmp_dft); - }); - }); - tmp_dft.data.zero(); - - (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { - (0..self.rank() + 1).for_each(|col_j| { - self.set_row(module, row_i, col_j, &tmp_dft); - }); - }); + self.set_row(module, row_i, col_j, &res_dft); + }) + }) } pub fn external_product( diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index 2ee8708..cf5e3c7 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -1,6 +1,6 @@ use base2k::{ - FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, + FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScalarZnxOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, + VecZnxBigOps, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, }; use sampling::source::Source; @@ -292,14 +292,22 @@ pub(crate) fn noise_ggsw_keyswitch( noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] } +#[test] +fn automorphism() { + (1..4).for_each(|rank| { + println!("test automorphism rank: {}", rank); + test_automorphism(-5, 12, 15, 60, rank, 3.2); + }); +} + fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); let rows: usize = (k + basek - 1) / basek; let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); - let mut tsk: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank, rank); + let mut tensor_key: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); @@ -311,44 +319,38 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_in.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_out.size()) - | GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) - | TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) - | GGSWCiphertext::keyswitch_scratch_space( + | AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size()) + | TensorKey::encrypt_sk_scratch_space(&module, rank, tensor_key.size()) + | GGSWCiphertext::automorphism_scratch_space( &module, ct_out.size(), ct_in.size(), - ksk.size(), - tsk.size(), + auto_key.size(), + tensor_key.size(), rank, ), ); let var_xs: f64 = 0.5; - let mut sk_in: SecretKey> = SecretKey::new(&module, rank); - sk_in.fill_ternary_prob(var_xs, &mut source_xs); + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_in_dft.dft(&module, &sk_in); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); - let mut sk_out: SecretKey> = SecretKey::new(&module, rank); - sk_out.fill_ternary_prob(var_xs, &mut source_xs); - - let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_out_dft.dft(&module, &sk_out); - - ksk.encrypt_sk( + auto_key.encrypt_sk( &module, - &sk_in, - &sk_out_dft, + p, + &sk, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - tsk.encrypt_sk( + tensor_key.encrypt_sk( &module, - &sk_out_dft, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -360,14 +362,16 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, ct_in.encrypt_sk( &module, &pt_scalar, - &sk_in_dft, + &sk_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct_out.keyswitch(&module, &ct_in, &ksk, &tsk, scratch.borrow()); + ct_out.automorphism(&module, &ct_in, &auto_key, &tensor_key, scratch.borrow()); + + module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); @@ -380,14 +384,14 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); } ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); From ba27dcf3e6b068199212e1456ed1052e1e7d5a4e Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 21 May 2025 09:33:11 +0200 Subject: [PATCH 86/87] fixed a typo & small optimization --- base2k/src/vec_znx_dft_ops.rs | 3 --- core/src/automorphism.rs | 4 ++-- core/src/glwe_ciphertext.rs | 15 ++++----------- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 3e5965b..e4d6c33 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -148,9 +148,6 @@ impl VecZnxDftOps for Module { ); }); } - (min_size..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); - }) } fn vec_znx_dft_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index 8b4fe3a..8dca7ec 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -352,12 +352,12 @@ where pub fn keyswitch_inplace( &mut self, module: &Module, - rhs: &AutomorphismKey, + rhs: &GLWESwitchingKey, scratch: &mut base2k::Scratch, ) where MatZnxDft: MatZnxDftToRef, { - self.key.keyswitch_inplace(module, &rhs.key, scratch); + self.key.keyswitch_inplace(module, &rhs, scratch); } pub fn external_product( diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index ca94db1..e319d21 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -141,11 +141,9 @@ impl GLWECiphertext> { let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size) + module.bytes_of_vec_znx_dft(in_rank, in_size); - let a0_big: usize = module.bytes_of_vec_znx_big(1, in_size) + module.vec_znx_idft_tmp_bytes(); - let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); - res_dft + (vmp | a0_big | norm) + res_dft + (vmp | norm) } pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { @@ -362,15 +360,10 @@ where module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2); } - // Switches result of VMP outside of DFT - let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); + module.vec_znx_dft_add_inplace(&mut res_dft, 0, lhs, 0); - { - // Switches lhs 0-th outside of DFT domain and adds on - let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size()); - module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2); - module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0); - } + // Switches result of VMP outside of DFT + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); (0..cols_out).for_each(|i| { module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); From fa067228da0fe4cdfbfd432608c7834f1f0509a0 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 21 May 2025 11:31:28 +0200 Subject: [PATCH 87/87] Added remaining missing test --- core/src/ggsw_ciphertext.rs | 59 ++++++++ core/src/test_fft64/ggsw.rs | 280 +++++++++++++++++++++++++++++++++--- 2 files changed, 323 insertions(+), 16 deletions(-) diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 28c5961..8955b7a 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -124,6 +124,16 @@ impl GGSWCiphertext, FFT64> { res_znx + ci_dft + (ks | expand_rows | res_dft) } + pub fn keyswitch_inplace_scratch_space( + module: &Module, + out_size: usize, + ksk_size: usize, + tensor_key_size: usize, + rank: usize, + ) -> usize { + GGSWCiphertext::keyswitch_scratch_space(module, out_size, out_size, ksk_size, tensor_key_size, rank) + } + pub fn automorphism_scratch_space( module: &Module, out_size: usize, @@ -142,6 +152,23 @@ impl GGSWCiphertext, FFT64> { ) } + pub fn automorphism_inplace_scratch_space( + module: &Module, + out_size: usize, + auto_key_size: usize, + tensor_key_size: usize, + rank: usize, + ) -> usize { + GGSWCiphertext::automorphism_scratch_space( + module, + out_size, + out_size, + auto_key_size, + tensor_key_size, + rank, + ) + } + pub fn external_product_scratch_space( module: &Module, out_size: usize, @@ -389,6 +416,22 @@ where }) } + pub fn keyswitch_inplace( + &mut self, + module: &Module, + ksk: &GLWESwitchingKey, + tsk: &TensorKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + unsafe { + let self_ptr: *mut GGSWCiphertext = self as *mut GGSWCiphertext; + self.keyswitch(module, &*self_ptr, ksk, tsk, scratch); + } + } + pub fn automorphism( &mut self, module: &Module, @@ -470,6 +513,22 @@ where }) } + pub fn automorphism_inplace( + &mut self, + module: &Module, + auto_key: &AutomorphismKey, + tensor_key: &TensorKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + unsafe { + let self_ptr: *mut GGSWCiphertext = self as *mut GGSWCiphertext; + self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch); + } + } + pub fn external_product( &mut self, module: &Module, diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index cf5e3c7..f02bd87 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -25,6 +25,38 @@ fn encrypt_sk() { }); } +#[test] +fn keyswitch() { + (1..4).for_each(|rank| { + println!("test keyswitch rank: {}", rank); + test_keyswitch(12, 15, 60, rank, 3.2); + }); +} + +#[test] +fn keyswitch_inplace() { + (1..4).for_each(|rank| { + println!("test keyswitch_inplace rank: {}", rank); + test_keyswitch_inplace(12, 15, 60, rank, 3.2); + }); +} + +#[test] +fn automorphism() { + (1..4).for_each(|rank| { + println!("test automorphism rank: {}", rank); + test_automorphism(-5, 12, 15, 60, rank, 3.2); + }); +} + +#[test] +fn automorphism_inplace() { + (1..4).for_each(|rank| { + println!("test automorphism_inplace rank: {}", rank); + test_automorphism_inplace(-5, 12, 15, 60, rank, 3.2); + }); +} + #[test] fn external_product() { (1..4).for_each(|rank| { @@ -108,14 +140,6 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank: }); } -#[test] -fn keyswitch() { - (1..4).for_each(|rank| { - println!("test keyswitch rank: {}", rank); - test_keyswitch(12, 15, 60, rank, 3.2); - }); -} - fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); let rows: usize = (k + basek - 1) / basek; @@ -243,6 +267,125 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) }); } +fn test_keyswitch_inplace(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k + basek - 1) / basek; + + let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut tsk: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()) + | GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct.size(), ksk.size(), tsk.size(), rank), + ); + + let var_xs: f64 = 0.5; + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + tsk.encrypt_sk( + &module, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.keyswitch_inplace(&module, &ksk, &tsk, scratch.borrow()); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); + + (0..ct.rank() + 1).for_each(|col_j| { + (0..ct.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_ggsw_keyswitch( + module.n() as f64, + basek, + col_j, + var_xs, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k, + k, + ); + + println!("{} {}", noise_have, noise_want); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); +} + pub(crate) fn noise_ggsw_keyswitch( n: f64, basek: usize, @@ -292,14 +435,6 @@ pub(crate) fn noise_ggsw_keyswitch( noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] } -#[test] -fn automorphism() { - (1..4).for_each(|rank| { - println!("test automorphism rank: {}", rank); - test_automorphism(-5, 12, 15, 60, rank, 3.2); - }); -} - fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); let rows: usize = (k + basek - 1) / basek; @@ -421,6 +556,119 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, }); } +fn test_automorphism_inplace(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k + basek - 1) / basek; + + let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut tensor_key: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()) + | AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size()) + | TensorKey::encrypt_sk_scratch_space(&module, rank, tensor_key.size()) + | GGSWCiphertext::automorphism_inplace_scratch_space(&module, ct.size(), auto_key.size(), tensor_key.size(), rank), + ); + + let var_xs: f64 = 0.5; + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + auto_key.encrypt_sk( + &module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + tensor_key.encrypt_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.automorphism_inplace(&module, &auto_key, &tensor_key, scratch.borrow()); + + module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); + + (0..ct.rank() + 1).for_each(|col_j| { + (0..ct.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_ggsw_keyswitch( + module.n() as f64, + basek, + col_j, + var_xs, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k, + k, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); +} + fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n);