automorphism gglwe

This commit is contained in:
Jean-Philippe Bossuat
2025-10-16 10:52:55 +02:00
parent 3236e1be2c
commit bdd00b557f
6 changed files with 230 additions and 235 deletions

View File

@@ -1,198 +1,169 @@
use poulpy_hal::{
api::{
ScratchAvailable, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
api::VecZnxAutomorphism,
layouts::{Backend, DataMut, GaloisElement, Module, Scratch},
};
use crate::layouts::{AutomorphismKey, GGLWEInfos, GLWE, prepared::AutomorphismKeyPrepared};
use crate::{
ScratchTakeCore,
automorphism::glwe_ct::GLWEAutomorphism,
layouts::{
AutomorphismKey, AutomorphismKeyToMut, AutomorphismKeyToRef, GGLWEInfos, GLWE, GLWEInfos,
prepared::{
AutomorphismKeyPrepared, AutomorphismKeyPreparedToRef, GetAutomorphismGaloisElement, SetAutomorphismGaloisElement,
},
},
};
impl AutomorphismKey<Vec<u8>> {
pub fn automorphism_tmp_bytes<B: Backend, OUT, IN, KEY>(
module: &Module<B>,
out_infos: &OUT,
in_infos: &IN,
key_infos: &KEY,
) -> usize
pub fn automorphism_tmp_bytes<R, A, K, M, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
where
OUT: GGLWEInfos,
IN: GGLWEInfos,
KEY: GGLWEInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
R: GGLWEInfos,
A: GGLWEInfos,
K: GGLWEInfos,
M: AutomorphismKeyAutomorphism<BE>,
{
GLWE::keyswitch_tmp_bytes(
module,
&out_infos.glwe_layout(),
&in_infos.glwe_layout(),
key_infos,
)
}
pub fn automorphism_inplace_tmp_bytes<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
where
OUT: GGLWEInfos,
KEY: GGLWEInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
{
AutomorphismKey::automorphism_tmp_bytes(module, out_infos, out_infos, key_infos)
module.automorphism_key_automorphism_tmp_bytes(res_infos, a_infos, key_infos)
}
}
impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
pub fn automorphism<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
lhs: &AutomorphismKey<DataLhs>,
rhs: &AutomorphismKeyPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphism
+ VecZnxAutomorphismInplace<B>
+ VecZnxNormalize<B>
+ VecZnxNormalizeTmpBytes,
Scratch<B>: ScratchAvailable,
pub fn automorphism<A, K, M, BE: Backend>(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch<BE>)
where
A: AutomorphismKeyToRef + GetAutomorphismGaloisElement,
K: AutomorphismKeyPreparedToRef<BE> + GetAutomorphismGaloisElement,
Scratch<BE>: ScratchTakeCore<BE>,
M: AutomorphismKeyAutomorphism<BE>,
{
#[cfg(debug_assertions)]
{
use crate::layouts::LWEInfos;
assert_eq!(
self.rank_in(),
lhs.rank_in(),
"ksk_out input rank: {} != ksk_in input rank: {}",
self.rank_in(),
lhs.rank_in()
);
assert_eq!(
self.rank_out(),
rhs.rank_in(),
"ksk_in output rank: {} != ksk_apply input rank: {}",
self.rank_out(),
rhs.rank_in()
);
assert_eq!(
self.rank_out(),
rhs.rank_out(),
"ksk_out output rank: {} != ksk_apply output rank: {}",
self.rank_out(),
rhs.rank_out()
);
assert!(
self.k() <= lhs.k(),
"output k={} cannot be greater than input k={}",
self.k(),
lhs.k()
)
}
let cols_out: usize = (rhs.rank_out() + 1).into();
let p: i64 = lhs.p();
let p_inv: i64 = module.galois_element_inv(p);
(0..self.rank_in().into()).for_each(|col_i| {
(0..self.dnum().into()).for_each(|row_j| {
let mut res_ct: GLWE<&mut [u8]> = self.at_mut(row_j, col_i);
let lhs_ct: GLWE<&[u8]> = lhs.at(row_j, col_i);
// Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
(0..cols_out).for_each(|i| {
module.vec_znx_automorphism(lhs.p(), &mut res_ct.data, i, &lhs_ct.data, i);
});
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
res_ct.keyswitch_inplace(module, &rhs.key, scratch);
// Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
(0..cols_out).for_each(|i| {
module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
});
});
});
(self.dnum().min(lhs.dnum()).into()..self.dnum().into()).for_each(|row_i| {
(0..self.rank_in().into()).for_each(|col_j| {
self.at_mut(row_i, col_j).data.zero();
});
});
self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64);
module.automorphism_key_automorphism(self, a, key, scratch);
}
pub fn automorphism_inplace<DataRhs: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
rhs: &AutomorphismKeyPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphism
+ VecZnxAutomorphismInplace<B>
+ VecZnxNormalize<B>
+ VecZnxNormalizeTmpBytes,
Scratch<B>: ScratchAvailable,
pub fn automorphism_inplace<K, M, BE: Backend>(&mut self, module: &M, key: &K, scratch: &mut Scratch<BE>)
where
K: AutomorphismKeyPreparedToRef<BE> + GetAutomorphismGaloisElement,
Scratch<BE>: ScratchTakeCore<BE>,
M: AutomorphismKeyAutomorphism<BE>,
{
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank_out(),
rhs.rank_in(),
"ksk_in output rank: {} != ksk_apply input rank: {}",
self.rank_out(),
rhs.rank_in()
);
assert_eq!(
self.rank_out(),
rhs.rank_out(),
"ksk_out output rank: {} != ksk_apply output rank: {}",
self.rank_out(),
rhs.rank_out()
);
}
let cols_out: usize = (rhs.rank_out() + 1).into();
let p: i64 = self.p();
let p_inv = module.galois_element_inv(p);
(0..self.rank_in().into()).for_each(|col_i| {
(0..self.dnum().into()).for_each(|row_j| {
let mut res_ct: GLWE<&mut [u8]> = self.at_mut(row_j, col_i);
// Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
(0..cols_out).for_each(|i| {
module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
});
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
res_ct.keyswitch_inplace(module, &rhs.key, scratch);
// Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
(0..cols_out).for_each(|i| {
module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
});
});
});
self.p = (self.p * rhs.p) % (module.cyclotomic_order() as i64);
module.automorphism_key_automorphism_inplace(self, key, scratch);
}
}
impl<BE: Backend> AutomorphismKeyAutomorphism<BE> for Module<BE> where
Self: GaloisElement + GLWEAutomorphism<BE> + VecZnxAutomorphism
{
}
pub trait AutomorphismKeyAutomorphism<BE: Backend>
where
Self: GaloisElement + GLWEAutomorphism<BE> + VecZnxAutomorphism,
{
fn automorphism_key_automorphism_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
where
R: GGLWEInfos,
A: GGLWEInfos,
K: GGLWEInfos,
{
self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
}
fn automorphism_key_automorphism<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
where
R: AutomorphismKeyToMut + SetAutomorphismGaloisElement,
A: AutomorphismKeyToRef + GetAutomorphismGaloisElement,
K: AutomorphismKeyPreparedToRef<BE> + GetAutomorphismGaloisElement,
Scratch<BE>: ScratchTakeCore<BE>,
{
{
let res: &mut AutomorphismKey<&mut [u8]> = &mut res.to_mut();
let a: &AutomorphismKey<&[u8]> = &a.to_ref();
let key: &AutomorphismKeyPrepared<&[u8], _> = &key.to_ref();
assert!(
res.dnum().as_u32() <= a.dnum().as_u32(),
"res dnum: {} > a dnum: {}",
res.dnum(),
a.dnum()
);
assert_eq!(
res.dsize(),
a.dsize(),
"res dnum: {} != a dnum: {}",
res.dsize(),
a.dsize()
);
let cols_out: usize = (key.rank_out() + 1).into();
let p: i64 = a.p();
let p_inv: i64 = self.galois_element_inv(p);
for row in 0..res.dnum().as_usize() {
for col in 0..cols_out {
let mut res_tmp: GLWE<&mut [u8]> = res.at_mut(row, col);
let a_ct: GLWE<&[u8]> = a.at(row, col);
// Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
for i in 0..cols_out {
self.vec_znx_automorphism(a.p(), res_tmp.data_mut(), i, &a_ct.data, i);
}
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
self.glwe_keyswitch_inplace(&mut res_tmp, &key.key, scratch);
// Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
(0..cols_out).for_each(|i| {
self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch);
});
}
}
}
res.set_p((a.p() * key.p()) % (self.cyclotomic_order() as i64));
}
fn automorphism_key_automorphism_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
where
R: AutomorphismKeyToMut + SetAutomorphismGaloisElement + GetAutomorphismGaloisElement,
K: AutomorphismKeyPreparedToRef<BE> + GetAutomorphismGaloisElement,
Scratch<BE>: ScratchTakeCore<BE>,
{
{
let res: &mut AutomorphismKey<&mut [u8]> = &mut res.to_mut();
let key: &AutomorphismKeyPrepared<&[u8], _> = &key.to_ref();
assert_eq!(
res.rank(),
key.rank(),
"key rank: {} != key rank: {}",
res.rank(),
key.rank()
);
let cols_out: usize = (key.rank_out() + 1).into();
let p: i64 = res.p();
let p_inv: i64 = self.galois_element_inv(p);
for row in 0..res.dnum().as_usize() {
for col in 0..cols_out {
let mut res_tmp: GLWE<&mut [u8]> = res.at_mut(row, col);
// Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
for i in 0..cols_out {
self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch);
}
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
self.glwe_keyswitch_inplace(&mut res_tmp, &key.key, scratch);
// Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
for i in 0..cols_out {
self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch);
}
}
}
}
res.set_p((res.p() * key.p()) % (self.cyclotomic_order() as i64));
}
}

