glwe trace

This commit is contained in:
Pro7ech
2025-10-16 17:54:12 +02:00
parent d27d43759a
commit 827d257d0a

View File

@@ -1,170 +1,188 @@
use std::collections::HashMap; use std::collections::HashMap;
use poulpy_hal::{ use poulpy_hal::{
api::{ api::ModuleLogN,
ScratchAvailable, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, layouts::{Backend, DataMut, GaloisElement, Module, Scratch, VecZnx},
VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes,
VecZnxRshInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx},
}; };
use crate::layouts::{Base2K, GGLWEInfos, GLWE, GLWEInfos, GLWELayout, LWEInfos, prepared::AutomorphismKeyPrepared}; use crate::{
GLWEAutomorphism, GLWECopy, GLWEShift, ScratchTakeCore,
layouts::{
Base2K, GGLWEInfos, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, LWEInfos,
prepared::{AutomorphismKeyPreparedToRef, GetAutomorphismGaloisElement},
},
};
impl GLWE<Vec<u8>> { impl GLWE<Vec<u8>> {
pub fn trace_galois_elements<B: Backend>(module: &Module<B>) -> Vec<i64> { pub fn trace_galois_elements<M, BE: Backend>(module: &M) -> Vec<i64>
let mut gal_els: Vec<i64> = Vec::new(); where
(0..module.log_n()).for_each(|i| { M: GLWETrace<BE>,
if i == 0 { {
gal_els.push(-1); module.glwe_trace_galois_elements()
} else {
gal_els.push(module.galois_element(1 << (i - 1)));
}
});
gal_els
} }
pub fn trace_tmp_bytes<B: Backend, OUT, IN, KEY>(module: &Module<B>, out_infos: &OUT, in_infos: &IN, key_infos: &KEY) -> usize pub fn trace_tmp_bytes<R, A, K, M, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
where where
OUT: GLWEInfos, R: GLWEInfos,
IN: GLWEInfos, A: GLWEInfos,
KEY: GGLWEInfos, K: GGLWEInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, M: GLWETrace<BE>,
{ {
let trace: usize = Self::automorphism_inplace_tmp_bytes(module, out_infos, key_infos); module.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos)
if in_infos.base2k() != key_infos.base2k() { }
}
impl<D: DataMut> GLWE<D> {
pub fn trace<A, K, M, BE: Backend>(
&mut self,
module: &M,
start: usize,
end: usize,
a: &A,
keys: &HashMap<i64, K>,
scratch: &mut Scratch<BE>,
) where
A: GLWEToRef,
K: AutomorphismKeyPreparedToRef<BE> + GGLWEInfos + GetAutomorphismGaloisElement,
Scratch<BE>: ScratchTakeCore<BE>,
M: GLWETrace<BE>,
{
module.glwe_trace(self, start, end, a, keys, scratch);
}
pub fn trace_inplace<K, M, BE: Backend>(
&mut self,
module: &M,
start: usize,
end: usize,
keys: &HashMap<i64, K>,
scratch: &mut Scratch<BE>,
) where
K: AutomorphismKeyPreparedToRef<BE> + GGLWEInfos + GetAutomorphismGaloisElement,
Scratch<BE>: ScratchTakeCore<BE>,
M: GLWETrace<BE>,
{
module.glwe_trace_inplace(self, start, end, keys, scratch);
}
}
impl<BE: Backend> GLWETrace<BE> for Module<BE> where
Self: ModuleLogN + GaloisElement + GLWEAutomorphism<BE> + GLWEShift<BE> + GLWECopy
{
}
pub trait GLWETrace<BE: Backend>
where
Self: ModuleLogN + GaloisElement + GLWEAutomorphism<BE> + GLWEShift<BE> + GLWECopy,
{
fn glwe_trace_galois_elements(&self) -> Vec<i64> {
(0..self.log_n())
.map(|i| {
if i == 0 {
-1
} else {
self.galois_element(1 << (i - 1))
}
})
.collect()
}
fn glwe_trace_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
K: GGLWEInfos,
{
let trace: usize = self.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos);
if a_infos.base2k() != key_infos.base2k() {
let glwe_conv: usize = VecZnx::bytes_of( let glwe_conv: usize = VecZnx::bytes_of(
module.n(), self.n(),
(key_infos.rank_out() + 1).into(), (key_infos.rank_out() + 1).into(),
out_infos.k().min(in_infos.k()).div_ceil(key_infos.base2k()) as usize, res_infos.k().min(a_infos.k()).div_ceil(key_infos.base2k()) as usize,
) + module.vec_znx_normalize_tmp_bytes(); ) + self.vec_znx_normalize_tmp_bytes();
return glwe_conv + trace; return glwe_conv + trace;
} }
trace trace
} }
pub fn trace_inplace_tmp_bytes<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize fn glwe_trace<R, A, K>(&self, res: &mut R, start: usize, end: usize, a: &A, keys: &HashMap<i64, K>, scratch: &mut Scratch<BE>)
where where
OUT: GLWEInfos, R: GLWEToMut,
KEY: GGLWEInfos, A: GLWEToRef,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, K: AutomorphismKeyPreparedToRef<BE> + GGLWEInfos + GetAutomorphismGaloisElement,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
Self::trace_tmp_bytes(module, out_infos, out_infos, key_infos) self.glwe_copy(res, a);
} self.glwe_trace_inplace(res, start, end, keys, scratch);
}
impl<DataSelf: DataMut> GLWE<DataSelf> {
pub fn trace<DataLhs: DataRef, DataAK: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
start: usize,
end: usize,
lhs: &GLWE<DataLhs>,
auto_keys: &HashMap<i64, AutomorphismKeyPrepared<DataAK, B>>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxRshInplace<B>
+ VecZnxCopy
+ VecZnxNormalizeTmpBytes
+ VecZnxNormalize<B>,
Scratch<B>: ScratchAvailable,
{
self.copy(module, lhs);
self.trace_inplace(module, start, end, auto_keys, scratch);
} }
pub fn trace_inplace<DataAK: DataRef, B: Backend>( fn glwe_trace_inplace<R, K>(&self, res: &mut R, start: usize, end: usize, keys: &HashMap<i64, K>, scratch: &mut Scratch<BE>)
&mut self, where
module: &Module<B>, R: GLWEToMut,
start: usize, K: AutomorphismKeyPreparedToRef<BE> + GGLWEInfos + GetAutomorphismGaloisElement,
end: usize, Scratch<BE>: ScratchTakeCore<BE>,
auto_keys: &HashMap<i64, AutomorphismKeyPrepared<DataAK, B>>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxRshInplace<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxNormalize<B>,
Scratch<B>: ScratchAvailable,
{ {
let basek_ksk: Base2K = auto_keys let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
.get(auto_keys.keys().next().unwrap())
.unwrap() let basek_ksk: Base2K = keys.get(keys.keys().next().unwrap()).unwrap().base2k();
.base2k();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.n(), module.n() as u32); assert_eq!(res.n(), self.n() as u32);
assert!(start < end); assert!(start < end);
assert!(end <= module.log_n()); assert!(end <= self.log_n());
for key in auto_keys.values() { for key in keys.values() {
assert_eq!(key.n(), module.n() as u32); assert_eq!(key.n(), self.n() as u32);
assert_eq!(key.base2k(), basek_ksk); assert_eq!(key.base2k(), basek_ksk);
assert_eq!(key.rank_in(), self.rank()); assert_eq!(key.rank_in(), res.rank());
assert_eq!(key.rank_out(), self.rank()); assert_eq!(key.rank_out(), res.rank());
} }
} }
if self.base2k() != basek_ksk { if res.base2k() != basek_ksk {
let (mut self_conv, scratch_1) = scratch.take_glwe_ct(&GLWELayout { let (mut self_conv, scratch_1) = scratch.take_glwe_ct(
n: module.n().into(), self,
&GLWELayout {
n: self.n().into(),
base2k: basek_ksk, base2k: basek_ksk,
k: self.k(), k: res.k(),
rank: self.rank(), rank: res.rank(),
}); },
);
for j in 0..(self.rank() + 1).into() { for j in 0..(res.rank() + 1).into() {
module.vec_znx_normalize( self.vec_znx_normalize(
basek_ksk.into(), basek_ksk.into(),
&mut self_conv.data, &mut self_conv.data,
j, j,
basek_ksk.into(), basek_ksk.into(),
&self.data, res.data(),
j, j,
scratch_1, scratch_1,
); );
} }
for i in start..end { for i in start..end {
self_conv.rsh(module, 1, scratch_1); self.glwe_rsh(1, &mut self_conv, scratch_1);
let p: i64 = if i == 0 { let p: i64 = if i == 0 {
-1 -1
} else { } else {
module.galois_element(1 << (i - 1)) self.galois_element(1 << (i - 1))
}; };
if let Some(key) = auto_keys.get(&p) { if let Some(key) = keys.get(&p) {
self_conv.automorphism_add_inplace(module, key, scratch_1); self.glwe_automorphism_add_inplace(&mut self_conv, key, scratch_1);
} else { } else {
panic!("auto_keys[{p}] is empty") panic!("keys[{p}] is empty")
} }
} }
for j in 0..(self.rank() + 1).into() { for j in 0..(res.rank() + 1).into() {
module.vec_znx_normalize( self.vec_znx_normalize(
self.base2k().into(), res.base2k().into(),
&mut self.data, res.data_mut(),
j, j,
basek_ksk.into(), basek_ksk.into(),
&self_conv.data, &self_conv.data,
@@ -174,18 +192,18 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
} }
} else { } else {
for i in start..end { for i in start..end {
self.rsh(module, 1, scratch); self.glwe_rsh(1, res, scratch);
let p: i64 = if i == 0 { let p: i64 = if i == 0 {
-1 -1
} else { } else {
module.galois_element(1 << (i - 1)) self.galois_element(1 << (i - 1))
}; };
if let Some(key) = auto_keys.get(&p) { if let Some(key) = keys.get(&p) {
self.automorphism_add_inplace(module, key, scratch); self.glwe_automorphism_add_inplace(res, key, scratch);
} else { } else {
panic!("auto_keys[{p}] is empty") panic!("keys[{p}] is empty")
} }
} }
} }