keyswitch tests

This commit is contained in:
Pro7ech
2025-10-20 15:32:52 +02:00
parent 0c894c19db
commit 252eda36fe
60 changed files with 918 additions and 945 deletions

View File

@@ -3,11 +3,7 @@ use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch};
use crate::{
ScratchTakeCore,
keyswitching::glwe_ct::GLWEKeyswitch,
layouts::{
AutomorphismKey, AutomorphismKeyToRef, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWESwitchingKey,
GLWESwitchingKeyToRef,
prepared::{GLWESwitchingKeyPrepared, GLWESwitchingKeyPreparedToRef},
},
layouts::{AutomorphismKey, GGLWE, GGLWEInfos, GGLWEPreparedToRef, GGLWEToMut, GGLWEToRef, GLWESwitchingKey},
};
impl AutomorphismKey<Vec<u8>> {
@@ -25,21 +21,21 @@ impl AutomorphismKey<Vec<u8>> {
impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
pub fn keyswitch<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
A: AutomorphismKeyToRef,
B: GLWESwitchingKeyPreparedToRef<BE>,
A: GGLWEToRef + GGLWEToRef,
B: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeyswitch<BE>,
{
module.gglwe_keyswitch(&mut self.key.key, &a.to_ref().key.key, b, scratch);
module.gglwe_keyswitch(self, a, b, scratch);
}
pub fn keyswitch_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
where
A: GLWESwitchingKeyPreparedToRef<BE>,
A: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeyswitch<BE>,
{
module.gglwe_keyswitch_inplace(&mut self.key.key, a, scratch);
module.gglwe_keyswitch_inplace(self, a, scratch);
}
}
@@ -58,21 +54,21 @@ impl GLWESwitchingKey<Vec<u8>> {
impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
pub fn keyswitch<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
A: GLWESwitchingKeyToRef,
B: GLWESwitchingKeyPreparedToRef<BE>,
A: GGLWEToRef,
B: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeyswitch<BE>,
{
module.gglwe_keyswitch(&mut self.key, &a.to_ref().key, b, scratch);
module.gglwe_keyswitch(self, a, b, scratch);
}
pub fn keyswitch_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
where
A: GLWESwitchingKeyPreparedToRef<BE>,
A: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeyswitch<BE>,
{
module.gglwe_keyswitch_inplace(&mut self.key, a, scratch);
module.gglwe_keyswitch_inplace(self, a, scratch);
}
}
@@ -92,7 +88,7 @@ impl<DataSelf: DataMut> GGLWE<DataSelf> {
pub fn keyswitch<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
A: GGLWEToRef,
B: GLWESwitchingKeyPreparedToRef<BE>,
B: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeyswitch<BE>,
{
@@ -101,7 +97,7 @@ impl<DataSelf: DataMut> GGLWE<DataSelf> {
pub fn keyswitch_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
where
A: GLWESwitchingKeyPreparedToRef<BE>,
A: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeyswitch<BE>,
{
@@ -128,12 +124,11 @@ where
where
R: GGLWEToMut,
A: GGLWEToRef,
B: GLWESwitchingKeyPreparedToRef<BE>,
B: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut();
let a: &GGLWE<&[u8]> = &a.to_ref();
let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &b.to_ref();
assert_eq!(
res.rank_in(),
@@ -180,11 +175,10 @@ where
fn gglwe_keyswitch_inplace<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
where
R: GGLWEToMut,
A: GLWESwitchingKeyPreparedToRef<BE>,
A: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &a.to_ref();
assert_eq!(
res.rank_out(),

View File

@@ -3,10 +3,7 @@ use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, VecZnx};
use crate::{
GGSWExpandRows, ScratchTakeCore,
keyswitching::glwe_ct::GLWEKeyswitch,
layouts::{
GGLWEInfos, GGSW, GGSWInfos, GGSWToMut, GGSWToRef,
prepared::{GLWESwitchingKeyPreparedToRef, TensorKeyPreparedToRef},
},
layouts::{GGLWEInfos, GGLWEPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, prepared::TensorKeyPreparedToRef},
};
impl GGSW<Vec<u8>> {
@@ -32,7 +29,7 @@ impl<D: DataMut> GGSW<D> {
pub fn keyswitch<M, A, K, T, BE: Backend>(&mut self, module: &M, a: &A, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
where
A: GGSWToRef,
K: GLWESwitchingKeyPreparedToRef<BE>,
K: GGLWEPreparedToRef<BE>,
T: TensorKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGSWKeyswitch<BE>,
@@ -42,7 +39,7 @@ impl<D: DataMut> GGSW<D> {
pub fn keyswitch_inplace<M, K, T, BE: Backend>(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
where
K: GLWESwitchingKeyPreparedToRef<BE>,
K: GGLWEPreparedToRef<BE>,
T: TensorKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGSWKeyswitch<BE>,
@@ -93,14 +90,15 @@ where
where
R: GGSWToMut,
A: GGSWToRef,
K: GLWESwitchingKeyPreparedToRef<BE>,
K: GGLWEPreparedToRef<BE>,
T: TensorKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let a: &GGSW<&[u8]> = &a.to_ref();
assert_eq!(res.ggsw_layout(), a.ggsw_layout());
assert!(res.dnum() <= a.dnum());
assert_eq!(res.dsize(), a.dsize());
for row in 0..a.dnum().into() {
// Key-switch column 0, i.e.
@@ -114,7 +112,7 @@ where
fn ggsw_keyswitch_inplace<R, K, T>(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
where
R: GGSWToMut,
K: GLWESwitchingKeyPreparedToRef<BE>,
K: GGLWEPreparedToRef<BE>,
T: TensorKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{

View File

@@ -4,15 +4,12 @@ use poulpy_hal::{
VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft,
VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
};
use crate::{
ScratchTakeCore,
layouts::{
GGLWEInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos,
prepared::{GLWESwitchingKeyPrepared, GLWESwitchingKeyPreparedToRef},
},
layouts::{GGLWEInfos, GGLWEPrepared, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos},
};
impl GLWE<Vec<u8>> {
@@ -31,7 +28,7 @@ impl<D: DataMut> GLWE<D> {
pub fn keyswitch<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
A: GLWEToRef,
B: GLWESwitchingKeyPreparedToRef<BE>,
B: GGLWEPreparedToRef<BE>,
M: GLWEKeyswitch<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
@@ -40,7 +37,7 @@ impl<D: DataMut> GLWE<D> {
pub fn keyswitch_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
where
A: GLWESwitchingKeyPreparedToRef<BE>,
A: GGLWEPreparedToRef<BE>,
M: GLWEKeyswitch<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
@@ -129,12 +126,12 @@ where
where
R: GLWEToMut,
A: GLWEToRef,
K: GLWESwitchingKeyPreparedToRef<BE>,
K: GGLWEPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &key.to_ref();
let b: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
assert_eq!(
a.rank(),
@@ -184,11 +181,11 @@ where
fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
K: GLWESwitchingKeyPreparedToRef<BE>,
K: GGLWEPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &key.to_ref();
let a: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
assert_eq!(
res.rank(),
@@ -239,17 +236,17 @@ impl GLWE<Vec<u8>> {}
impl<DataSelf: DataMut> GLWE<DataSelf> {}
pub(crate) fn keyswitch_internal<BE: Backend, M, DR, DA, DB>(
pub(crate) fn keyswitch_internal<BE: Backend, M, DR, A, K>(
module: &M,
mut res: VecZnxDft<DR, BE>,
a: &GLWE<DA>,
key: &GLWESwitchingKeyPrepared<DB, BE>,
a: &A,
key: &K,
scratch: &mut Scratch<BE>,
) -> VecZnxBig<DR, BE>
where
DR: DataMut,
DA: DataRef,
DB: DataRef,
A: GLWEToRef,
K: GGLWEPreparedToRef<BE>,
M: ModuleN
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
@@ -264,11 +261,14 @@ where
+ VecZnxNormalize<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let a: &GLWE<&[u8]> = &a.to_ref();
let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
let base2k_in: usize = a.base2k().into();
let base2k_out: usize = key.base2k().into();
let cols: usize = (a.rank() + 1).into();
let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out);
let pmat: &VmpPMat<DB, BE> = &key.key.data;
let pmat: &VmpPMat<&[u8], BE> = &key.data;
if key.dsize() == 1 {
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size());

View File

@@ -6,10 +6,7 @@ use poulpy_hal::{
use crate::{
LWESampleExtract, ScratchTakeCore,
keyswitching::glwe_ct::GLWEKeyswitch,
layouts::{
GGLWEInfos, GLWE, GLWELayout, LWE, LWEInfos, LWEToMut, LWEToRef, Rank, TorusPrecision,
prepared::{LWESwitchingKeyPrepared, LWESwitchingKeyPreparedToRef},
},
layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWELayout, LWE, LWEInfos, LWEToMut, LWEToRef, Rank, TorusPrecision},
};
impl LWE<Vec<u8>> {
@@ -28,7 +25,7 @@ impl<D: DataMut> LWE<D> {
pub fn keyswitch<M, A, K, BE: Backend>(&mut self, module: &M, a: &A, ksk: &K, scratch: &mut Scratch<BE>)
where
A: LWEToRef,
K: LWESwitchingKeyPreparedToRef<BE>,
K: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
M: LWEKeySwitch<BE>,
{
@@ -36,7 +33,7 @@ impl<D: DataMut> LWE<D> {
}
}
impl<BE: Backend> LWEKeySwitch<BE> for Module<BE> where Self: LWEKeySwitch<BE> {}
impl<BE: Backend> LWEKeySwitch<BE> for Module<BE> where Self: GLWEKeyswitch<BE> + LWESampleExtract {}
pub trait LWEKeySwitch<BE: Backend>
where
@@ -75,12 +72,11 @@ where
where
R: LWEToMut,
A: LWEToRef,
K: LWESwitchingKeyPreparedToRef<BE>,
K: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut LWE<&mut [u8]> = &mut res.to_mut();
let a: &LWE<&[u8]> = &a.to_ref();
let ksk: &LWESwitchingKeyPrepared<&[u8], BE> = &ksk.to_ref();
assert!(res.n().as_usize() <= self.n());
assert!(a.n().as_usize() <= self.n());
@@ -120,7 +116,7 @@ where
glwe_in.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
}
self.glwe_keyswitch(&mut glwe_out, &glwe_in, &ksk.0, scratch_1);
self.glwe_keyswitch(&mut glwe_out, &glwe_in, ksk, scratch_1);
self.lwe_sample_extract(res, &glwe_out);
}
}