automorphism gglwe

This commit is contained in:
Jean-Philippe Bossuat
2025-10-16 10:52:55 +02:00
parent 3236e1be2c
commit bdd00b557f
6 changed files with 230 additions and 235 deletions

View File

@@ -1,198 +1,169 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::VecZnxAutomorphism,
ScratchAvailable, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, layouts::{Backend, DataMut, GaloisElement, Module, Scratch},
VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
}; };
use crate::layouts::{AutomorphismKey, GGLWEInfos, GLWE, prepared::AutomorphismKeyPrepared}; use crate::{
ScratchTakeCore,
automorphism::glwe_ct::GLWEAutomorphism,
layouts::{
AutomorphismKey, AutomorphismKeyToMut, AutomorphismKeyToRef, GGLWEInfos, GLWE, GLWEInfos,
prepared::{
AutomorphismKeyPrepared, AutomorphismKeyPreparedToRef, GetAutomorphismGaloisElement, SetAutomorphismGaloisElement,
},
},
};
impl AutomorphismKey<Vec<u8>> { impl AutomorphismKey<Vec<u8>> {
pub fn automorphism_tmp_bytes<B: Backend, OUT, IN, KEY>( pub fn automorphism_tmp_bytes<R, A, K, M, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
module: &Module<B>,
out_infos: &OUT,
in_infos: &IN,
key_infos: &KEY,
) -> usize
where where
OUT: GGLWEInfos, R: GGLWEInfos,
IN: GGLWEInfos, A: GGLWEInfos,
KEY: GGLWEInfos, K: GGLWEInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, M: AutomorphismKeyAutomorphism<BE>,
{ {
GLWE::keyswitch_tmp_bytes( module.automorphism_key_automorphism_tmp_bytes(res_infos, a_infos, key_infos)
module,
&out_infos.glwe_layout(),
&in_infos.glwe_layout(),
key_infos,
)
}
pub fn automorphism_inplace_tmp_bytes<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
where
OUT: GGLWEInfos,
KEY: GGLWEInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
{
AutomorphismKey::automorphism_tmp_bytes(module, out_infos, out_infos, key_infos)
} }
} }
impl<DataSelf: DataMut> AutomorphismKey<DataSelf> { impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
pub fn automorphism<DataLhs: DataRef, DataRhs: DataRef, B: Backend>( pub fn automorphism<A, K, M, BE: Backend>(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch<BE>)
&mut self, where
module: &Module<B>, A: AutomorphismKeyToRef + GetAutomorphismGaloisElement,
lhs: &AutomorphismKey<DataLhs>, K: AutomorphismKeyPreparedToRef<BE> + GetAutomorphismGaloisElement,
rhs: &AutomorphismKeyPrepared<DataRhs, B>, Scratch<BE>: ScratchTakeCore<BE>,
scratch: &mut Scratch<B>, M: AutomorphismKeyAutomorphism<BE>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphism
+ VecZnxAutomorphismInplace<B>
+ VecZnxNormalize<B>
+ VecZnxNormalizeTmpBytes,
Scratch<B>: ScratchAvailable,
{ {
#[cfg(debug_assertions)] module.automorphism_key_automorphism(self, a, key, scratch);
{ }
use crate::layouts::LWEInfos;
pub fn automorphism_inplace<K, M, BE: Backend>(&mut self, module: &M, key: &K, scratch: &mut Scratch<BE>)
where
K: AutomorphismKeyPreparedToRef<BE> + GetAutomorphismGaloisElement,
Scratch<BE>: ScratchTakeCore<BE>,
M: AutomorphismKeyAutomorphism<BE>,
{
module.automorphism_key_automorphism_inplace(self, key, scratch);
}
}
impl<BE: Backend> AutomorphismKeyAutomorphism<BE> for Module<BE> where
Self: GaloisElement + GLWEAutomorphism<BE> + VecZnxAutomorphism
{
}
pub trait AutomorphismKeyAutomorphism<BE: Backend>
where
Self: GaloisElement + GLWEAutomorphism<BE> + VecZnxAutomorphism,
{
fn automorphism_key_automorphism_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
where
R: GGLWEInfos,
A: GGLWEInfos,
K: GGLWEInfos,
{
self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
}
fn automorphism_key_automorphism<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
where
R: AutomorphismKeyToMut + SetAutomorphismGaloisElement,
A: AutomorphismKeyToRef + GetAutomorphismGaloisElement,
K: AutomorphismKeyPreparedToRef<BE> + GetAutomorphismGaloisElement,
Scratch<BE>: ScratchTakeCore<BE>,
{
{
let res: &mut AutomorphismKey<&mut [u8]> = &mut res.to_mut();
let a: &AutomorphismKey<&[u8]> = &a.to_ref();
let key: &AutomorphismKeyPrepared<&[u8], _> = &key.to_ref();
assert_eq!(
self.rank_in(),
lhs.rank_in(),
"ksk_out input rank: {} != ksk_in input rank: {}",
self.rank_in(),
lhs.rank_in()
);
assert_eq!(
self.rank_out(),
rhs.rank_in(),
"ksk_in output rank: {} != ksk_apply input rank: {}",
self.rank_out(),
rhs.rank_in()
);
assert_eq!(
self.rank_out(),
rhs.rank_out(),
"ksk_out output rank: {} != ksk_apply output rank: {}",
self.rank_out(),
rhs.rank_out()
);
assert!( assert!(
self.k() <= lhs.k(), res.dnum().as_u32() <= a.dnum().as_u32(),
"output k={} cannot be greater than input k={}", "res dnum: {} > a dnum: {}",
self.k(), res.dnum(),
lhs.k() a.dnum()
) );
}
let cols_out: usize = (rhs.rank_out() + 1).into(); assert_eq!(
res.dsize(),
a.dsize(),
"res dnum: {} != a dnum: {}",
res.dsize(),
a.dsize()
);
let p: i64 = lhs.p(); let cols_out: usize = (key.rank_out() + 1).into();
let p_inv: i64 = module.galois_element_inv(p);
(0..self.rank_in().into()).for_each(|col_i| { let p: i64 = a.p();
(0..self.dnum().into()).for_each(|row_j| { let p_inv: i64 = self.galois_element_inv(p);
let mut res_ct: GLWE<&mut [u8]> = self.at_mut(row_j, col_i);
let lhs_ct: GLWE<&[u8]> = lhs.at(row_j, col_i); for row in 0..res.dnum().as_usize() {
for col in 0..cols_out {
let mut res_tmp: GLWE<&mut [u8]> = res.at_mut(row, col);
let a_ct: GLWE<&[u8]> = a.at(row, col);
// Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
(0..cols_out).for_each(|i| { for i in 0..cols_out {
module.vec_znx_automorphism(lhs.p(), &mut res_ct.data, i, &lhs_ct.data, i); self.vec_znx_automorphism(a.p(), res_tmp.data_mut(), i, &a_ct.data, i);
}); }
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
res_ct.keyswitch_inplace(module, &rhs.key, scratch); self.glwe_keyswitch_inplace(&mut res_tmp, &key.key, scratch);
// Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
(0..cols_out).for_each(|i| { (0..cols_out).for_each(|i| {
module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch);
}); });
}); }
}); }
(self.dnum().min(lhs.dnum()).into()..self.dnum().into()).for_each(|row_i| {
(0..self.rank_in().into()).for_each(|col_j| {
self.at_mut(row_i, col_j).data.zero();
});
});
self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64);
} }
pub fn automorphism_inplace<DataRhs: DataRef, B: Backend>( res.set_p((a.p() * key.p()) % (self.cyclotomic_order() as i64));
&mut self,
module: &Module<B>,
rhs: &AutomorphismKeyPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphism
+ VecZnxAutomorphismInplace<B>
+ VecZnxNormalize<B>
+ VecZnxNormalizeTmpBytes,
Scratch<B>: ScratchAvailable,
{
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank_out(),
rhs.rank_in(),
"ksk_in output rank: {} != ksk_apply input rank: {}",
self.rank_out(),
rhs.rank_in()
);
assert_eq!(
self.rank_out(),
rhs.rank_out(),
"ksk_out output rank: {} != ksk_apply output rank: {}",
self.rank_out(),
rhs.rank_out()
);
} }
let cols_out: usize = (rhs.rank_out() + 1).into(); fn automorphism_key_automorphism_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
where
R: AutomorphismKeyToMut + SetAutomorphismGaloisElement + GetAutomorphismGaloisElement,
K: AutomorphismKeyPreparedToRef<BE> + GetAutomorphismGaloisElement,
Scratch<BE>: ScratchTakeCore<BE>,
{
{
let res: &mut AutomorphismKey<&mut [u8]> = &mut res.to_mut();
let key: &AutomorphismKeyPrepared<&[u8], _> = &key.to_ref();
let p: i64 = self.p(); assert_eq!(
let p_inv = module.galois_element_inv(p); res.rank(),
key.rank(),
"key rank: {} != key rank: {}",
res.rank(),
key.rank()
);
(0..self.rank_in().into()).for_each(|col_i| { let cols_out: usize = (key.rank_out() + 1).into();
(0..self.dnum().into()).for_each(|row_j| {
let mut res_ct: GLWE<&mut [u8]> = self.at_mut(row_j, col_i); let p: i64 = res.p();
let p_inv: i64 = self.galois_element_inv(p);
for row in 0..res.dnum().as_usize() {
for col in 0..cols_out {
let mut res_tmp: GLWE<&mut [u8]> = res.at_mut(row, col);
// Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
(0..cols_out).for_each(|i| { for i in 0..cols_out {
module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch);
}); }
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
res_ct.keyswitch_inplace(module, &rhs.key, scratch); self.glwe_keyswitch_inplace(&mut res_tmp, &key.key, scratch);
// Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
(0..cols_out).for_each(|i| { for i in 0..cols_out {
module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch);
}); }
}); }
}); }
}
self.p = (self.p * rhs.p) % (module.cyclotomic_order() as i64); res.set_p((res.p() * key.p()) % (self.cyclotomic_order() as i64));
} }
} }

