From 640ff9ea614e11405966d738728f496b51823f0c Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 20 May 2025 17:42:43 +0200 Subject: [PATCH] Refactor of GGSW key-switch to enable easier implementation of GGSW automorphism --- base2k/src/module.rs | 5 +- base2k/src/scalar_znx_dft.rs | 11 +- base2k/src/vec_znx.rs | 14 +- base2k/src/vec_znx_big.rs | 10 +- base2k/src/vec_znx_big_ops.rs | 2 - base2k/src/vec_znx_dft_ops.rs | 69 +++++++ base2k/src/vec_znx_ops.rs | 6 +- core/src/ggsw_ciphertext.rs | 332 +++++++++++++++++++--------------- core/src/test_fft64/ggsw.rs | 243 +++++++++++++------------ 9 files changed, 404 insertions(+), 288 deletions(-) diff --git a/base2k/src/module.rs b/base2k/src/module.rs index 8ee6e4b..f6d0e0e 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -78,10 +78,7 @@ impl Module { if gal_el == 0 { panic!("cannot invert 0") } - ((mod_exp_u64( - gal_el.abs() as u64, - (self.cyclotomic_order() - 1) as usize, - ) & (self.cyclotomic_order() - 1)) as i64) + ((mod_exp_u64(gal_el.abs() as u64, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1)) as i64) * gal_el.signum() } } diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index 248b87d..fa4ab10 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -2,7 +2,10 @@ use std::marker::PhantomData; use crate::ffi::svp; use crate::znx_base::ZnxInfos; -use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView, FFT64}; +use crate::{ + Backend, DataView, DataViewMut, FFT64, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView, + alloc_aligned, +}; pub struct ScalarZnxDft { data: D, @@ -93,8 +96,8 @@ impl ScalarZnxDft { } } - pub fn as_vec_znx_dft(self) -> VecZnxDft{ - VecZnxDft{ + pub fn as_vec_znx_dft(self) -> VecZnxDft { + VecZnxDft { data: self.data, n: self.n, cols: self.cols, @@ -227,4 +230,4 @@ impl VecZnxDftToRef for ScalarZnxDft<&[u8], B> { _phantom: PhantomData, } } -} \ No newline at end of file +} diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 5d9f1ca..d4b0b9c 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -130,9 +130,13 @@ impl VecZnx { } } - pub fn to_scalar_znx(self) -> ScalarZnx{ - debug_assert_eq!(self.size, 1, "cannot convert VecZnx to ScalarZnx if cols: {} != 1", self.cols); - ScalarZnx{ + pub fn to_scalar_znx(self) -> ScalarZnx { + debug_assert_eq!( + self.size, 1, + "cannot convert VecZnx to ScalarZnx if cols: {} != 1", + self.cols + ); + ScalarZnx { data: self.data, n: self.n, cols: self.cols, @@ -198,9 +202,9 @@ where VecZnx: VecZnxToMut + ZnxInfos, { /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. - pub fn extract_column(&mut self, self_col: usize, a: &VecZnx, a_col: usize) + pub fn extract_column(&mut self, self_col: usize, a: &R, a_col: usize) where - VecZnx: VecZnxToRef + ZnxInfos, + R: VecZnxToRef + ZnxInfos, { #[cfg(debug_assertions)] { diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index eba90e9..2bf4dcc 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,6 +1,6 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; -use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, FFT64}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned}; use std::fmt; use std::marker::PhantomData; @@ -97,11 +97,11 @@ impl VecZnxBig { impl VecZnxBig where VecZnxBig: VecZnxBigToMut + ZnxInfos, -{ - // Consumes the VecZnxBig to return a VecZnx. +{ + // Consumes the VecZnxBig to return a VecZnx. // Useful when no normalization is needed. - pub fn to_vec_znx_small(self) -> VecZnx{ - VecZnx{ + pub fn to_vec_znx_small(self) -> VecZnx { + VecZnx { data: self.data, n: self.n, cols: self.cols, diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index f6dad7a..8208c97 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -147,7 +147,6 @@ pub trait VecZnxBigOps { fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) where A: VecZnxBigToMut; - } pub trait VecZnxBigScratch { @@ -170,7 +169,6 @@ impl VecZnxBigAlloc for Module { } impl VecZnxBigOps for Module { - fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) where R: VecZnxBigToMut, diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 27e6f59..3e5965b 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -42,6 +42,17 @@ pub trait VecZnxDftOps { /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. fn vec_znx_idft_tmp_bytes(&self) -> usize; + fn vec_znx_dft_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef; + + fn vec_znx_dft_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; + fn vec_znx_dft_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, @@ -84,6 +95,64 @@ impl VecZnxDftAlloc for Module { } impl VecZnxDftOps for Module { + fn vec_znx_dft_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_add( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } + + fn vec_znx_dft_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_add( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } + fn vec_znx_dft_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 85c2e1f..b97e6b7 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -574,7 +574,11 @@ impl VecZnxOps for Module { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); - assert!(k & 1 != 0, "invalid galois element: must be odd but is {}", k); + assert!( + k & 1 != 0, + "invalid galois element: must be odd but is {}", + k + ); } unsafe { vec_znx::vec_znx_automorphism( diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 2546b7d..e8f913c 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -1,7 +1,8 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, - ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, - VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero, + ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, + VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, + VecZnxToRef, ZnxInfos, ZnxZero, }; use sampling::source::Source; @@ -81,6 +82,26 @@ impl GGSWCiphertext, FFT64> { + module.bytes_of_vec_znx_dft(rank + 1, size) } + pub(crate) fn expand_row_scratch_space(module: &Module, self_size: usize, tsk_size: usize, rank: usize) -> usize { + let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size); + let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size); + let vmp: usize = tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tsk_size); + let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size); + let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); + tmp_dft_i + ((tmp_dft_col_data + vmp) | (tmp_idft + norm)) + } + + pub(crate) fn keyswitch_internal_col0_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ksk_size: usize, + rank: usize, + ) -> usize { + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, ksk_size) + + module.bytes_of_vec_znx_dft(rank + 1, in_size) + } + pub fn keyswitch_scratch_space( module: &Module, out_size: usize, @@ -89,17 +110,12 @@ impl GGSWCiphertext, FFT64> { tsk_size: usize, rank: usize, ) -> usize { - let tmp_dft_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); - let vmp_ksk: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) - + GLWECiphertextFourier::keyswitch_scratch_space(module, out_size, rank, in_size, rank, ksk_size); - let tmp_c0: usize = module.bytes_of_vec_znx_big(1, out_size); - let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size); - let vmp_tsk: usize = module.bytes_of_vec_znx_dft(1, out_size) - + module.vmp_apply_tmp_bytes(out_size, out_size, rank + 1, rank + 1, rank + 1, tsk_size); - let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size); - let tmp_znx_small: usize = module.bytes_of_vec_znx(1, out_size); - let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); - tmp_dft_out + (vmp_ksk | (tmp_c0 + tmp_dft_i + (vmp_tsk | (tmp_idft + tmp_znx_small + norm)))) + let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size); + let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ks: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, ksk_size, rank); + let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tsk_size, rank); + let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + res_znx + ci_dft + (ks | expand_rows | res_dft) } pub fn automorphism_scratch_space( @@ -186,19 +202,19 @@ where k, }; - (0..self.rows()).for_each(|row_j| { + (0..self.rows()).for_each(|row_i| { vec_znx_pt.data.zero(); // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0); + module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0); module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scrach_2); - (0..cols).for_each(|col_i| { + (0..cols).for_each(|col_j| { // rlwe encrypt of vec_znx_pt into vec_znx_ct vec_znx_ct.encrypt_sk_private( module, - Some((&vec_znx_pt, col_i)), + Some((&vec_znx_pt, col_j)), sk_dft, source_xa, source_xe, @@ -214,167 +230,151 @@ where module.vec_znx_dft(&mut vec_znx_dft_ct, i, &vec_znx_ct, i); }); - module.vmp_prepare_row(self, row_j, col_i, &vec_znx_dft_ct); + self.set_row(module, row_i, col_j, &vec_znx_dft_ct); } }); }); } - pub fn keyswitch( + pub(crate) fn expand_row( &mut self, module: &Module, - lhs: &GGSWCiphertext, - ksk: &GLWESwitchingKey, - tsk: &TensorKey, + col_j: usize, + res: &mut R, + ci_dft: &VecZnxDft, + tsk: &TensorKey, scratch: &mut Scratch, ) where - MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, + R: VecZnxToMut, + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank(), - lhs.rank(), - "ggsw_out rank: {} != ggsw_in rank: {}", - self.rank(), - lhs.rank() - ); - assert_eq!( - lhs.rank(), - ksk.rank(), - "ggsw_in rank: {} != ksk rank: {}", - lhs.rank(), - ksk.rank() - ); - assert_eq!( - lhs.rank(), - tsk.rank(), - "ggsw_in rank: {} != tsk rank: {}", - lhs.rank(), - tsk.rank() - ); - } - let cols: usize = self.rank() + 1; // Example for rank 3: // // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is - // actually composed of that many rows. + // actually composed of that many rows and we focus on a specific row here + // implicitely given ci_dft. // // # Input // - // col 0: (-(a0s0 + a1s1 + a2s2) + M, a0 , a1 , a2 ) - // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M, b1 , b2 ) - // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M, c2 ) - // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M) + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (0, 0, 0, 0) + // col 2: (0, 0, 0, 0) + // col 3: (0, 0, 0, 0) // // # Output // - // col 0: (-(a0s0' + a1s1' + a2s2') + M, a0 , a1 , a2 ) - // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M, b1 , b2 ) - // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M, c2 ) - // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M) - (0..self.rows()).for_each(|row_j| { - let (tmp_dft_out_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) - let mut tmp_dft_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_dft_out_data, - basek: lhs.basek(), - k: lhs.k(), - }; + let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size()); + { + let (mut tmp_dft_col_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, self.size()); - { - let (tmp_dft_in_data, scratch2) = scratch1.tmp_vec_znx_dft(module, lhs.rank() + 1, lhs.size()); - - let mut tmp_dft_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_dft_in_data, - basek: lhs.basek(), - k: lhs.k(), - }; - - // 1) Applies key-switching to GGSW[i][0]: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) - lhs.get_row(module, row_j, 0, &mut tmp_dft_in); - // (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) - tmp_dft_out.keyswitch(module, &tmp_dft_in, ksk, scratch2); - self.set_row(module, row_j, 0, &tmp_dft_out); - } - - // 2) Isolates IDFT(-(a0s0' + a1s1' + a2s2') + M[i]) - let (mut tmp_c0_data, scratch2) = scratch1.tmp_vec_znx_big(module, 1, self.size()); - module.vec_znx_idft_tmp_a(&mut tmp_c0_data, 0, &mut tmp_dft_out, 0); - - // 3) Expands the i-th row of the other columns using the tensor key - // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M[i], b1 , b2 ) = KS_{s0's0', s0's1', s0's2'}(a0) + (0, -(a0s0' + a1s1' + a2s2') + M[i], 0, 0) - // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) = KS_{s1's0', s1's1', s1's2'}(a1) + (0, 0, -(a0s0' + a1s1' + a2s2') + M[i], 0) - // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) = KS_{s2's0', s2's1', s2's2'}(a2) + (0, 0, 0, -(a0s0' + a1s1' + a2s2') + M[i]) + // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 + // + // # Example for col=1 + // + // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) + // + + // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) + // + + // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) (1..cols).for_each(|col_i| { - let (tmp_dft_i_data, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, tsk.size()); - let mut tmp_dft_i: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_dft_i_data, - basek: lhs.basek(), - k: lhs.k(), - }; + // Extracts a[i] and multipies with Enc(s[i]s[j]) + tmp_dft_col_data.extract_column(0, ci_dft, col_i); - // 5) Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 - // - // # Example for col=1 - // - // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) - // + - // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) - // + - // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) - // = - // (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0, x1, x2) - (1..cols).for_each(|col_j| { - // Extracts a[i] and multipies with Enc(s'[i]s'[j]) - let (mut tmp_dft_col_data, scratch4) = scratch3.tmp_vec_znx_dft(module, 1, self.size()); - tmp_dft_col_data.extract_column(0, &tmp_dft_out.data, col_j); + if col_i == 1 { + module.vmp_apply( + &mut tmp_dft_i, + &tmp_dft_col_data, + tsk.at(col_i - 1, col_j - 1), // Selects Enc(s[i]s[j]) + scratch2, + ); + } else { + module.vmp_apply_add( + &mut tmp_dft_i, + &tmp_dft_col_data, + tsk.at(col_i - 1, col_j - 1), // Selects Enc(s[i]s[j]) + scratch2, + ); + } + }); + } - if col_j == 1 { - module.vmp_apply( - &mut tmp_dft_i, - &tmp_dft_col_data, - tsk.at(col_i - 1, col_j - 1), // Selects Enc(s'[i]s'[j]) - scratch4, - ); - } else { - module.vmp_apply_add( - &mut tmp_dft_i, - &tmp_dft_col_data, - tsk.at(col_i - 1, col_j - 1), // Selects Enc(s'[i]s'[j]) - scratch4, - ); - } + // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i + // + // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) + // + + // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) + // = + // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) + module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, ci_dft, 0); + let (mut tmp_idft, scratch2) = scratch1.tmp_vec_znx_big(module, 1, tsk.size()); + (0..cols).for_each(|i| { + module.vec_znx_idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); + module.vec_znx_big_normalize(self.basek(), res, i, &tmp_idft, 0, scratch2); + }); + } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + ksk: &GLWESwitchingKey, + tsk: &TensorKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + let cols: usize = self.rank() + 1; + + let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size()); + let mut res: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: res_data, + basek: self.basek(), + k: self.k(), + }; + + let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size()); + + // Keyswitch the j-th row of the col 0 + (0..lhs.rows()).for_each(|row_i| { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + lhs.keyswitch_internal_col0(module, row_i, &mut res, ksk, scratch2); + + // Isolates DFT(a[i]) + (0..cols).for_each(|col_i| { + module.vec_znx_dft(&mut ci_dft, col_i, &res, col_i); + }); + + self.set_row(module, row_i, 0, &ci_dft); + + // Generates + // + // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) + (1..cols).for_each(|col_j| { + self.expand_row(module, col_j, &mut res, &ci_dft, tsk, scratch2); + + let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size()); + (0..cols).for_each(|i| { + module.vec_znx_dft(&mut res_dft, i, &res, i); }); - // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i - // - // (-(x0s0' + x1s1' + x2s2') + a0s0's0' + a1s0's1' + a2s0's2', x0, x1, x2) - // + - // (0, -(a0s0' + a1s1' + a2s2') + M[i], 0, 0) - // = - // (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0 -(a0s0' + a1s1' + a2s2') + M[i], x1, x2) - // = - // (-(x0s0' + x1s1' + x2s2'), x0 + M[i], x1, x2) - { - let (mut tmp_idft, scratch3) = scratch3.tmp_vec_znx_big(module, 1, tsk.size()); - let (mut tmp_znx_small, scratch5) = scratch3.tmp_vec_znx(module, 1, self.size()); - (0..cols).for_each(|i| { - module.vec_znx_idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); - if i == col_i { - module.vec_znx_big_add_inplace(&mut tmp_idft, 0, &tmp_c0_data, 0); - } - module.vec_znx_big_normalize(self.basek(), &mut tmp_znx_small, 0, &tmp_idft, 0, scratch5); - module.vec_znx_dft(&mut tmp_dft_i, i, &tmp_znx_small, 0); - }); - } - - // Stores (-(x0s0' + x1s1' + x2s2'), x0 + M[i], x1, x2) - self.set_row(module, row_j, col_i, &tmp_dft_i); + self.set_row(module, row_i, col_j, &res_dft); }) }) } @@ -542,6 +542,38 @@ where } } +impl GGSWCiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + pub(crate) fn keyswitch_internal_col0( + &self, + module: &Module, + row_i: usize, + res: &mut GLWECiphertext, + ksk: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut + VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), ksk.rank()); + assert_eq!(res.rank(), ksk.rank()); + } + + let (tmp_dft_in_data, scratch2) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); + let mut tmp_dft_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_dft_in_data, + basek: self.basek(), + k: self.k(), + }; + self.get_row(module, row_i, 0, &mut tmp_dft_in); + res.keyswitch_from_fourier(module, &tmp_dft_in, ksk, scratch2); + } +} + impl GetRow for GGSWCiphertext where MatZnxDft: MatZnxDftToRef, diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index 40237c9..2ee8708 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -13,7 +13,6 @@ use crate::{ keys::{SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, tensor_key::TensorKey, - test_fft64::gglwe::log2_std_noise_gglwe_product, }; use super::gglwe::var_noise_gglwe_product; @@ -230,6 +229,8 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) k, ); + println!("{} {}", noise_have, noise_want); + assert!( (noise_have - noise_want).abs() <= 0.1, "{} {}", @@ -291,122 +292,130 @@ pub(crate) fn noise_ggsw_keyswitch( noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] } -// fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { -// let module: Module = Module::::new(1 << log_n); -// let rows: usize = (k_ggsw + basek - 1) / basek; -// -// let mut ct_ggsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); -// let mut ct_ggsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); -// let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank); -// -// let mut pt_ggsw_in: ScalarZnx> = module.new_scalar_znx(1); -// let mut pt_ggsw_out: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// pt_ggsw_in.fill_ternary_prob(0, 0.5, &mut source_xs); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_out.size()) -// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_in.size()) -// | GGSWCiphertext::automorphism_scratch_space( -// &module, -// ct_ggsw_out.size(), -// ct_ggsw_in.size(), -// auto_key.size(), -// rank, -// ), -// ); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// ct_ggsw_in.encrypt_sk( -// &module, -// &pt_ggsw_in, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// auto_key.encrypt_sk( -// &module, -// p, -// &sk, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_ggsw_out.automorphism(&module, &ct_ggsw_in, &auto_key, scratch.borrow()); -// -// let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); -// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size()); -// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size()); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); -// -// module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); -// -// (0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| { -// (0..ct_ggsw_lhs_out.rows()).for_each(|row_i| { -// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0); -// -// if col_j > 0 { -// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); -// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); -// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); -// module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); -// } -// -// ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); -// ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); -// -// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); -// -// let noise_have: f64 = pt.data.std(0, basek).log2(); -// -// let var_gct_err_lhs: f64 = sigma * sigma; -// let var_gct_err_rhs: f64 = 0f64; -// -// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} -// let var_a0_err: f64 = sigma * sigma; -// let var_a1_err: f64 = 1f64 / 12f64; -// -// let noise_want: f64 = noise_ggsw_product( -// module.n() as f64, -// basek, -// 0.5, -// var_msg, -// var_a0_err, -// var_a1_err, -// var_gct_err_lhs, -// var_gct_err_rhs, -// rank as f64, -// k_ggsw, -// k_ggsw, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "have: {} want: {}", -// noise_have, -// noise_want -// ); -// -// pt_want.data.zero(); -// }); -// }); -// } +fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k + basek - 1) / basek; + + let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut tsk: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_out.size()) + | GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | GGSWCiphertext::keyswitch_scratch_space( + &module, + ct_out.size(), + ct_in.size(), + ksk.size(), + tsk.size(), + rank, + ), + ); + + let var_xs: f64 = 0.5; + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + tsk.encrypt_sk( + &module, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + ct_in.encrypt_sk( + &module, + &pt_scalar, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_out.keyswitch(&module, &ct_in, &ksk, &tsk, scratch.borrow()); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_out.size()); + + (0..ct_out.rank() + 1).for_each(|col_j| { + (0..ct_out.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_ggsw_keyswitch( + module.n() as f64, + basek, + col_j, + var_xs, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k, + k, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); +} fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n);