From 06b3cccbffaceeb0e6e4bc23dd4a440955621db1 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 20 May 2025 11:43:18 +0200 Subject: [PATCH] Added GGSW key-switching along with algo description --- base2k/src/mat_znx_dft_ops.rs | 60 +++++++++++ core/src/automorphism.rs | 12 +-- core/src/elem.rs | 12 +-- core/src/gglwe_ciphertext.rs | 12 +-- core/src/ggsw_ciphertext.rs | 194 ++++++++++++++++++++++++++++++---- core/src/keyswitch_key.rs | 10 +- core/src/tensor_key.rs | 23 ++-- 7 files changed, 272 insertions(+), 51 deletions(-) diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 7b4ac36..24be2e2 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -99,6 +99,13 @@ pub trait MatZnxDftOps { R: VecZnxDftToMut, A: VecZnxDftToRef, B: MatZnxDftToRef; + + // Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R. + fn vmp_apply_add(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: MatZnxDftToRef; } impl MatZnxDftAlloc for Module { @@ -301,6 +308,59 @@ impl MatZnxDftOps for Module { ) } } + + fn vmp_apply_add(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: MatZnxDftToRef { + let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); + let a: VecZnxDft<&[u8], _> = a.to_ref(); + let b: MatZnxDft<&[u8], _> = b.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + res.cols(), + b.cols_out(), + "res.cols(): {} != b.cols_out: {}", + res.cols(), + b.cols_out() + ); + assert_eq!( + a.cols(), + b.cols_in(), + "a.cols(): {} != b.cols_in: {}", + a.cols(), + b.cols_in() + ); + } + + let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes( + res.size(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size(), + )); + unsafe { + vmp::vmp_apply_dft_to_dft_add( + self.ptr, + res.as_mut_ptr() as *mut vec_znx_dft_t, + (res.size() * res.cols()) as u64, + a.as_ptr() as *const vec_znx_dft_t, + (a.size() * a.cols()) as u64, + b.as_ptr() as *const vmp::vmp_pmat_t, + (b.rows() * b.cols_in()) as u64, + (b.size() * b.cols_out()) as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } } #[cfg(test)] mod tests { diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index 64b3cf7..8b4fe3a 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -1,7 +1,7 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps, - ScalarZnxToRef, Scratch, VecZnx, VecZnxBigAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, - VecZnxOps, ZnxZero, + ScalarZnxToRef, Scratch, VecZnx, VecZnxBigAlloc, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, + ZnxZero, }; use sampling::source::Source; @@ -85,9 +85,9 @@ impl GetRow for AutomorphismKey where MatZnxDft: MatZnxDftToRef, { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) where - VecZnxDft: VecZnxDftToMut, + R: VecZnxDftToMut, { module.vmp_extract_row(res, self, row_i, col_j); } @@ -97,9 +97,9 @@ impl SetRow for AutomorphismKey where MatZnxDft: MatZnxDftToMut, { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) where - VecZnxDft: VecZnxDftToRef, + R: VecZnxDftToRef, { module.vmp_prepare_row(self, row_i, col_j, a); } diff --git a/core/src/elem.rs b/core/src/elem.rs index 4562137..66cb1d0 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,6 +1,6 @@ -use base2k::{Backend, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos}; +use base2k::{Backend, Module, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos}; -use crate::{glwe_ciphertext_fourier::GLWECiphertextFourier, utils::derive_size}; +use crate::utils::derive_size; pub trait Infos { type Inner: ZnxInfos; @@ -47,13 +47,13 @@ pub trait Infos { } pub trait GetRow { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) where - VecZnxDft: VecZnxDftToMut; + R: VecZnxDftToMut; } pub trait SetRow { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) where - VecZnxDft: VecZnxDftToRef; + R: VecZnxDftToRef; } diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index 863fd54..f8983c8 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -1,7 +1,7 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, - ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, - VecZnxOps, ZnxInfos, ZnxZero, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos, + ZnxZero, }; use sampling::source::Source; @@ -190,9 +190,9 @@ impl GetRow for GGLWECiphertext where MatZnxDft: MatZnxDftToRef, { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) where - VecZnxDft: VecZnxDftToMut, + R: VecZnxDftToMut, { module.vmp_extract_row(res, self, row_i, col_j); } @@ -202,9 +202,9 @@ impl SetRow for GGLWECiphertext where MatZnxDft: MatZnxDftToMut, { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) where - VecZnxDft: VecZnxDftToRef, + R: VecZnxDftToRef, { module.vmp_prepare_row(self, row_i, col_j, a); } diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 4a8c0a8..1b00bd3 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -1,6 +1,6 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, - ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxBigOps, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero, }; use sampling::source::Source; @@ -12,6 +12,8 @@ use crate::{ glwe_ciphertext_fourier::GLWECiphertextFourier, glwe_plaintext::GLWEPlaintext, keys::SecretKeyFourier, + keyswitch_key::GLWESwitchingKey, + tensor_key::TensorKey, utils::derive_size, }; @@ -86,10 +88,9 @@ impl GGSWCiphertext, FFT64> { auto_key_size: usize, rank: usize, ) -> usize { - let size: usize = in_size.min(out_size); - let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, size); - let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, size); - let vmp: usize = GLWECiphertext::keyswitch_from_fourier_scratch_space(module, size, rank, size, rank, auto_key_size); + let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, auto_key_size); + let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, out_size); + let vmp: usize = GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, auto_key_size); tmp_dft + tmp_idft + vmp } @@ -197,6 +198,167 @@ where }); } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + ksk: &GLWESwitchingKey, + tsk: &TensorKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank(), + lhs.rank(), + "ggsw_out rank: {} != ggsw_in rank: {}", + self.rank(), + lhs.rank() + ); + assert_eq!( + lhs.rank(), + ksk.rank(), + "ggsw_in rank: {} != ksk rank: {}", + lhs.rank(), + ksk.rank() + ); + assert_eq!( + lhs.rank(), + tsk.rank(), + "ggsw_in rank: {} != tsk rank: {}", + lhs.rank(), + tsk.rank() + ); + } + + let cols: usize = self.rank() + 1; + + // Example for rank 3: + // + // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is + // actually composed of that many rows. + // + // # Input + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M, a0 , a1 , a2 ) + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M, b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M, c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M) + // + // # Output + // + // col 0: (-(a0s0' + a1s1' + a2s2') + M, a0 , a1 , a2 ) + // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M, b1 , b2 ) + // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M, c2 ) + // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M) + (0..self.rows()).for_each(|row_j| { + + let (tmp_dft_out_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); + + let mut tmp_dft_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_dft_out_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + { + let (tmp_dft_in_data, scratch2) = scratch1.tmp_vec_znx_dft(module, lhs.rank() + 1, lhs.size()); + + let mut tmp_dft_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_dft_in_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + // 1) Applies key-switching to GGSW[i][0]: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) + lhs.get_row(module, row_j, 0, &mut tmp_dft_in); + // (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + tmp_dft_out.keyswitch(module, &tmp_dft_in, ksk, scratch2); + self.set_row(module, row_j, 0, &tmp_dft_out); + } + + // 2) Isolates IDFT(-(a0s0' + a1s1' + a2s2') + M[i]) + let (mut tmp_c0_data, scratch2) = scratch1.tmp_vec_znx_big(module, 1, self.size()); + module.vec_znx_idft_tmp_a(&mut tmp_c0_data, 0, &mut tmp_dft_out, 0); + + // 3) Expands the i-th row of the other columns using the tensor key + // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M[i], b1 , b2 ) = KS_{s0's0', s0's1', s0's2'}(a0) + (0, -(a0s0' + a1s1' + a2s2') + M[i], 0, 0) + // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) = KS_{s1's0', s1's1', s1's2'}(a1) + (0, 0, -(a0s0' + a1s1' + a2s2') + M[i], 0) + // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) = KS_{s2's0', s2's1', s2's2'}(a2) + (0, 0, 0, -(a0s0' + a1s1' + a2s2') + M[i]) + (1..cols).for_each(|col_i| { + + let (tmp_dft_i_data, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, tsk.size()); + let mut tmp_dft_i: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_dft_i_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + // 5) Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 + // + // # Example for col=1 + // + // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) + // + + // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) + // + + // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) + // = + // (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0, x1, x2) + (1..cols).for_each(|col_j| { + + // Extracts a[i] and multipies with Enc(s'[i]s'[j]) + let (mut tmp_dft_col_data, scratch4) = scratch3.tmp_vec_znx_dft(module, 1, self.size()); + tmp_dft_col_data.extract_column(0, &tmp_dft_out.data, col_j); + + if col_j == 1 { + module.vmp_apply( + &mut tmp_dft_i, + &tmp_dft_col_data, + tsk.at(col_i - 1, col_j - 1), // Selects Enc(s'[i]s'[j]) + scratch4, + ); + } else { + module.vmp_apply_add( + &mut tmp_dft_i, + &tmp_dft_col_data, + tsk.at(col_i - 1, col_j - 1), // Selects Enc(s'[i]s'[j]) + scratch4, + ); + } + }); + + // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i + // + // (-(x0s0' + x1s1' + x2s2') + a0s0's0' + a1s0's1' + a2s0's2', x0, x1, x2) + // + + // (0, -(a0s0' + a1s1' + a2s2') + M[i], 0, 0) + // = + // (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0 -(a0s0' + a1s1' + a2s2') + M[i], x1, x2) + // = + // (-(x0s0' + x1s1' + x2s2'), x0 + M[i], x1, x2) + { + let (mut tmp_idft, scratch3) = scratch3.tmp_vec_znx_big(module, 1, tsk.size()); + let (mut tmp_znx_small, scratch5) = scratch3.tmp_vec_znx(module, 1, self.size()); + (0..cols).for_each(|i| { + module.vec_znx_idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); + module.vec_znx_big_add_inplace(&mut tmp_idft, col_i, &tmp_c0_data, 0); + module.vec_znx_big_normalize(self.basek(), &mut tmp_znx_small, 0, &tmp_idft, 0, scratch5); + module.vec_znx_dft(&mut tmp_dft_i, i, &tmp_znx_small, 0); + }); + } + + // Stores (-(x0s0' + x1s1' + x2s2'), x0 + M[i], x1, x2) + self.set_row(module, row_j, col_i, &tmp_dft_i); + }) + }) + } + pub fn automorphism( &mut self, module: &Module, @@ -224,11 +386,10 @@ where rhs.rank() ); } - - let size: usize = self.size().min(lhs.size()); +; let cols: usize = self.rank() + 1; - let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, size); + let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); //TODO optimize let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_dft_data, @@ -236,7 +397,7 @@ where k: lhs.k(), }; - let (tmp_idft_data, scratch2) = scratch1.tmp_vec_znx(module, cols, size); + let (tmp_idft_data, scratch2) = scratch1.tmp_vec_znx(module, cols, self.size()); let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: tmp_idft_data, @@ -366,14 +527,9 @@ impl GetRow for GGSWCiphertext where MatZnxDft: MatZnxDftToRef, { - fn get_row( - &self, - module: &Module, - row_i: usize, - col_j: usize, - res: &mut GLWECiphertextFourier, - ) where - VecZnxDft: VecZnxDftToMut, + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) + where + R: VecZnxDftToMut, { module.vmp_extract_row(res, self, row_i, col_j); } @@ -383,9 +539,9 @@ impl SetRow for GGSWCiphertext where MatZnxDft: MatZnxDftToMut, { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) where - VecZnxDft: VecZnxDftToRef, + R: VecZnxDftToRef, { module.vmp_prepare_row(self, row_i, col_j, a); } diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index e01df09..cade469 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -1,6 +1,6 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef, - ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, ZnxZero, + ScalarZnxToRef, Scratch, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, ZnxZero, }; use sampling::source::Source; @@ -74,9 +74,9 @@ impl GetRow for GLWESwitchingKey where MatZnxDft: MatZnxDftToRef, { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) where - VecZnxDft: VecZnxDftToMut, + R: VecZnxDftToMut, { module.vmp_extract_row(res, self, row_i, col_j); } @@ -86,9 +86,9 @@ impl SetRow for GLWESwitchingKey where MatZnxDft: MatZnxDftToMut, { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) where - VecZnxDft: VecZnxDftToRef, + R: VecZnxDftToRef, { module.vmp_prepare_row(self, row_i, col_j, a); } diff --git a/core/src/tensor_key.rs b/core/src/tensor_key.rs index 5625b51..158274d 100644 --- a/core/src/tensor_key.rs +++ b/core/src/tensor_key.rs @@ -105,15 +105,6 @@ where }) } - // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) - pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank(); - &self.keys[i * rank + j - (i * (i + 1) / 2)] - } - // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKey { if i > j { @@ -123,3 +114,17 @@ where &mut self.keys[i * rank + j - (i * (i + 1) / 2)] } } + +impl TensorKey +where + MatZnxDft: MatZnxDftToRef, +{ + // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) + pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank(); + &self.keys[i * rank + j - (i * (i + 1) / 2)] + } +}