automorphism gglwe

This commit is contained in:
Jean-Philippe Bossuat
2025-10-16 10:52:55 +02:00
parent 3236e1be2c
commit bdd00b557f
6 changed files with 230 additions and 235 deletions

View File

@@ -16,14 +16,14 @@ use crate::{
};
impl GLWE<Vec<u8>> {
pub fn keyswitch_tmp_bytes<M, R, A, B, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
pub fn keyswitch_tmp_bytes<M, R, A, B, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GGLWEInfos,
M: GLWEKeyswitch<BE>,
{
module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, b_infos)
module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
}
}
@@ -89,7 +89,7 @@ where
+ VecZnxNormalize<BE>
+ VecZnxNormalizeTmpBytes,
{
fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
@@ -97,44 +97,44 @@ where
{
let in_size: usize = a_infos
.k()
.div_ceil(b_infos.base2k())
.div_ceil(b_infos.dsize().into()) as usize;
.div_ceil(key_infos.base2k())
.div_ceil(key_infos.dsize().into()) as usize;
let out_size: usize = res_infos.size();
let ksk_size: usize = b_infos.size();
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE
let ai_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank_in()).into(), in_size);
let ksk_size: usize = key_infos.size();
let res_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE
let ai_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size);
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
out_size,
in_size,
in_size,
(b_infos.rank_in()).into(),
(b_infos.rank_out() + 1).into(),
(key_infos.rank_in()).into(),
(key_infos.rank_out() + 1).into(),
ksk_size,
) + self.bytes_of_vec_znx_dft((b_infos.rank_in()).into(), in_size);
) + self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size);
let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes();
if a_infos.base2k() == b_infos.base2k() {
if a_infos.base2k() == key_infos.base2k() {
res_dft + ((ai_dft + vmp) | normalize_big)
} else if b_infos.dsize() == 1 {
} else if key_infos.dsize() == 1 {
// In this case, we only need one column, temporary, that we can drop once a_dft is computed.
let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes();
res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big)
} else {
// Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion.
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank_in()).into(), in_size);
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (key_infos.rank_in()).into(), in_size);
res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
}
}
fn glwe_keyswitch<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
B: GLWESwitchingKeyPreparedToRef<BE>,
K: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &b.to_ref();
let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &key.to_ref();
assert_eq!(
a.rank(),
@@ -181,14 +181,14 @@ where
})
}
fn glwe_keyswitch_inplace<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWESwitchingKeyPreparedToRef<BE>,
K: GLWESwitchingKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &a.to_ref();
let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &key.to_ref();
assert_eq!(
res.rank(),
@@ -243,7 +243,7 @@ pub(crate) fn keyswitch_internal<BE: Backend, M, DR, DA, DB>(
module: &M,
mut res: VecZnxDft<DR, BE>,
a: &GLWE<DA>,
b: &GLWESwitchingKeyPrepared<DB, BE>,
key: &GLWESwitchingKeyPrepared<DB, BE>,
scratch: &mut Scratch<BE>,
) -> VecZnxBig<DR, BE>
where
@@ -265,12 +265,12 @@ where
Scratch<BE>: ScratchTakeCore<BE>,
{
let base2k_in: usize = a.base2k().into();
let base2k_out: usize = b.base2k().into();
let base2k_out: usize = key.base2k().into();
let cols: usize = (a.rank() + 1).into();
let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out);
let pmat: &VmpPMat<DB, BE> = &b.key.data;
let pmat: &VmpPMat<DB, BE> = &key.key.data;
if b.dsize() == 1 {
if key.dsize() == 1 {
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size());
if base2k_in == base2k_out {
@@ -295,7 +295,7 @@ where
module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1);
} else {
let dsize: usize = b.dsize().into();
let dsize: usize = key.dsize().into();
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize));
ai_dft.data_mut().fill(0);