From a3264b8851d466c97257d748129f85bee0b9193d Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Mon, 17 Nov 2025 16:48:52 +0100 Subject: [PATCH] Update cross-base2k keyswitch routine & tests, + add GLWE cross base2k conversion test --- poulpy-core/src/keyswitching/glwe.rs | 129 ++++++++++-------- .../src/tests/test_suite/conversion.rs | 32 +++-- .../tests/test_suite/encryption/glwe_ct.rs | 1 - .../src/tests/test_suite/keyswitch/glwe_ct.rs | 96 +++++++------ 4 files changed, 150 insertions(+), 108 deletions(-) diff --git a/poulpy-core/src/keyswitching/glwe.rs b/poulpy-core/src/keyswitching/glwe.rs index 72def40..be670e6 100644 --- a/poulpy-core/src/keyswitching/glwe.rs +++ b/poulpy-core/src/keyswitching/glwe.rs @@ -4,12 +4,12 @@ use poulpy_hal::{ VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, - layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VecZnxDftToRef, VmpPMat, ZnxInfos}, + layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnxBig, VecZnxDft, VecZnxDftToRef, VmpPMat, ZnxInfos}, }; use crate::{ - ScratchTakeCore, - layouts::{GGLWEInfos, GGLWEPrepared, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos}, + GLWENormalize, ScratchTakeCore, + layouts::{GGLWEInfos, GGLWEPrepared, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, LWEInfos}, }; impl GLWE> { @@ -47,7 +47,7 @@ impl GLWE { impl GLWEKeyswitch for Module where - Self: Sized + GLWEKeySwitchInternal + VecZnxBigNormalizeTmpBytes + VecZnxBigNormalize, + Self: Sized + GLWEKeySwitchInternal + VecZnxBigNormalizeTmpBytes + VecZnxBigNormalize + GLWENormalize, Scratch: ScratchTakeCore, { fn glwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize @@ -57,9 +57,21 @@ where B: GGLWEInfos, { let cols: usize = res_infos.rank().as_usize() + 1; - self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_infos, key_infos) + let size: usize = self + .glwe_keyswitch_internal_tmp_bytes(res_infos, a_infos, key_infos) .max(self.vec_znx_big_normalize_tmp_bytes()) - + self.bytes_of_vec_znx_dft(cols, key_infos.size()) + + self.bytes_of_vec_znx_dft(cols, key_infos.size()); + + if a_infos.base2k() != key_infos.base2k() { + size + GLWE::bytes_of_from_infos(&GLWELayout { + n: a_infos.n(), + base2k: key_infos.base2k(), + k: a_infos.k(), + rank: a_infos.rank(), + }) + } else { + size + } } fn glwe_keyswitch(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) @@ -70,28 +82,28 @@ where { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); - let b: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); + let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); assert_eq!( a.rank(), - b.rank_in(), + key.rank_in(), "a.rank(): {} != b.rank_in(): {}", a.rank(), - b.rank_in() + key.rank_in() ); assert_eq!( res.rank(), - b.rank_out(), + key.rank_out(), "res.rank(): {} != b.rank_out(): {}", res.rank(), - b.rank_out() + key.rank_out() ); assert_eq!(res.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32); - assert_eq!(b.n(), self.n() as u32); + assert_eq!(key.n(), self.n() as u32); - let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, a, b); + let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, a, key); assert!( scratch.available() >= scrach_needed, @@ -99,17 +111,31 @@ where scratch.available(), ); - let basek_out: usize = res.base2k().into(); - let base2k_out: usize = b.base2k().into(); + let base2k_a: usize = a.base2k().into(); + let base2k_key: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().into(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise + + let res_big: VecZnxBig<&mut [u8], BE> = if base2k_a != base2k_key { + let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: a.n(), + base2k: key.base2k(), + k: a.k(), + rank: a.rank(), + }); + self.glwe_normalize(&mut a_conv, a, scratch_2); + self.glwe_keyswitch_internal(res_dft, &a_conv, key, scratch_2) + } else { + self.glwe_keyswitch_internal(res_dft, a, key, scratch_1) + }; - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), b.size()); // Todo optimise - let res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, b, scratch_1); for i in 0..(res.rank() + 1).into() { self.vec_znx_big_normalize( - basek_out, + base2k_res, &mut res.data, i, - base2k_out, + base2k_key, &res_big, i, scratch_1, @@ -151,17 +177,31 @@ where scratch.available(), ); - let base2k_in: usize = res.base2k().into(); - let base2k_out: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().as_usize(); + let base2k_key: usize = key.base2k().as_usize(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise - let res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + + let res_big: VecZnxBig<&mut [u8], BE> = if base2k_res != base2k_key { + let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: res.n(), + base2k: key.base2k(), + k: res.k(), + rank: res.rank(), + }); + self.glwe_normalize(&mut res_conv, res, scratch_2); + + self.glwe_keyswitch_internal(res_dft, &res_conv, key, scratch_2) + } else { + self.glwe_keyswitch_internal(res_dft, res, key, scratch_1) + }; + for i in 0..(res.rank() + 1).into() { self.vec_znx_big_normalize( - base2k_in, - &mut res.data, + base2k_res, + res.data_mut(), i, - base2k_out, + base2k_key, &res_big, i, scratch_1, @@ -216,14 +256,7 @@ where { let cols: usize = (a_infos.rank() + 1).into(); let a_size: usize = a_infos.size(); - - let a_conv = if a_infos.base2k() == key_infos.base2k() { - 0 - } else { - VecZnx::bytes_of(self.n(), 1, a_size) + self.vec_znx_normalize_tmp_bytes() - }; - - self.gglwe_product_dft_tmp_bytes(res_infos.size(), a_size, key_infos) + self.bytes_of_vec_znx_dft(cols, a_size) + a_conv + self.gglwe_product_dft_tmp_bytes(res_infos.size(), a_size, key_infos) + self.bytes_of_vec_znx_dft(cols, a_size) } fn glwe_keyswitch_internal( @@ -241,36 +274,14 @@ where { let a: &GLWE<&[u8]> = &a.to_ref(); let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); - - let base2k_in: usize = a.base2k().into(); - let base2k_out: usize = key.base2k().into(); + assert_eq!(a.base2k(), key.base2k()); let cols: usize = (a.rank() + 1).into(); - let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out); - + let a_size: usize = a.size(); let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, a_size); - - if base2k_in == base2k_out { - for col_i in 0..cols - 1 { - self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, a.data(), col_i + 1); - } - } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size); - for i in 0..cols - 1 { - self.vec_znx_normalize( - base2k_out, - &mut a_conv, - 0, - base2k_in, - a.data(), - i + 1, - scratch_2, - ); - self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0); - } + for col_i in 0..cols - 1 { + self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, a.data(), col_i + 1); } - self.gglwe_product_dft(&mut res, &a_dft, key, scratch_1); - let mut res_big: VecZnxBig = self.vec_znx_idft_apply_consume(res); self.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0); res_big diff --git a/poulpy-core/src/tests/test_suite/conversion.rs b/poulpy-core/src/tests/test_suite/conversion.rs index 334aee5..eecaaea 100644 --- a/poulpy-core/src/tests/test_suite/conversion.rs +++ b/poulpy-core/src/tests/test_suite/conversion.rs @@ -6,9 +6,14 @@ use poulpy_hal::{ use rug::Float; use crate::{ - GLWEDecrypt, GLWEEncryptSk, GLWEFromLWE, GLWENoise, GLWENormalize, GLWEToLWESwitchingKeyEncryptSk, LWEDecrypt, LWEEncryptSk, LWEFromGLWE, LWEToGLWESwitchingKeyEncryptSk, SIGMA, ScratchTakeCore, layouts::{ - Base2K, Degree, Dnum, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWEToLWEKey, GLWEToLWEKeyLayout, GLWEToLWEKeyPrepared, GLWEToLWEKeyPreparedFactory, LWE, LWEInfos, LWELayout, LWEPlaintext, LWESecret, LWEToGLWEKey, LWEToGLWEKeyLayout, LWEToGLWEKeyPrepared, LWEToGLWEKeyPreparedFactory, Rank, TorusPrecision, prepared::GLWESecretPrepared - } + GLWEDecrypt, GLWEEncryptSk, GLWEFromLWE, GLWENoise, GLWENormalize, GLWEToLWESwitchingKeyEncryptSk, LWEDecrypt, LWEEncryptSk, + LWEFromGLWE, LWEToGLWESwitchingKeyEncryptSk, SIGMA, ScratchTakeCore, + layouts::{ + Base2K, Degree, Dnum, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWEToLWEKey, + GLWEToLWEKeyLayout, GLWEToLWEKeyPrepared, GLWEToLWEKeyPreparedFactory, LWE, LWEInfos, LWELayout, LWEPlaintext, LWESecret, + LWEToGLWEKey, LWEToGLWEKeyLayout, LWEToGLWEKeyPrepared, LWEToGLWEKeyPreparedFactory, Rank, TorusPrecision, + prepared::GLWESecretPrepared, + }, }; pub fn test_glwe_base2k_conversion(module: &Module) @@ -70,15 +75,24 @@ where ); let mut data: Vec = (0..module.n()).map(|_| Float::with_val(128, 0)).collect(); - ct_in.data().decode_vec_float(ct_in.base2k().into(), 0, &mut data); + ct_in + .data() + .decode_vec_float(ct_in.base2k().into(), 0, &mut data); - ct_out.fill_uniform(ct_out.base2k().into(),&mut source_xa); + ct_out.fill_uniform(ct_out.base2k().into(), &mut source_xa); module.glwe_normalize(&mut ct_out, &ct_in, scratch.borrow()); - - let mut data_conv: Vec = (0..module.n()).map(|_| Float::with_val(128, 0)).collect(); - ct_out.data().decode_vec_float(ct_out.base2k().into(), 0, &mut data_conv); - ct_out.assert_noise(module, &sk_prep, &pt_out, -(ct_out.k().as_u32() as f64) + SIGMA.log2() + 0.5); + let mut data_conv: Vec = (0..module.n()).map(|_| Float::with_val(128, 0)).collect(); + ct_out + .data() + .decode_vec_float(ct_out.base2k().into(), 0, &mut data_conv); + + ct_out.assert_noise( + module, + &sk_prep, + &pt_out, + -(ct_out.k().as_u32() as f64) + SIGMA.log2() + 0.5, + ); } } } diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs index c4cb283..43b4729 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs @@ -92,7 +92,6 @@ where let k_pt: usize = 30; for rank in 1_usize..3 { - let n: usize = module.n(); let glwe_infos: GLWELayout = GLWELayout { diff --git a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs index f43ec56..4abffc6 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs @@ -5,14 +5,15 @@ use poulpy_hal::{ }; use crate::{ - GLWEEncryptSk, GLWEKeyswitch, GLWENoise, GLWESwitchingKeyEncryptSk, ScratchTakeCore, + GLWEEncryptSk, GLWEKeyswitch, GLWENoise, GLWENormalize, GLWESwitchingKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, - GLWESwitchingKeyPreparedFactory, + GLWESwitchingKeyPreparedFactory, LWEInfos, prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared}, }, noise::log2_std_noise_gglwe_product, + var_noise_gglwe_product_v2, }; #[allow(clippy::too_many_arguments)] @@ -24,44 +25,46 @@ where + GLWEKeyswitch + GLWESecretPreparedFactory + GLWESwitchingKeyPreparedFactory - + GLWENoise, + + GLWENoise + + GLWENormalize, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_glwe: usize = 12; - let base2k_gglwe: usize = 8; - let k_in: usize = 45; - let dsize: usize = k_in.div_ceil(base2k_gglwe); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = 15; + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { - for di in 1_usize..dsize+1 { - let k_ksk: usize = k_in + base2k_gglwe * di; + for dsize in 1_usize..max_dsize + 1 { + let k_ksk: usize = k_in + base2k_key * dsize; let k_out: usize = k_ksk; // better capture noise let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k_gglwe * dsize); + let dnum: usize = k_in.div_ceil(base2k_key * dsize); let glwe_in_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k_glwe.into(), + base2k: base2k_in.into(), k: k_in.into(), rank: rank_in.into(), }; let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k_glwe.into(), + base2k: base2k_out.into(), k: k_out.into(), rank: rank_out.into(), }; let ksk: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k_gglwe.into(), + base2k: base2k_key.into(), k: k_ksk.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank_in: rank_in.into(), rank_out: rank_out.into(), }; @@ -69,13 +72,14 @@ where let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&ksk); let mut glwe_in: GLWE> = GLWE::alloc_from_infos(&glwe_in_infos); let mut glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); + let mut pt_in: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); + let mut pt_out: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform(base2k_glwe, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(pt_in.base2k().into(), &mut pt_in.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk) @@ -106,7 +110,7 @@ where glwe_in.encrypt_sk( module, - &pt_want, + &pt_in, &sk_in_prepared, &mut source_xa, &mut source_xe, @@ -119,20 +123,25 @@ where glwe_out.keyswitch(module, &glwe_in, &ksk_prepared, scratch.borrow()); - let max_noise: f64 = log2_std_noise_gglwe_product( + let max_noise: f64 = var_noise_gglwe_product_v2( module.n() as f64, - base2k_gglwe * dsize, + k_ksk, + dnum, + dsize, + base2k_key, 0.5, 0.5, 0f64, SIGMA * SIGMA, 0f64, rank_in as f64, - k_in, - k_ksk, - ); + ) + .sqrt() + .log2(); - glwe_out.assert_noise(module, &sk_out_prepared, &pt_want, max_noise + 0.5); + module.glwe_normalize(&mut pt_out, &pt_in, scratch.borrow()); + + glwe_out.assert_noise(module, &sk_out_prepared, &pt_out, max_noise + 0.5); } } } @@ -150,30 +159,31 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 45; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ksk: usize = k_out + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_out + base2k_key * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(base2k * dsize); - + let dnum: usize = k_out.div_ceil(base2k_key * dsize); let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), rank: rank.into(), }; let ksk_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank_in: rank.into(), rank_out: rank.into(), }; @@ -186,7 +196,12 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform( + pt_want.base2k().into(), + &mut pt_want.data, + 0, + &mut source_xa, + ); let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_infos) @@ -230,18 +245,21 @@ where glwe_out.keyswitch_inplace(module, &ksk_prepared, scratch.borrow()); - let max_noise: f64 = log2_std_noise_gglwe_product( + let max_noise: f64 = var_noise_gglwe_product_v2( module.n() as f64, - base2k * dsize, + k_ksk, + dnum, + dsize, + base2k_key, 0.5, 0.5, 0f64, SIGMA * SIGMA, 0f64, rank as f64, - k_out, - k_ksk, - ); + ) + .sqrt() + .log2(); glwe_out.assert_noise(module, &sk_out_prepared, &pt_want, max_noise + 0.5); }