keyswitch tests

This commit is contained in:
Pro7ech
2025-10-20 15:32:52 +02:00
parent 0c894c19db
commit 252eda36fe
60 changed files with 918 additions and 945 deletions

View File

@@ -8,10 +8,9 @@ use poulpy_hal::{
};
use crate::{
GetDistribution, ScratchTakeCore,
encryption::gglwe_ksk::GLWESwitchingKeyEncryptSk,
GGLWEEncryptSk, GetDistribution, ScratchTakeCore,
layouts::{
GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, GLWESwitchingKey, LWEInfos, Rank, TensorKey, TensorKeyToMut,
GGLWE, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, LWEInfos, Rank, TensorKey, TensorKeyToMut,
prepared::{GLWESecretPrepare, GLWESecretPrepared, GLWESecretPreparedAlloc},
},
};
@@ -63,7 +62,7 @@ pub trait TensorKeyEncryptSk<BE: Backend> {
impl<BE: Backend> TensorKeyEncryptSk<BE> for Module<BE>
where
Self: ModuleN
+ GLWESwitchingKeyEncryptSk<BE>
+ GGLWEEncryptSk<BE>
+ VecZnxDftBytesOf
+ VecZnxBigBytesOf
+ GLWESecretPreparedAlloc<BE>
@@ -83,7 +82,7 @@ where
+ self.bytes_of_vec_znx_big(1, 1)
+ self.bytes_of_vec_znx_dft(1, 1)
+ GLWESecret::bytes_of(self.n().into(), Rank(1))
+ GLWESwitchingKey::encrypt_sk_tmp_bytes(self, infos)
+ GGLWE::encrypt_sk_tmp_bytes(self, infos)
}
fn tensor_key_encrypt_sk<R, S>(
@@ -102,8 +101,8 @@ where
// let n: RingDegree = sk.n();
let rank: Rank = res.rank_out();
let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(self, rank);
sk_dft_prep.prepare(self, sk);
let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, rank);
sk_prepared.prepare(self, sk);
let sk: &GLWESecret<&[u8]> = &sk.to_ref();
@@ -122,7 +121,7 @@ where
(0..rank.into()).for_each(|i| {
(i..rank.into()).for_each(|j| {
self.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
self.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i);
self.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
self.vec_znx_big_normalize(
@@ -135,8 +134,14 @@ where
scratch_5,
);
res.at_mut(i, j)
.encrypt_sk(self, &sk_ij, sk, source_xa, source_xe, scratch_5);
res.at_mut(i, j).encrypt_sk(
self,
&sk_ij.data,
&sk_prepared,
source_xa,
source_xe,
scratch_5,
);
});
})
}