Ref. + AVX code & generic tests + benches (#85)

This commit is contained in:
Jean-Philippe Bossuat
2025-09-15 16:16:11 +02:00
committed by GitHub
parent 99b9e3e10e
commit 56dbd29c59
286 changed files with 27797 additions and 7270 deletions

View File

@@ -1,8 +1,8 @@
use poulpy_hal::{
api::{
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace,
VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft,
VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume,
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
};
@@ -54,12 +54,12 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphism
+ VecZnxAutomorphismInplace,
+ VecZnxAutomorphismInplace<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
{
#[cfg(debug_assertions)]
@@ -72,7 +72,7 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
lhs.rank_in()
);
assert_eq!(
lhs.rank_out(),
self.rank_out(),
rhs.rank_in(),
"ksk_in output rank: {} != ksk_apply input rank: {}",
self.rank_out(),
@@ -113,7 +113,7 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
// Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
(0..cols_out).for_each(|i| {
module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i);
module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
});
});
});
@@ -138,17 +138,56 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphism
+ VecZnxAutomorphismInplace,
+ VecZnxAutomorphismInplace<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
{
unsafe {
let self_ptr: *mut GGLWEAutomorphismKey<DataSelf> = self as *mut GGLWEAutomorphismKey<DataSelf>;
self.automorphism(module, &*self_ptr, rhs, scratch);
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank_out(),
rhs.rank_in(),
"ksk_in output rank: {} != ksk_apply input rank: {}",
self.rank_out(),
rhs.rank_in()
);
assert_eq!(
self.rank_out(),
rhs.rank_out(),
"ksk_out output rank: {} != ksk_apply output rank: {}",
self.rank_out(),
rhs.rank_out()
);
}
let cols_out: usize = rhs.rank_out() + 1;
let p: i64 = self.p();
let p_inv = module.galois_element_inv(p);
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i);
// Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
(0..cols_out).for_each(|i| {
module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
});
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
res_ct.keyswitch_inplace(module, &rhs.key, scratch);
// Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
(0..cols_out).for_each(|i| {
module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
});
});
});
self.p = (self.p * rhs.p) % (module.cyclotomic_order() as i64);
}
}

View File

@@ -1,8 +1,8 @@
use poulpy_hal::{
api::{
DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace,
VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace,
VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes,
VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy,
VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch},
@@ -79,16 +79,16 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphismInplace
+ VecZnxAutomorphismInplace<B>
+ VecZnxBigAllocBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftCopy<B>
+ VecZnxDftAddInplace<B>
+ IDFTTmpA<B>,
+ VecZnxIdftApplyTmpA<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B>,
{
#[cfg(debug_assertions)]
@@ -133,7 +133,13 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
)
};
self.automorphism_internal(module, lhs, auto_key, scratch);
// Keyswitch the j-th row of the col 0
(0..lhs.rows()).for_each(|row_i| {
// Key-switch column 0, i.e.
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2)
self.at_mut(row_i, 0)
.automorphism(module, &lhs.at(row_i, 0), auto_key, scratch);
});
self.expand_row(module, tensor_key, scratch);
}
@@ -149,49 +155,25 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphismInplace
+ VecZnxAutomorphismInplace<B>
+ VecZnxBigAllocBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftCopy<B>
+ VecZnxDftAddInplace<B>
+ IDFTTmpA<B>,
+ VecZnxIdftApplyTmpA<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B>,
{
unsafe {
let self_ptr: *mut GGSWCiphertext<DataSelf> = self as *mut GGSWCiphertext<DataSelf>;
self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch);
}
}
fn automorphism_internal<DataLhs: DataRef, DataAk: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
lhs: &GGSWCiphertext<DataLhs>,
auto_key: &GGLWEAutomorphismKeyPrepared<DataAk, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftAllocBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphismInplace,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
{
// Keyswitch the j-th row of the col 0
(0..lhs.rows()).for_each(|row_i| {
(0..self.rows()).for_each(|row_i| {
// Key-switch column 0, i.e.
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2)
self.at_mut(row_i, 0)
.automorphism(module, &lhs.at(row_i, 0), auto_key, scratch);
.automorphism_inplace(module, auto_key, scratch);
});
self.expand_row(module, tensor_key, scratch);
}
}

View File

@@ -1,8 +1,9 @@
use poulpy_hal::{
api::{
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallAInplace,
VecZnxBigSubSmallBInplace, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace,
VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallAInplace, VecZnxBigSubSmallBInplace,
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxBig},
};
@@ -54,16 +55,16 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphismInplace,
+ VecZnxAutomorphismInplace<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
{
self.keyswitch(module, lhs, &rhs.key, scratch);
(0..self.rank() + 1).for_each(|i| {
module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i);
module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch);
})
}
@@ -78,16 +79,16 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphismInplace,
+ VecZnxAutomorphismInplace<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
{
self.keyswitch_inplace(module, &rhs.key, scratch);
(0..self.rank() + 1).for_each(|i| {
module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i);
module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch);
})
}
@@ -103,8 +104,8 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxBigAutomorphismInplace<B>,
@@ -114,12 +115,12 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
{
self.assert_keyswitch(module, lhs, &rhs.key, scratch);
}
let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1);
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
(0..self.cols()).for_each(|i| {
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i);
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i);
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
})
}
@@ -134,17 +135,24 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxBigAutomorphismInplace<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
{
unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
self.automorphism_add(module, &*self_ptr, rhs, scratch);
#[cfg(debug_assertions)]
{
self.assert_keyswitch_inplace(module, &rhs.key, scratch);
}
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
(0..self.cols()).for_each(|i| {
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
module.vec_znx_big_add_small_inplace(&mut res_big, i, &self.data, i);
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
})
}
pub fn automorphism_sub_ab<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
@@ -159,8 +167,8 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxBigAutomorphismInplace<B>
@@ -171,12 +179,12 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
{
self.assert_keyswitch(module, lhs, &rhs.key, scratch);
}
let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1);
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
(0..self.cols()).for_each(|i| {
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i);
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &lhs.data, i);
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
})
}
@@ -191,18 +199,25 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxBigSubSmallAInplace<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
{
unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
self.automorphism_sub_ab(module, &*self_ptr, rhs, scratch);
#[cfg(debug_assertions)]
{
self.assert_keyswitch_inplace(module, &rhs.key, scratch);
}
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
(0..self.cols()).for_each(|i| {
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &self.data, i);
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
})
}
pub fn automorphism_sub_ba<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
@@ -217,8 +232,8 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxBigAutomorphismInplace<B>
@@ -229,12 +244,12 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
{
self.assert_keyswitch(module, lhs, &rhs.key, scratch);
}
let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1);
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
(0..self.cols()).for_each(|i| {
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i);
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &lhs.data, i);
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
})
}
@@ -249,17 +264,24 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxBigSubSmallBInplace<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
{
unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
self.automorphism_sub_ba(module, &*self_ptr, rhs, scratch);
#[cfg(debug_assertions)]
{
self.assert_keyswitch_inplace(module, &rhs.key, scratch);
}
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
(0..self.cols()).for_each(|i| {
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &self.data, i);
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
})
}
}