View File

@@ -16,7 +16,7 @@ impl AutomorphismKey<Vec<u8>> {
R: GGLWEInfos,
A: GGLWEInfos,
K: GGLWEInfos,
M: GGLWEKeySwitch<BE>,
M: GGLWEKeyswitch<BE>,
{
module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
}
@@ -28,7 +28,7 @@ impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
A: AutomorphismKeyToRef,
B: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeySwitch<BE>,
M: GGLWEKeyswitch<BE>,
{
module.gglwe_keyswitch(&mut self.key.key, &a.to_ref().key.key, b, scratch);
}
@@ -37,7 +37,7 @@ impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
where
A: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeySwitch<BE>,
M: GGLWEKeyswitch<BE>,
{
module.gglwe_keyswitch_inplace(&mut self.key.key, a, scratch);
}
@@ -49,7 +49,7 @@ impl GLWESwitchingKey<Vec<u8>> {
R: GGLWEInfos,
A: GGLWEInfos,
K: GGLWEInfos,
M: GGLWEKeySwitch<BE>,
M: GGLWEKeyswitch<BE>,
{
module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
}
@@ -61,7 +61,7 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
A: GLWESwitchingKeyToRef,
B: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeySwitch<BE>,
M: GGLWEKeyswitch<BE>,
{
module.gglwe_keyswitch(&mut self.key, &a.to_ref().key, b, scratch);
}
@@ -70,7 +70,7 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
where
A: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeySwitch<BE>,
M: GGLWEKeyswitch<BE>,
{
module.gglwe_keyswitch_inplace(&mut self.key, a, scratch);
}
@@ -82,7 +82,7 @@ impl GGLWE<Vec<u8>> {
R: GGLWEInfos,
A: GGLWEInfos,
K: GGLWEInfos,
M: GGLWEKeySwitch<BE>,
M: GGLWEKeyswitch<BE>,
{
module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
}
@@ -94,7 +94,7 @@ impl<DataSelf: DataMut> GGLWE<DataSelf> {
A: GGLWEToRef,
B: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeySwitch<BE>,
M: GGLWEKeyswitch<BE>,
{
module.gglwe_keyswitch(self, a, b, scratch);
}
@@ -103,15 +103,15 @@ impl<DataSelf: DataMut> GGLWE<DataSelf> {
where
A: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeySwitch<BE>,
M: GGLWEKeyswitch<BE>,
{
module.gglwe_keyswitch_inplace(self, a, scratch);
}
}
impl<BE: Backend> GGLWEKeySwitch<BE> for Module<BE> where Self: GLWEKeyswitch<BE> {}
impl<BE: Backend> GGLWEKeyswitch<BE> for Module<BE> where Self: GLWEKeyswitch<BE> {}
pub trait GGLWEKeySwitch<BE: Backend>
pub trait GGLWEKeyswitch<BE: Backend>
where
Self: GLWEKeyswitch<BE>,
{

View File

@@ -22,7 +22,7 @@ impl GGSW<Vec<u8>> {
A: GGSWInfos,
K: GGLWEInfos,
T: GGLWEInfos,
M: GGSWKeySwitch<BE>,
M: GGSWKeyswitch<BE>,
{
module.ggsw_keyswitch_tmp_bytes(res_infos, a_infos, key_infos, tsk_infos)
}
@@ -35,7 +35,7 @@ impl<D: DataMut> GGSW<D> {
K: GLWESwitchingKeyPreparedToRef<BE>,
T: TensorKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGSWKeySwitch<BE>,
M: GGSWKeyswitch<BE>,
{
module.ggsw_keyswitch(self, a, key, tsk, scratch);
}
@@ -45,13 +45,13 @@ impl<D: DataMut> GGSW<D> {
K: GLWESwitchingKeyPreparedToRef<BE>,
T: TensorKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
M: GGSWKeySwitch<BE>,
M: GGSWKeyswitch<BE>,
{
module.ggsw_keyswitch_inplace(self, key, tsk, scratch);
}
}
pub trait GGSWKeySwitch<BE: Backend>
pub trait GGSWKeyswitch<BE: Backend>
where
Self: GLWEKeyswitch<BE> + GGSWExpandRows<BE>,
{

View File

@@ -16,14 +16,14 @@ use crate::{
};
impl GLWE<Vec<u8>> {
pub fn keyswitch_tmp_bytes<M, R, A, B, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
pub fn keyswitch_tmp_bytes<M, R, A, B, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GGLWEInfos,
M: GLWEKeyswitch<BE>,
{
module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, b_infos)
module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
}
}
@@ -89,7 +89,7 @@ where
+ VecZnxNormalize<BE>
+ VecZnxNormalizeTmpBytes,
{
fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
@@ -97,44 +97,44 @@ where
{
let in_size: usize = a_infos
.k()
.div_ceil(b_infos.base2k())
.div_ceil(b_infos.dsize().into()) as usize;
.div_ceil(key_infos.base2k())
.div_ceil(key_infos.dsize().into()) as usize;
let out_size: usize = res_infos.size();
let ksk_size: usize = b_infos.size();
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE
let ai_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank_in()).into(), in_size);
let ksk_size: usize = key_infos.size();
let res_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE
let ai_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size);
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
out_size,
in_size,
in_size,
(b_infos.rank_in()).into(),
(b_infos.rank_out() + 1).into(),
(key_infos.rank_in()).into(),
(key_infos.rank_out() + 1).into(),
ksk_size,
) + self.bytes_of_vec_znx_dft((b_infos.rank_in()).into(), in_size);
) + self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size);
let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes();
if a_infos.base2k() == b_infos.base2k() {
if a_infos.base2k() == key_infos.base2k() {
res_dft + ((ai_dft + vmp) | normalize_big)
} else if b_infos.dsize() == 1 {
} else if key_infos.dsize() == 1 {
// In this case, we only need one column, temporary, that we can drop once a_dft is computed.
let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes();
res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big)
} else {
// Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion.
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank_in()).into(), in_size);
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (key_infos.rank_in()).into(), in_size);
res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
}
}
fn glwe_keyswitch<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
B: GLWESwitchingKeyPreparedToRef<BE>,
K: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &b.to_ref();
let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &key.to_ref();
assert_eq!(
a.rank(),
@@ -181,14 +181,14 @@ where
})
}
fn glwe_keyswitch_inplace<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWESwitchingKeyPreparedToRef<BE>,
K: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &a.to_ref();
let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &key.to_ref();
assert_eq!(
res.rank(),
@@ -243,7 +243,7 @@ pub(crate) fn keyswitch_internal<BE: Backend, M, DR, DA, DB>(
module: &M,
mut res: VecZnxDft<DR, BE>,
a: &GLWE<DA>,
b: &GLWESwitchingKeyPrepared<DB, BE>,
key: &GLWESwitchingKeyPrepared<DB, BE>,
scratch: &mut Scratch<BE>,
) -> VecZnxBig<DR, BE>
where
@@ -265,12 +265,12 @@ where
Scratch<BE>: ScratchTakeCore<BE>,
{
let base2k_in: usize = a.base2k().into();
let base2k_out: usize = b.base2k().into();
let base2k_out: usize = key.base2k().into();
let cols: usize = (a.rank() + 1).into();
let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out);
let pmat: &VmpPMat<DB, BE> = &b.key.data;
let pmat: &VmpPMat<DB, BE> = &key.key.data;
if b.dsize() == 1 {
if key.dsize() == 1 {
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size());
if base2k_in == base2k_out {
@@ -295,7 +295,7 @@ where
module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1);
} else {
let dsize: usize = b.dsize().into();
let dsize: usize = key.dsize().into();
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize));
ai_dft.data_mut().fill(0);