View File

@@ -16,7 +16,7 @@ impl AutomorphismKey<Vec<u8>> {
R: GGLWEInfos, R: GGLWEInfos,
A: GGLWEInfos, A: GGLWEInfos,
K: GGLWEInfos, K: GGLWEInfos,
M: GGLWEKeySwitch<BE>, M: GGLWEKeyswitch<BE>,
{ {
module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
} }
@@ -28,7 +28,7 @@ impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
A: AutomorphismKeyToRef, A: AutomorphismKeyToRef,
B: GLWESwitchingKeyPreparedToRef<BE>, B: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeySwitch<BE>, M: GGLWEKeyswitch<BE>,
{ {
module.gglwe_keyswitch(&mut self.key.key, &a.to_ref().key.key, b, scratch); module.gglwe_keyswitch(&mut self.key.key, &a.to_ref().key.key, b, scratch);
} }
@@ -37,7 +37,7 @@ impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
where where
A: GLWESwitchingKeyPreparedToRef<BE>, A: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeySwitch<BE>, M: GGLWEKeyswitch<BE>,
{ {
module.gglwe_keyswitch_inplace(&mut self.key.key, a, scratch); module.gglwe_keyswitch_inplace(&mut self.key.key, a, scratch);
} }
@@ -49,7 +49,7 @@ impl GLWESwitchingKey<Vec<u8>> {
R: GGLWEInfos, R: GGLWEInfos,
A: GGLWEInfos, A: GGLWEInfos,
K: GGLWEInfos, K: GGLWEInfos,
M: GGLWEKeySwitch<BE>, M: GGLWEKeyswitch<BE>,
{ {
module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
} }
@@ -61,7 +61,7 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
A: GLWESwitchingKeyToRef, A: GLWESwitchingKeyToRef,
B: GLWESwitchingKeyPreparedToRef<BE>, B: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeySwitch<BE>, M: GGLWEKeyswitch<BE>,
{ {
module.gglwe_keyswitch(&mut self.key, &a.to_ref().key, b, scratch); module.gglwe_keyswitch(&mut self.key, &a.to_ref().key, b, scratch);
} }
@@ -70,7 +70,7 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
where where
A: GLWESwitchingKeyPreparedToRef<BE>, A: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeySwitch<BE>, M: GGLWEKeyswitch<BE>,
{ {
module.gglwe_keyswitch_inplace(&mut self.key, a, scratch); module.gglwe_keyswitch_inplace(&mut self.key, a, scratch);
} }
@@ -82,7 +82,7 @@ impl GGLWE<Vec<u8>> {
R: GGLWEInfos, R: GGLWEInfos,
A: GGLWEInfos, A: GGLWEInfos,
K: GGLWEInfos, K: GGLWEInfos,
M: GGLWEKeySwitch<BE>, M: GGLWEKeyswitch<BE>,
{ {
module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
} }
@@ -94,7 +94,7 @@ impl<DataSelf: DataMut> GGLWE<DataSelf> {
A: GGLWEToRef, A: GGLWEToRef,
B: GLWESwitchingKeyPreparedToRef<BE>, B: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeySwitch<BE>, M: GGLWEKeyswitch<BE>,
{ {
module.gglwe_keyswitch(self, a, b, scratch); module.gglwe_keyswitch(self, a, b, scratch);
} }
@@ -103,15 +103,15 @@ impl<DataSelf: DataMut> GGLWE<DataSelf> {
where where
A: GLWESwitchingKeyPreparedToRef<BE>, A: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeySwitch<BE>, M: GGLWEKeyswitch<BE>,
{ {
module.gglwe_keyswitch_inplace(self, a, scratch); module.gglwe_keyswitch_inplace(self, a, scratch);
} }
} }
impl<BE: Backend> GGLWEKeySwitch<BE> for Module<BE> where Self: GLWEKeyswitch<BE> {} impl<BE: Backend> GGLWEKeyswitch<BE> for Module<BE> where Self: GLWEKeyswitch<BE> {}
pub trait GGLWEKeySwitch<BE: Backend> pub trait GGLWEKeyswitch<BE: Backend>
where where
Self: GLWEKeyswitch<BE>, Self: GLWEKeyswitch<BE>,
{ {

View File

@@ -22,7 +22,7 @@ impl GGSW<Vec<u8>> {
A: GGSWInfos, A: GGSWInfos,
K: GGLWEInfos, K: GGLWEInfos,
T: GGLWEInfos, T: GGLWEInfos,
M: GGSWKeySwitch<BE>, M: GGSWKeyswitch<BE>,
{ {
module.ggsw_keyswitch_tmp_bytes(res_infos, a_infos, key_infos, tsk_infos) module.ggsw_keyswitch_tmp_bytes(res_infos, a_infos, key_infos, tsk_infos)
} }
@@ -35,7 +35,7 @@ impl<D: DataMut> GGSW<D> {
K: GLWESwitchingKeyPreparedToRef<BE>, K: GLWESwitchingKeyPreparedToRef<BE>,
T: TensorKeyPreparedToRef<BE>, T: TensorKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGSWKeySwitch<BE>, M: GGSWKeyswitch<BE>,
{ {
module.ggsw_keyswitch(self, a, key, tsk, scratch); module.ggsw_keyswitch(self, a, key, tsk, scratch);
} }
@@ -45,13 +45,13 @@ impl<D: DataMut> GGSW<D> {
K: GLWESwitchingKeyPreparedToRef<BE>, K: GLWESwitchingKeyPreparedToRef<BE>,
T: TensorKeyPreparedToRef<BE>, T: TensorKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGSWKeySwitch<BE>, M: GGSWKeyswitch<BE>,
{ {
module.ggsw_keyswitch_inplace(self, key, tsk, scratch); module.ggsw_keyswitch_inplace(self, key, tsk, scratch);
} }
} }
pub trait GGSWKeySwitch<BE: Backend> pub trait GGSWKeyswitch<BE: Backend>
where where
Self: GLWEKeyswitch<BE> + GGSWExpandRows<BE>, Self: GLWEKeyswitch<BE> + GGSWExpandRows<BE>,
{ {

View File

@@ -16,14 +16,14 @@ use crate::{
}; };
impl GLWE<Vec<u8>> { impl GLWE<Vec<u8>> {
pub fn keyswitch_tmp_bytes<M, R, A, B, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize pub fn keyswitch_tmp_bytes<M, R, A, B, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
where where
R: GLWEInfos, R: GLWEInfos,
A: GLWEInfos, A: GLWEInfos,
B: GGLWEInfos, B: GGLWEInfos,
M: GLWEKeyswitch<BE>, M: GLWEKeyswitch<BE>,
{ {
module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, b_infos) module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
} }
} }
@@ -89,7 +89,7 @@ where
+ VecZnxNormalize<BE> + VecZnxNormalize<BE>
+ VecZnxNormalizeTmpBytes, + VecZnxNormalizeTmpBytes,
{ {
fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
where where
R: GLWEInfos, R: GLWEInfos,
A: GLWEInfos, A: GLWEInfos,
@@ -97,44 +97,44 @@ where
{ {
let in_size: usize = a_infos let in_size: usize = a_infos
.k() .k()
.div_ceil(b_infos.base2k()) .div_ceil(key_infos.base2k())
.div_ceil(b_infos.dsize().into()) as usize; .div_ceil(key_infos.dsize().into()) as usize;
let out_size: usize = res_infos.size(); let out_size: usize = res_infos.size();
let ksk_size: usize = b_infos.size(); let ksk_size: usize = key_infos.size();
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE let res_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE
let ai_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank_in()).into(), in_size); let ai_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size);
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
out_size, out_size,
in_size, in_size,
in_size, in_size,
(b_infos.rank_in()).into(), (key_infos.rank_in()).into(),
(b_infos.rank_out() + 1).into(), (key_infos.rank_out() + 1).into(),
ksk_size, ksk_size,
) + self.bytes_of_vec_znx_dft((b_infos.rank_in()).into(), in_size); ) + self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size);
let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes(); let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes();
if a_infos.base2k() == b_infos.base2k() { if a_infos.base2k() == key_infos.base2k() {
res_dft + ((ai_dft + vmp) | normalize_big) res_dft + ((ai_dft + vmp) | normalize_big)
} else if b_infos.dsize() == 1 { } else if key_infos.dsize() == 1 {
// In this case, we only need one column, temporary, that we can drop once a_dft is computed. // In this case, we only need one column, temporary, that we can drop once a_dft is computed.
let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes(); let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes();
res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big) res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big)
} else { } else {
// Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion. // Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion.
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank_in()).into(), in_size); let normalize_conv: usize = VecZnx::bytes_of(self.n(), (key_infos.rank_in()).into(), in_size);
res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
} }
} }
fn glwe_keyswitch<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>) fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
where where
R: GLWEToMut, R: GLWEToMut,
A: GLWEToRef, A: GLWEToRef,
B: GLWESwitchingKeyPreparedToRef<BE>, K: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &b.to_ref(); let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &key.to_ref();
assert_eq!( assert_eq!(
a.rank(), a.rank(),
@@ -181,14 +181,14 @@ where
}) })
} }
fn glwe_keyswitch_inplace<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>) fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
where where
R: GLWEToMut, R: GLWEToMut,
A: GLWESwitchingKeyPreparedToRef<BE>, K: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &a.to_ref(); let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &key.to_ref();
assert_eq!( assert_eq!(
res.rank(), res.rank(),
@@ -243,7 +243,7 @@ pub(crate) fn keyswitch_internal<BE: Backend, M, DR, DA, DB>(
module: &M, module: &M,
mut res: VecZnxDft<DR, BE>, mut res: VecZnxDft<DR, BE>,
a: &GLWE<DA>, a: &GLWE<DA>,
b: &GLWESwitchingKeyPrepared<DB, BE>, key: &GLWESwitchingKeyPrepared<DB, BE>,
scratch: &mut Scratch<BE>, scratch: &mut Scratch<BE>,
) -> VecZnxBig<DR, BE> ) -> VecZnxBig<DR, BE>
where where
@@ -265,12 +265,12 @@ where
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let base2k_in: usize = a.base2k().into(); let base2k_in: usize = a.base2k().into();
let base2k_out: usize = b.base2k().into(); let base2k_out: usize = key.base2k().into();
let cols: usize = (a.rank() + 1).into(); let cols: usize = (a.rank() + 1).into();
let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out); let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out);
let pmat: &VmpPMat<DB, BE> = &b.key.data; let pmat: &VmpPMat<DB, BE> = &key.key.data;
if b.dsize() == 1 { if key.dsize() == 1 {
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size()); let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size());
if base2k_in == base2k_out { if base2k_in == base2k_out {
@@ -295,7 +295,7 @@ where
module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1);
} else { } else {
let dsize: usize = b.dsize().into(); let dsize: usize = key.dsize().into();
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize)); let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize));
ai_dft.data_mut().fill(0); ai_dft.data_mut().fill(0);

