Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -2,15 +2,19 @@ use std::collections::HashMap;
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxRshInplace,
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize,
VecZnxNormalizeTmpBytes, VecZnxRshInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch},
layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx},
};
use crate::{
layouts::{GLWECiphertext, prepared::GGLWEAutomorphismKeyPrepared},
TakeGLWECt,
layouts::{
Base2K, GGLWELayoutInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWEInfos,
prepared::GGLWEAutomorphismKeyPrepared,
},
operations::GLWEOperations,
};
@@ -27,34 +31,38 @@ impl GLWECiphertext<Vec<u8>> {
gal_els
}
#[allow(clippy::too_many_arguments)]
pub fn trace_scratch_space<B: Backend>(
pub fn trace_scratch_space<B: Backend, OUT, IN, KEY>(
module: &Module<B>,
basek: usize,
out_k: usize,
in_k: usize,
ksk_k: usize,
digits: usize,
rank: usize,
out_infos: &OUT,
in_infos: &IN,
key_infos: &KEY,
) -> usize
where
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
OUT: GLWEInfos,
IN: GLWEInfos,
KEY: GGLWELayoutInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
{
Self::automorphism_inplace_scratch_space(module, basek, out_k.min(in_k), ksk_k, digits, rank)
let trace: usize = Self::automorphism_inplace_scratch_space(module, out_infos, key_infos);
if in_infos.base2k() != key_infos.base2k() {
let glwe_conv: usize = VecZnx::alloc_bytes(
module.n(),
(key_infos.rank_out() + 1).into(),
out_infos.k().min(in_infos.k()).div_ceil(key_infos.base2k()) as usize,
) + module.vec_znx_normalize_tmp_bytes();
return glwe_conv + trace;
}
trace
}
pub fn trace_inplace_scratch_space<B: Backend>(
module: &Module<B>,
basek: usize,
out_k: usize,
ksk_k: usize,
digits: usize,
rank: usize,
) -> usize
pub fn trace_inplace_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
where
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
OUT: GLWEInfos,
KEY: GGLWELayoutInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
{
Self::automorphism_inplace_scratch_space(module, basek, out_k, ksk_k, digits, rank)
Self::trace_scratch_space(module, out_infos, out_infos, key_infos)
}
}
@@ -79,8 +87,10 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VecZnxBigNormalize<B>
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxRshInplace<B>
+ VecZnxCopy,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
+ VecZnxCopy
+ VecZnxNormalizeTmpBytes
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
self.copy(module, lhs);
self.trace_inplace(module, start, end, auto_keys, scratch);
@@ -104,23 +114,92 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxRshInplace<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
+ VecZnxRshInplace<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
(start..end).for_each(|i| {
self.rsh(module, 1, scratch);
let basek_ksk: Base2K = auto_keys
.get(auto_keys.keys().next().unwrap())
.unwrap()
.base2k();
let p: i64 = if i == 0 {
-1
} else {
module.galois_element(1 << (i - 1))
};
if let Some(key) = auto_keys.get(&p) {
self.automorphism_add_inplace(module, key, scratch);
} else {
panic!("auto_keys[{}] is empty", p)
#[cfg(debug_assertions)]
{
assert_eq!(self.n(), module.n() as u32);
assert!(start < end);
assert!(end <= module.log_n());
for key in auto_keys.values() {
assert_eq!(key.n(), module.n() as u32);
assert_eq!(key.base2k(), basek_ksk);
assert_eq!(key.rank_in(), self.rank());
assert_eq!(key.rank_out(), self.rank());
}
});
}
if self.base2k() != basek_ksk {
let (mut self_conv, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout {
n: module.n().into(),
base2k: basek_ksk,
k: self.k(),
rank: self.rank(),
});
for j in 0..(self.rank() + 1).into() {
module.vec_znx_normalize(
basek_ksk.into(),
&mut self_conv.data,
j,
basek_ksk.into(),
&self.data,
j,
scratch_1,
);
}
for i in start..end {
self_conv.rsh(module, 1, scratch_1);
let p: i64 = if i == 0 {
-1
} else {
module.galois_element(1 << (i - 1))
};
if let Some(key) = auto_keys.get(&p) {
self_conv.automorphism_add_inplace(module, key, scratch_1);
} else {
panic!("auto_keys[{p}] is empty")
}
}
for j in 0..(self.rank() + 1).into() {
module.vec_znx_normalize(
self.base2k().into(),
&mut self.data,
j,
basek_ksk.into(),
&self_conv.data,
j,
scratch_1,
);
}
} else {
for i in start..end {
self.rsh(module, 1, scratch);
let p: i64 = if i == 0 {
-1
} else {
module.galois_element(1 << (i - 1))
};
if let Some(key) = auto_keys.get(&p) {
self.automorphism_add_inplace(module, key, scratch);
} else {
panic!("auto_keys[{p}] is empty")
}
}
}
}
}