From b93e011347b5198012372038cb34c7c4960db432 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 17 Jun 2025 09:46:22 +0200 Subject: [PATCH] fixed automorphism on gglwe for k_out < k_in --- core/src/automorphism.rs | 52 ++++++++++++++++--------- core/src/test_fft64/automorphism_key.rs | 8 ++-- core/src/test_fft64/mod.rs | 2 +- 3 files changed, 38 insertions(+), 24 deletions(-) diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index e9e4a0b..27ea44a 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -249,24 +249,33 @@ impl + AsRef<[u8]>> AutomorphismKey { self.rank_out(), rhs.rank_out() ); + assert!( + self.k() <= lhs.k(), + "output k={} cannot be greater than input k={}", + self.k(), + lhs.k() + ) } let cols_out: usize = rhs.rank_out() + 1; - let (mut tmp_dft, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { - // Extracts relevant row - lhs.get_row(module, row_j, col_i, &mut tmp_dft); + let (mut tmp_idft_data, scratct1) = scratch.tmp_vec_znx_big(module, cols_out, self.size()); - // Get a VecZnxBig from scratch space - let (mut tmp_idft_data, scratch2) = scratch1.tmp_vec_znx_big(module, cols_out, self.size()); + { + let (mut tmp_dft, scratch2) = scratct1.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - // Switches input outside of DFT - (0..cols_out).for_each(|i| { - module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2); - }); + // Extracts relevant row + lhs.get_row(module, row_j, col_i, &mut tmp_dft); + + // Get a VecZnxBig from scratch space + + // Switches input outside of DFT + (0..cols_out).for_each(|i| { + module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2); + }); + } // Consumes to small vec znx let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small(); @@ -284,20 +293,25 @@ impl + AsRef<[u8]>> AutomorphismKey { }; // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) - tmp_idft.keyswitch_inplace(module, &rhs.key, scratch2); + tmp_idft.keyswitch_inplace(module, &rhs.key, scratct1); - // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) - // and switches back to DFT domain - (0..self.rank_out() + 1).for_each(|i| { - module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft.data, i); - module.vec_znx_dft(1, 0, &mut tmp_dft.data, i, &tmp_idft.data, i); - }); + { + let (mut tmp_dft, _) = scratct1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); - // Sets back the relevant row - self.set_row(module, row_j, col_i, &tmp_dft); + // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) + // and switches back to DFT domain + (0..self.rank_out() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft.data, i); + module.vec_znx_dft(1, 0, &mut tmp_dft.data, i, &tmp_idft.data, i); + }); + + // Sets back the relevant row + self.set_row(module, row_j, col_i, &tmp_dft); + } }); }); + let (mut tmp_dft, _) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); tmp_dft.data.zero(); (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { diff --git a/core/src/test_fft64/automorphism_key.rs b/core/src/test_fft64/automorphism_key.rs index fa1ca37..f23b619 100644 --- a/core/src/test_fft64/automorphism_key.rs +++ b/core/src/test_fft64/automorphism_key.rs @@ -10,7 +10,7 @@ fn automorphism() { let log_n: usize = 8; let basek: usize = 12; let k_in: usize = 60; - let k_out: usize = 60; + let k_out: usize = 40; let digits: usize = k_in.div_ceil(basek); let sigma: f64 = 3.2; (1..4).for_each(|rank| { @@ -141,12 +141,12 @@ fn test_automorphism( sigma * sigma, 0f64, rank as f64, - k_in, + k_out, k_apply, ); assert!( - (noise_have - noise_want).abs() <= 0.5, + noise_have < noise_want + 0.5, "{} {}", noise_have, noise_want @@ -260,7 +260,7 @@ fn test_automorphism_inplace( ); assert!( - (noise_have - noise_want).abs() <= 0.5, + noise_have < noise_want + 0.5, "{} {}", noise_have, noise_want diff --git a/core/src/test_fft64/mod.rs b/core/src/test_fft64/mod.rs index 6fcecb1..73a58e9 100644 --- a/core/src/test_fft64/mod.rs +++ b/core/src/test_fft64/mod.rs @@ -62,7 +62,7 @@ pub(crate) fn log2_std_noise_gglwe_product( b_logq, ); noise = noise.sqrt(); - noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] + noise.log2().min(-1.0).max(-(a_logq as f64)) // max noise is [-2^{-1}, 2^{-1}] } pub(crate) fn noise_ggsw_product(