View File

@@ -6,6 +6,7 @@ use poulpy_hal::{
use crate::layouts::{ use crate::layouts::{
Base2K, Dnum, Dsize, GGLWEInfos, GLWE, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyAlloc, GLWESwitchingKeyToMut, Base2K, Dnum, Dsize, GGLWEInfos, GLWE, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyAlloc, GLWESwitchingKeyToMut,
GLWESwitchingKeyToRef, LWEInfos, Rank, RingDegree, TorusPrecision, GLWESwitchingKeyToRef, LWEInfos, Rank, RingDegree, TorusPrecision,
prepared::{GetAutomorphismGaloisElement, SetAutomorphismGaloisElement},
}; };
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
@@ -27,6 +28,18 @@ pub struct AutomorphismKey<D: Data> {
pub(crate) p: i64, pub(crate) p: i64,
} }
impl<D: DataMut> SetAutomorphismGaloisElement for AutomorphismKey<D> {
fn set_p(&mut self, p: i64) {
self.p = p
}
}
impl<D: DataRef> GetAutomorphismGaloisElement for AutomorphismKey<D> {
fn p(&self) -> i64 {
self.p
}
}
impl<D: Data> AutomorphismKey<D> { impl<D: Data> AutomorphismKey<D> {
pub fn p(&self) -> i64 { pub fn p(&self) -> i64 {
self.p self.p

View File

@@ -2,7 +2,7 @@ use std::{fmt::Display, marker::PhantomData, ptr::NonNull};
use rand_distr::num_traits::Zero; use rand_distr::num_traits::Zero;
use crate::GALOISGENERATOR; use crate::{GALOISGENERATOR, api::ModuleN};
#[allow(clippy::missing_safety_doc)] #[allow(clippy::missing_safety_doc)]
pub trait Backend: Sized { pub trait Backend: Sized {
@@ -75,36 +75,47 @@ impl<B: Backend> Module<B> {
pub fn log_n(&self) -> usize { pub fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _ (usize::BITS - (self.n() - 1).leading_zeros()) as _
} }
#[inline]
pub fn cyclotomic_order(&self) -> u64 {
(self.n() << 1) as _
} }
pub trait CyclotomicOrder
where
Self: ModuleN,
{
fn cyclotomic_order(&self) -> i64 {
(self.n() << 1) as _
}
}
impl<BE: Backend> CyclotomicOrder for Module<BE> where Self: ModuleN {}
pub trait GaloisElement
where
Self: CyclotomicOrder,
{
// Returns GALOISGENERATOR^|generator| * sign(generator) // Returns GALOISGENERATOR^|generator| * sign(generator)
#[inline] fn galois_element(&self, generator: i64) -> i64 {
pub fn galois_element(&self, generator: i64) -> i64 {
if generator == 0 { if generator == 0 {
return 1; return 1;
} }
((mod_exp_u64(GALOISGENERATOR, generator.unsigned_abs() as usize) & (self.cyclotomic_order() - 1)) as i64)
* generator.signum() let g_exp: u64 = mod_exp_u64(GALOISGENERATOR, generator.unsigned_abs() as usize) & (self.cyclotomic_order() - 1) as u64;
g_exp as i64 * generator.signum()
} }
// Returns gen^-1 // Returns gen^-1
#[inline] fn galois_element_inv(&self, gal_el: i64) -> i64 {
pub fn galois_element_inv(&self, gal_el: i64) -> i64 {
if gal_el == 0 { if gal_el == 0 {
panic!("cannot invert 0") panic!("cannot invert 0")
} }
((mod_exp_u64(
gal_el.unsigned_abs(), let g_exp: u64 =
(self.cyclotomic_order() - 1) as usize, mod_exp_u64(GALOISGENERATOR, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1) as u64;
) & (self.cyclotomic_order() - 1)) as i64) g_exp as i64 * gal_el.signum()
* gal_el.signum()
} }
} }
impl<BE: Backend> GaloisElement for Module<BE> where Self: CyclotomicOrder {}
impl<B: Backend> Drop for Module<B> { impl<B: Backend> Drop for Module<B> {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { B::destroy(self.ptr) } unsafe { B::destroy(self.ptr) }