From 7434f289fe8a8c9156b1529e9eb6e52888d207f7 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 16 May 2025 14:15:41 +0200 Subject: [PATCH] Added automorphism for glwe --- base2k/src/scalar_znx.rs | 64 +++++++++ core/src/automorphism.rs | 253 +++++++++++++++++++++++++++++++++++ core/src/gglwe_ciphertext.rs | 9 +- core/src/glwe_ciphertext.rs | 71 ++++++++-- core/src/keyswitch_key.rs | 62 ++++----- core/src/lib.rs | 1 + core/src/test_fft64/glwe.rs | 109 +++++++++++++++ 7 files changed, 521 insertions(+), 48 deletions(-) create mode 100644 core/src/automorphism.rs diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index 108ba3f..8da145f 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -1,3 +1,4 @@ +use crate::ffi::vec_znx; use crate::znx_base::ZnxInfos; use crate::{ Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned, @@ -122,6 +123,69 @@ impl ScalarZnxAlloc for Module { } } +pub trait ScalarZnxOps { + fn scalar_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxToMut, + A: ScalarZnxToRef; + + /// Applies the automorphism X^i -> X^ik on the selected column of `a`. + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: ScalarZnxToMut; +} + +impl ScalarZnxOps for Module { + fn scalar_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxToMut, + A: ScalarZnxToRef, + { + let a: ScalarZnx<&[u8]> = a.to_ref(); + let mut res: ScalarZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + self.ptr, + k, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: ScalarZnxToMut, + { + let mut a: ScalarZnx<&mut [u8]> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + self.ptr, + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + impl ScalarZnx { pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { Self { data, n, cols } diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs new file mode 100644 index 0000000..ed6a954 --- /dev/null +++ b/core/src/automorphism.rs @@ -0,0 +1,253 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps, + ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, +}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos, SetRow}, + gglwe_ciphertext::GGLWECiphertext, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, +}; + +pub struct AutomorphismKey { + pub(crate) key: GLWESwitchingKey, + pub(crate) p: i64, +} + +impl AutomorphismKey, FFT64> { + pub fn new(module: &Module, basek: usize, p: i64, k: usize, rows: usize, rank: usize) -> Self { + AutomorphismKey { + key: GLWESwitchingKey::new(module, basek, k, rows, rank, rank), + p: p, + } + } +} + +impl Infos for AutomorphismKey { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.key.inner() + } + + fn basek(&self) -> usize { + self.key.basek() + } + + fn k(&self) -> usize { + self.key.k() + } +} + +impl AutomorphismKey { + pub fn p(&self) -> i64 { + self.p + } + + pub fn rank(&self) -> usize { + self.key.rank() + } + + pub fn rank_in(&self) -> usize { + self.key.rank_in() + } + + pub fn rank_out(&self) -> usize { + self.key.rank_out() + } +} + +impl MatZnxDftToMut for AutomorphismKey +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.key.to_mut() + } +} + +impl MatZnxDftToRef for AutomorphismKey +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.key.to_ref() + } +} + +impl GetRow for AutomorphismKey +where + MatZnxDft: MatZnxDftToRef, +{ + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToMut, + { + module.vmp_extract_row(res, self, row_i, col_j); + } +} + +impl SetRow for AutomorphismKey +where + MatZnxDft: MatZnxDftToMut, +{ + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToRef, + { + module.vmp_prepare_row(self, row_i, col_j, a); + } +} + +impl AutomorphismKey, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + GGLWECiphertext::encrypt_sk_scratch_space(module, rank, size) + } + + pub fn encrypt_pk_scratch_space(module: &Module, rank: usize, pk_size: usize) -> usize { + GGLWECiphertext::encrypt_pk_scratch_space(module, rank, pk_size) + } + + pub fn keyswitch_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ksk_size: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::keyswitch_scratch_space(module, out_size, rank, in_size, rank, ksk_size) + } + + pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { + GLWESwitchingKey::keyswitch_inplace_scratch_space(module, out_size, out_rank, ksk_size) + } + + pub fn external_product_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank) + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank) + } +} + +impl AutomorphismKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + p: i64, + sk: &SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + ScalarZnx: ScalarZnxToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.n(), module.n()); + assert_eq!(sk.n(), module.n()); + assert_eq!(self.rank_out(), self.rank_in()); + assert_eq!(sk.rank(), self.rank()); + } + + let (sk_out_dft_data, scratch_1) = scratch.tmp_scalar_znx_dft(module, sk.rank()); + + let mut sk_out_dft: SecretKeyFourier<&mut [u8], FFT64> = SecretKeyFourier { + data: sk_out_dft_data, + dist: sk.dist, + }; + + { + (0..self.rank()).for_each(|i| { + let (mut sk_inv_auto, _) = scratch_1.tmp_scalar_znx(module, 1); + module.scalar_znx_automorphism(module.galois_element_inv(p), &mut sk_inv_auto, 0, sk, i); + module.svp_prepare(&mut sk_out_dft, i, &sk_inv_auto, 0); + }); + } + + self.key.encrypt_sk( + module, + &sk, + &sk_out_dft, + source_xa, + source_xe, + sigma, + scratch_1, + ); + + self.p = p; + } +} + +impl AutomorphismKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &AutomorphismKey, + rhs: &GLWESwitchingKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + self.key.keyswitch(module, &lhs.key, rhs, scratch); + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + rhs: &AutomorphismKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + self.key.keyswitch_inplace(module, &rhs.key, scratch); + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &AutomorphismKey, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + self.key.external_product(module, &lhs.key, rhs, scratch); + } + + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + self.key.external_product_inplace(module, rhs, scratch); + } +} diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index 7deb225..863fd54 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -21,10 +21,10 @@ pub struct GGLWECiphertext { } impl GGLWECiphertext, B> { - pub fn new(module: &Module, base2k: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { + pub fn new(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { Self { - data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(base2k, k)), - basek: base2k, + data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(basek, k)), + basek: basek, k, } } @@ -161,6 +161,7 @@ where (0..cols_in).for_each(|col_i| { (0..rows).for_each(|row_i| { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + vec_znx_pt.data.zero(); // zeroes for next iteration module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, col_i); // Selects the i-th module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3); @@ -175,8 +176,6 @@ where scratch_3, ); - vec_znx_pt.data.zero(); // zeroes for next iteration - // Switch vec_znx_ct into DFT domain vec_znx_ct.dft(module, &mut vec_znx_ct_dft); diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index f0dbae1..422f2cc 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -8,6 +8,7 @@ use sampling::source::Source; use crate::{ SIX_SIGMA, + automorphism::AutomorphismKey, elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, @@ -137,21 +138,40 @@ impl GLWECiphertext> { GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size) } + pub fn automorphism_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + in_size: usize, + autokey_size: usize, + ) -> usize { + GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, in_size, out_rank, autokey_size) + } + + pub fn automorphism_inplace_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + autokey_size: usize, + ) -> usize { + GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, autokey_size) + } + pub fn external_product_scratch_space( module: &Module, out_size: usize, + out_rank: usize, in_size: usize, ggsw_size: usize, - rank: usize, ) -> usize { - let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, ggsw_size); - let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) + let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, ggsw_size); + let vmp: usize = module.bytes_of_vec_znx_dft(out_rank + 1, in_size) + module.vmp_apply_tmp_bytes( out_size, in_size, - in_size, // rows - rank + 1, // cols in - rank + 1, // cols out + in_size, // rows + out_rank + 1, // cols in + out_rank + 1, // cols out ggsw_size, ); let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); @@ -159,8 +179,13 @@ impl GLWECiphertext> { res_dft + (vmp | normalize) } - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - GLWECiphertext::external_product_scratch_space(module, res_size, res_size, rhs, rank) + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + ggsw_size: usize, + ) -> usize { + GLWECiphertext::external_product_scratch_space(module, out_size, out_rank, out_size, ggsw_size) } } @@ -244,6 +269,36 @@ where self.encrypt_pk_private(module, None, pk, source_xu, source_xe, sigma, scratch); } + pub fn automorphism( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + { + self.keyswitch(module, lhs, &rhs.key, scratch); + //(0..self.rank() + 1).for_each(|i| { + // module.vec_znx_automorphism_inplace(rhs.p(), self, i); + //}) + } + + pub fn automorphism_inplace( + &mut self, + module: &Module, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + self.keyswitch_inplace(module, &rhs.key, scratch); + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(rhs.p(), self, i); + }) + } + pub fn keyswitch( &mut self, module: &Module, diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index 34595e3..e01df09 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -15,9 +15,9 @@ use crate::{ pub struct GLWESwitchingKey(pub(crate) GGLWECiphertext); impl GLWESwitchingKey, FFT64> { - pub fn new(module: &Module, base2k: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { + pub fn new(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { GLWESwitchingKey(GGLWECiphertext::new( - module, base2k, k, rows, rank_in, rank_out, + module, basek, k, rows, rank_in, rank_out, )) } } @@ -26,7 +26,7 @@ impl Infos for GLWESwitchingKey { type Inner = MatZnxDft; fn inner(&self) -> &Self::Inner { - &self.0.inner() + self.0.inner() } fn basek(&self) -> usize { @@ -102,38 +102,7 @@ impl GLWESwitchingKey, FFT64> { pub fn encrypt_pk_scratch_space(module: &Module, rank: usize, pk_size: usize) -> usize { GGLWECiphertext::encrypt_pk_scratch_space(module, rank, pk_size) } -} -impl GLWESwitchingKey -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, -{ - pub fn encrypt_sk( - &mut self, - module: &Module, - sk_in: &SecretKey, - sk_out_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) where - ScalarZnx: ScalarZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - self.0.encrypt_sk( - module, - &sk_in.data, - sk_out_dft, - source_xa, - source_xe, - sigma, - scratch, - ); - } -} - -impl GLWESwitchingKey, FFT64> { pub fn keyswitch_scratch_space( module: &Module, out_size: usize, @@ -178,11 +147,34 @@ impl GLWESwitchingKey, FFT64> { tmp + ggsw } } - impl GLWESwitchingKey where MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, { + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_in: &SecretKey, + sk_out_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + ScalarZnx: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + self.0.encrypt_sk( + module, + &sk_in.data, + sk_out_dft, + source_xa, + source_xe, + sigma, + scratch, + ); + } + pub fn keyswitch( &mut self, module: &Module, diff --git a/core/src/lib.rs b/core/src/lib.rs index 60d57c2..f04ca06 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,3 +1,4 @@ +pub mod automorphism; pub mod elem; pub mod gglwe_ciphertext; pub mod ggsw_ciphertext; diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 37bfc4e..525de22 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -6,6 +6,7 @@ use itertools::izip; use sampling::source::Source; use crate::{ + automorphism::AutomorphismKey, elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, @@ -415,6 +416,114 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, ); } +#[test] +fn automorphism() { + (1..4).for_each(|rank| { + println!("test automorphism rank: {}", rank); + test_automorphism(12, 12, 1, 60, 45, 60, rank, 3.2); + }); +} + +fn test_automorphism( + log_n: usize, + basek: usize, + p: i64, + k_autokey: usize, + k_ct_in: usize, + k_ct_out: usize, + rank: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct_in + basek - 1) / basek; + + let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, p, k_autokey, rows, rank); + let mut ct_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank); + let mut ct_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + // pt_want + // .data + // .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + pt_want + .to_mut() + .at_mut(0, 1) + .iter_mut() + .enumerate() + .for_each(|(i, x)| { + *x = i as i64; + }); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, autokey.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size()) + | GLWECiphertext::automorphism_scratch_space(&module, ct_out.size(), rank, ct_in.size(), autokey.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + autokey.encrypt_sk( + &module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_out.automorphism(&module, &ct_in, &autokey, scratch.borrow()); + + ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_automorphism_inplace(p, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct_in, + k_autokey, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} + fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usize, k_ct_out: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n);