From 7d84477e6411390bc312306a5ea66f4de4e930d2 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 20 May 2025 13:51:13 +0200 Subject: [PATCH] working GGSW key-switch + added test (missing noise formula) --- core/src/ggsw_ciphertext.rs | 47 +++++++++---- core/src/test_fft64/ggsw.rs | 135 ++++++++++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 14 deletions(-) diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 1b00bd3..2546b7d 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -1,7 +1,7 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, - ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxBigOps, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, - VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero, + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, + ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, + VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero, }; use sampling::source::Source; @@ -81,6 +81,27 @@ impl GGSWCiphertext, FFT64> { + module.bytes_of_vec_znx_dft(rank + 1, size) } + pub fn keyswitch_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ksk_size: usize, + tsk_size: usize, + rank: usize, + ) -> usize { + let tmp_dft_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let vmp_ksk: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) + + GLWECiphertextFourier::keyswitch_scratch_space(module, out_size, rank, in_size, rank, ksk_size); + let tmp_c0: usize = module.bytes_of_vec_znx_big(1, out_size); + let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size); + let vmp_tsk: usize = module.bytes_of_vec_znx_dft(1, out_size) + + module.vmp_apply_tmp_bytes(out_size, out_size, rank + 1, rank + 1, rank + 1, tsk_size); + let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size); + let tmp_znx_small: usize = module.bytes_of_vec_znx(1, out_size); + let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); + tmp_dft_out + (vmp_ksk | (tmp_c0 + tmp_dft_i + (vmp_tsk | (tmp_idft + tmp_znx_small + norm)))) + } + pub fn automorphism_scratch_space( module: &Module, out_size: usize, @@ -90,7 +111,8 @@ impl GGSWCiphertext, FFT64> { ) -> usize { let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, auto_key_size); let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, out_size); - let vmp: usize = GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, auto_key_size); + let vmp: usize = + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, auto_key_size); tmp_dft + tmp_idft + vmp } @@ -198,7 +220,6 @@ where }); } - pub fn keyswitch( &mut self, module: &Module, @@ -239,7 +260,7 @@ where let cols: usize = self.rank() + 1; // Example for rank 3: - // + // // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is // actually composed of that many rows. // @@ -250,14 +271,13 @@ where // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M, c2 ) // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M) // - // # Output + // # Output // // col 0: (-(a0s0' + a1s1' + a2s2') + M, a0 , a1 , a2 ) // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M, b1 , b2 ) // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M, c2 ) // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M) (0..self.rows()).for_each(|row_j| { - let (tmp_dft_out_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); let mut tmp_dft_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { @@ -291,7 +311,6 @@ where // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) = KS_{s1's0', s1's1', s1's2'}(a1) + (0, 0, -(a0s0' + a1s1' + a2s2') + M[i], 0) // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) = KS_{s2's0', s2's1', s2's2'}(a2) + (0, 0, 0, -(a0s0' + a1s1' + a2s2') + M[i]) (1..cols).for_each(|col_i| { - let (tmp_dft_i_data, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, tsk.size()); let mut tmp_dft_i: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_dft_i_data, @@ -311,7 +330,6 @@ where // = // (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0, x1, x2) (1..cols).for_each(|col_j| { - // Extracts a[i] and multipies with Enc(s'[i]s'[j]) let (mut tmp_dft_col_data, scratch4) = scratch3.tmp_vec_znx_dft(module, 1, self.size()); tmp_dft_col_data.extract_column(0, &tmp_dft_out.data, col_j); @@ -336,7 +354,7 @@ where // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i // // (-(x0s0' + x1s1' + x2s2') + a0s0's0' + a1s0's1' + a2s0's2', x0, x1, x2) - // + + // + // (0, -(a0s0' + a1s1' + a2s2') + M[i], 0, 0) // = // (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0 -(a0s0' + a1s1' + a2s2') + M[i], x1, x2) @@ -347,7 +365,9 @@ where let (mut tmp_znx_small, scratch5) = scratch3.tmp_vec_znx(module, 1, self.size()); (0..cols).for_each(|i| { module.vec_znx_idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); - module.vec_znx_big_add_inplace(&mut tmp_idft, col_i, &tmp_c0_data, 0); + if i == col_i { + module.vec_znx_big_add_inplace(&mut tmp_idft, 0, &tmp_c0_data, 0); + } module.vec_znx_big_normalize(self.basek(), &mut tmp_znx_small, 0, &tmp_idft, 0, scratch5); module.vec_znx_dft(&mut tmp_dft_i, i, &tmp_znx_small, 0); }); @@ -385,8 +405,7 @@ where self.rank(), rhs.rank() ); - } -; + }; let cols: usize = self.rank() + 1; let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); //TODO optimize diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index 4325426..07cb650 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -12,6 +12,8 @@ use crate::{ glwe_plaintext::GLWEPlaintext, keys::{SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, + tensor_key::TensorKey, + test_fft64::gglwe::noise_gglwe_product, }; #[test] @@ -105,6 +107,139 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank: }); } +#[test] +fn keyswitch() { + (1..4).for_each(|rank| { + println!("test keyswitch rank: {}", rank); + test_keyswitch(12, 15, 60, rank, 3.2); + }); +} + +fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k + basek - 1) / basek; + + let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut tsk: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_out.size()) + | GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | GGSWCiphertext::keyswitch_scratch_space( + &module, + ct_out.size(), + ct_in.size(), + ksk.size(), + tsk.size(), + rank, + ), + ); + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + tsk.encrypt_sk( + &module, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + ct_in.encrypt_sk( + &module, + &pt_scalar, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_out.keyswitch(&module, &ct_in, &ksk, &tsk, scratch.borrow()); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_out.size()); + + (0..ct_out.rank() + 1).for_each(|col_j| { + (0..ct_out.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k, + k, + ); + + println!("{} {}", noise_have, noise_want); + + // assert!( + // (noise_have - noise_want).abs() <= 0.1, + // "{} {}", + // noise_have, + // noise_want + // ); + + pt_want.data.zero(); + }); + }); +} + // fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { // let module: Module = Module::::new(1 << log_n); // let rows: usize = (k_ggsw + basek - 1) / basek;