From 8f2eac4928f1026c74077d896262b20e20e66c53 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 19 May 2025 18:06:14 +0200 Subject: [PATCH] Added tensor key & associated test --- base2k/src/scalar_znx.rs | 6 +- base2k/src/scalar_znx_dft.rs | 72 +++++++++++++- base2k/src/vec_znx.rs | 10 ++ base2k/src/vec_znx_dft.rs | 10 +- core/src/ggsw_ciphertext.rs | 82 ++++++++++++++++ core/src/lib.rs | 1 + core/src/tensor_key.rs | 125 ++++++++++++++++++++++++ core/src/test_fft64/automorphism_key.rs | 99 +++++++++++++++++++ core/src/test_fft64/ggsw.rs | 121 ++++++++++++++++++++++- core/src/test_fft64/glwe.rs | 34 +++---- core/src/test_fft64/mod.rs | 1 + core/src/test_fft64/tensor_key.rs | 77 +++++++++++++++ 12 files changed, 610 insertions(+), 28 deletions(-) create mode 100644 core/src/tensor_key.rs create mode 100644 core/src/test_fft64/tensor_key.rs diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index fa812a8..4c981c1 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -9,9 +9,9 @@ use rand_distr::{Distribution, weighted::WeightedIndex}; use sampling::source::Source; pub struct ScalarZnx { - data: D, - n: usize, - cols: usize, + pub(crate) data: D, + pub(crate) n: usize, + pub(crate) cols: usize, } impl ZnxInfos for ScalarZnx { diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index 3626625..248b87d 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use crate::ffi::svp; use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; +use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView, FFT64}; pub struct ScalarZnxDft { data: D, @@ -92,6 +92,16 @@ impl ScalarZnxDft { _phantom: PhantomData, } } + + pub fn as_vec_znx_dft(self) -> VecZnxDft{ + VecZnxDft{ + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } } pub type ScalarZnxDftOwned = ScalarZnxDft, B>; @@ -158,3 +168,63 @@ impl ScalarZnxDftToRef for ScalarZnxDft<&[u8], B> { } } } + +impl VecZnxDftToMut for ScalarZnxDft, B> { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + VecZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for ScalarZnxDft, B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToMut for ScalarZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for ScalarZnxDft<&mut [u8], B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for ScalarZnxDft<&[u8], B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} \ No newline at end of file diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index b945b2c..5d9f1ca 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,5 +1,6 @@ use crate::DataView; use crate::DataViewMut; +use crate::ScalarZnx; use crate::ZnxSliceSize; use crate::ZnxZero; use crate::alloc_aligned; @@ -128,6 +129,15 @@ impl VecZnx { size, } } + + pub fn to_scalar_znx(self) -> ScalarZnx{ + debug_assert_eq!(self.size, 1, "cannot convert VecZnx to ScalarZnx if cols: {} != 1", self.cols); + ScalarZnx{ + data: self.data, + n: self.n, + cols: self.cols, + } + } } /// Copies the coefficients of `a` on the receiver. diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index b4bc973..7b4ec29 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -8,11 +8,11 @@ use crate::{ use std::fmt; pub struct VecZnxDft { - data: D, - n: usize, - cols: usize, - size: usize, - _phantom: PhantomData, + pub(crate) data: D, + pub(crate) n: usize, + pub(crate) cols: usize, + pub(crate) size: usize, + pub(crate) _phantom: PhantomData, } impl VecZnxDft { diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 577bd6e..4a8c0a8 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -6,6 +6,7 @@ use base2k::{ use sampling::source::Source; use crate::{ + automorphism::AutomorphismKey, elem::{GetRow, Infos, SetRow}, glwe_ciphertext::GLWECiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, @@ -78,6 +79,20 @@ impl GGSWCiphertext, FFT64> { + module.bytes_of_vec_znx_dft(rank + 1, size) } + pub fn automorphism_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + auto_key_size: usize, + rank: usize, + ) -> usize { + let size: usize = in_size.min(out_size); + let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, size); + let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, size); + let vmp: usize = GLWECiphertext::keyswitch_from_fourier_scratch_space(module, size, rank, size, rank, auto_key_size); + tmp_dft + tmp_idft + vmp + } + pub fn external_product_scratch_space( module: &Module, out_size: usize, @@ -182,6 +197,73 @@ where }); } + pub fn automorphism( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank(), + lhs.rank(), + "ggsw_out rank: {} != ggsw_in rank: {}", + self.rank(), + lhs.rank() + ); + assert_eq!( + self.rank(), + rhs.rank(), + "ggsw_in rank: {} != auto_key rank: {}", + self.rank(), + rhs.rank() + ); + } + + let size: usize = self.size().min(lhs.size()); + let cols: usize = self.rank() + 1; + + let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, size); + + let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_dft_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + let (tmp_idft_data, scratch2) = scratch1.tmp_vec_znx(module, cols, size); + + let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: tmp_idft_data, + basek: self.basek(), + k: self.k(), + }; + + (0..cols).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_dft); + tmp_idft.keyswitch_from_fourier(module, &tmp_dft, &rhs.key, scratch2); + (0..cols).for_each(|i| { + module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i); + }); + self.set_row(module, row_j, col_i, &tmp_dft); + }); + }); + + tmp_dft.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank() + 1).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_dft); + }); + }); + } + pub fn external_product( &mut self, module: &Module, diff --git a/core/src/lib.rs b/core/src/lib.rs index f04ca06..74ed7ef 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -7,6 +7,7 @@ pub mod glwe_ciphertext_fourier; pub mod glwe_plaintext; pub mod keys; pub mod keyswitch_key; +pub mod tensor_key; #[cfg(test)] mod test_fft64; mod utils; diff --git a/core/src/tensor_key.rs b/core/src/tensor_key.rs new file mode 100644 index 0000000..5625b51 --- /dev/null +++ b/core/src/tensor_key.rs @@ -0,0 +1,125 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftAlloc, + ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnxDftOps, VecZnxDftToRef, +}; +use sampling::source::Source; + +use crate::{ + elem::Infos, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, +}; + +pub struct TensorKey { + pub(crate) keys: Vec>, +} + +impl TensorKey, FFT64> { + pub fn new(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + let mut keys: Vec, FFT64>> = Vec::new(); + let pairs: usize = ((rank + 1) * rank) >> 1; + (0..pairs).for_each(|_| { + keys.push(GLWESwitchingKey::new(module, basek, k, rows, 1, rank)); + }); + Self { keys: keys } + } +} + +impl Infos for TensorKey { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.keys[0].inner() + } + + fn basek(&self) -> usize { + self.keys[0].basek() + } + + fn k(&self) -> usize { + self.keys[0].k() + } +} + +impl TensorKey { + pub fn rank(&self) -> usize { + self.keys[0].rank() + } + + pub fn rank_in(&self) -> usize { + self.keys[0].rank_in() + } + + pub fn rank_out(&self) -> usize { + self.keys[0].rank_out() + } +} + +impl TensorKey, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + module.bytes_of_scalar_znx_dft(1) + GLWESwitchingKey::encrypt_sk_scratch_space(module, rank, size) + } +} + +impl TensorKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + ScalarZnxDft: VecZnxDftToRef + ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + } + + let rank: usize = self.rank(); + + (0..rank).for_each(|i| { + (i..rank).for_each(|j| { + let (mut sk_ij_dft, scratch1) = scratch.tmp_scalar_znx_dft(module, 1); + module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j); + let sk_ij: ScalarZnx<&mut [u8]> = module + .vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft()) + .to_vec_znx_small() + .to_scalar_znx(); + let sk_ij: SecretKey<&mut [u8]> = SecretKey { + data: sk_ij, + dist: sk_dft.dist, + }; + + self.at_mut(i, j).encrypt_sk( + module, &sk_ij, sk_dft, source_xa, source_xe, sigma, scratch1, + ); + }); + }) + } + + // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) + pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank(); + &self.keys[i * rank + j - (i * (i + 1) / 2)] + } + + // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) + pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKey { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank(); + &mut self.keys[i * rank + j - (i * (i + 1) / 2)] + } +} diff --git a/core/src/test_fft64/automorphism_key.rs b/core/src/test_fft64/automorphism_key.rs index 9705a3f..6ac6b40 100644 --- a/core/src/test_fft64/automorphism_key.rs +++ b/core/src/test_fft64/automorphism_key.rs @@ -18,6 +18,14 @@ fn automorphism() { }); } +#[test] +fn automorphism_inplace() { + (1..4).for_each(|rank| { + println!("test automorphism_inplace rank: {}", rank); + test_automorphism_inplace(-1, 5, 12, 12, 60, 3.2, rank); + }); +} + fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) { let module: Module = Module::::new(1 << log_n); let rows = (k_ksk + basek - 1) / basek; @@ -115,3 +123,94 @@ fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, }); }); } + +fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + let rows = (k_ksk + basek - 1) / basek; + + let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank); + let mut auto_key_apply: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, auto_key.size()) + | AutomorphismKey::automorphism_inplace_scratch_space(&module, auto_key.size(), auto_key_apply.size(), rank), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + // gglwe_{s1}(s0) = s0 -> s1 + auto_key.encrypt_sk( + &module, + p0, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + auto_key_apply.encrypt_sk( + &module, + p1, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + auto_key.automorphism_inplace(&module, &auto_key_apply, scratch.borrow()); + + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ksk); + + let mut sk_auto: SecretKey> = SecretKey::new(&module, rank); + sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk + (0..rank).for_each(|i| { + module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i); + }); + + let mut sk_auto_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_auto_dft.dft(&module, &sk_auto); + + (0..auto_key.rank_in()).for_each(|col_i| { + (0..auto_key.rows()).for_each(|row_i| { + auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft); + + ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_want: f64 = noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ksk, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + }); +} diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index cf34dda..4325426 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -5,6 +5,7 @@ use base2k::{ use sampling::source::Source; use crate::{ + automorphism::AutomorphismKey, elem::{GetRow, Infos}, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, @@ -104,6 +105,123 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank: }); } +// fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { +// let module: Module = Module::::new(1 << log_n); +// let rows: usize = (k_ggsw + basek - 1) / basek; +// +// let mut ct_ggsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); +// let mut ct_ggsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); +// let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank); +// +// let mut pt_ggsw_in: ScalarZnx> = module.new_scalar_znx(1); +// let mut pt_ggsw_out: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// pt_ggsw_in.fill_ternary_prob(0, 0.5, &mut source_xs); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_out.size()) +// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_in.size()) +// | GGSWCiphertext::automorphism_scratch_space( +// &module, +// ct_ggsw_out.size(), +// ct_ggsw_in.size(), +// auto_key.size(), +// rank, +// ), +// ); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// ct_ggsw_in.encrypt_sk( +// &module, +// &pt_ggsw_in, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// auto_key.encrypt_sk( +// &module, +// p, +// &sk, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_ggsw_out.automorphism(&module, &ct_ggsw_in, &auto_key, scratch.borrow()); +// +// let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); +// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size()); +// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size()); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); +// +// module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); +// +// (0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| { +// (0..ct_ggsw_lhs_out.rows()).for_each(|row_i| { +// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0); +// +// if col_j > 0 { +// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); +// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); +// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); +// module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); +// } +// +// ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); +// ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); +// +// let noise_have: f64 = pt.data.std(0, basek).log2(); +// +// let var_gct_err_lhs: f64 = sigma * sigma; +// let var_gct_err_rhs: f64 = 0f64; +// +// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} +// let var_a0_err: f64 = sigma * sigma; +// let var_a1_err: f64 = 1f64 / 12f64; +// +// let noise_want: f64 = noise_ggsw_product( +// module.n() as f64, +// basek, +// 0.5, +// var_msg, +// var_a0_err, +// var_a1_err, +// var_gct_err_lhs, +// var_gct_err_rhs, +// rank as f64, +// k_ggsw, +// k_ggsw, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "have: {} want: {}", +// noise_have, +// noise_want +// ); +// +// pt_want.data.zero(); +// }); +// }); +// } + fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); @@ -126,8 +244,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size()) + GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size()) | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size()) | GGSWCiphertext::external_product_scratch_space( &module, diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 54f389c..e0323fa 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -1,6 +1,6 @@ use base2k::{ Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, - ZnxView, ZnxViewMut, ZnxZero, + ZnxViewMut, ZnxZero, }; use itertools::izip; use sampling::source::Source; @@ -75,6 +75,22 @@ fn external_product_inplace() { }); } +#[test] +fn automorphism_inplace() { + (1..4).for_each(|rank| { + println!("test automorphism_inplace rank: {}", rank); + test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2); + }); +} + +#[test] +fn automorphism() { + (1..4).for_each(|rank| { + println!("test automorphism rank: {}", rank); + test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2); + }); +} + fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { let module: Module = Module::::new(1 << log_n); @@ -416,14 +432,6 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, ); } -#[test] -fn automorphism() { - (1..4).for_each(|rank| { - println!("test automorphism rank: {}", rank); - test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2); - }); -} - fn test_automorphism( log_n: usize, basek: usize, @@ -515,14 +523,6 @@ fn test_automorphism( ); } -#[test] -fn automorphism_inplace() { - (1..4).for_each(|rank| { - println!("test automorphism_inplace rank: {}", rank); - test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2); - }); -} - fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usize, k_ct: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); let rows: usize = (k_ct + basek - 1) / basek; diff --git a/core/src/test_fft64/mod.rs b/core/src/test_fft64/mod.rs index 9af0cfc..fb2129e 100644 --- a/core/src/test_fft64/mod.rs +++ b/core/src/test_fft64/mod.rs @@ -3,3 +3,4 @@ mod gglwe; mod ggsw; mod glwe; mod glwe_fourier; +mod tensor_key; diff --git a/core/src/test_fft64/tensor_key.rs b/core/src/test_fft64/tensor_key.rs new file mode 100644 index 0000000..920341b --- /dev/null +++ b/core/src/test_fft64/tensor_key.rs @@ -0,0 +1,77 @@ +use base2k::{FFT64, Module, ScalarZnx, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxDftOps, VecZnxOps}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos}, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{SecretKey, SecretKeyFourier}, + tensor_key::TensorKey, +}; + +#[test] +fn encrypt_sk() { + (1..4).for_each(|rank| { + println!("test encrypt_sk rank: {}", rank); + test_encrypt_sk(12, 16, 54, 3.2, rank); + }); +} + +fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k + basek - 1) / basek; + + let mut tensor_key: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new(TensorKey::encrypt_sk_scratch_space( + &module, + rank, + tensor_key.size(), + )); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + tensor_key.encrypt_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + + (0..rank).for_each(|i| { + (0..rank).for_each(|j| { + let mut sk_ij_dft: base2k::ScalarZnxDft, FFT64> = module.new_scalar_znx_dft(1); + module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j); + let sk_ij: ScalarZnx> = module + .vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft()) + .to_vec_znx_small() + .to_scalar_znx(); + + (0..tensor_key.rank_in()).for_each(|col_i| { + (0..tensor_key.rows()).for_each(|row_i| { + tensor_key + .at(i, j) + .get_row(&module, row_i, col_i, &mut ct_glwe_fourier); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_ij, col_i); + let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + }); + }); + }) + }) +}