View File

@@ -6,6 +6,7 @@ use poulpy_hal::{
use crate::layouts::{
Base2K, Dnum, Dsize, GGLWEInfos, GLWE, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyAlloc, GLWESwitchingKeyToMut,
GLWESwitchingKeyToRef, LWEInfos, Rank, RingDegree, TorusPrecision,
prepared::{GetAutomorphismGaloisElement, SetAutomorphismGaloisElement},
};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
@@ -27,6 +28,18 @@ pub struct AutomorphismKey<D: Data> {
pub(crate) p: i64,
}
impl<D: DataMut> SetAutomorphismGaloisElement for AutomorphismKey<D> {
fn set_p(&mut self, p: i64) {
self.p = p
}
}
impl<D: DataRef> GetAutomorphismGaloisElement for AutomorphismKey<D> {
fn p(&self) -> i64 {
self.p
}
}
impl<D: Data> AutomorphismKey<D> {
pub fn p(&self) -> i64 {
self.p

View File

@@ -2,7 +2,7 @@ use std::{fmt::Display, marker::PhantomData, ptr::NonNull};
use rand_distr::num_traits::Zero;
use crate::GALOISGENERATOR;
use crate::{GALOISGENERATOR, api::ModuleN};
#[allow(clippy::missing_safety_doc)]
pub trait Backend: Sized {
@@ -75,36 +75,47 @@ impl<B: Backend> Module<B> {
pub fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
}
#[inline]
pub fn cyclotomic_order(&self) -> u64 {
pub trait CyclotomicOrder
where
Self: ModuleN,
{
fn cyclotomic_order(&self) -> i64 {
(self.n() << 1) as _
}
}
impl<BE: Backend> CyclotomicOrder for Module<BE> where Self: ModuleN {}
pub trait GaloisElement
where
Self: CyclotomicOrder,
{
// Returns GALOISGENERATOR^|generator| * sign(generator)
#[inline]
pub fn galois_element(&self, generator: i64) -> i64 {
fn galois_element(&self, generator: i64) -> i64 {
if generator == 0 {
return 1;
}
((mod_exp_u64(GALOISGENERATOR, generator.unsigned_abs() as usize) & (self.cyclotomic_order() - 1)) as i64)
* generator.signum()
let g_exp: u64 = mod_exp_u64(GALOISGENERATOR, generator.unsigned_abs() as usize) & (self.cyclotomic_order() - 1) as u64;
g_exp as i64 * generator.signum()
}
// Returns gen^-1
#[inline]
pub fn galois_element_inv(&self, gal_el: i64) -> i64 {
fn galois_element_inv(&self, gal_el: i64) -> i64 {
if gal_el == 0 {
panic!("cannot invert 0")
}
((mod_exp_u64(
gal_el.unsigned_abs(),
(self.cyclotomic_order() - 1) as usize,
) & (self.cyclotomic_order() - 1)) as i64)
* gal_el.signum()
let g_exp: u64 =
mod_exp_u64(GALOISGENERATOR, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1) as u64;
g_exp as i64 * gal_el.signum()
}
}
impl<BE: Backend> GaloisElement for Module<BE> where Self: CyclotomicOrder {}
impl<B: Backend> Drop for Module<B> {
fn drop(&mut self) {
unsafe { B::destroy(self.ptr) }