mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add cross-basek normalization (#90)
* added cross_basek_normalization * updated method signatures to take layouts * fixed cross-base normalization fix #91 fix #93
This commit is contained in:
committed by
GitHub
parent
4da790ea6a
commit
37e13b965c
@@ -2,15 +2,19 @@ use std::collections::HashMap;
|
||||
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxRshInplace,
|
||||
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize,
|
||||
VecZnxNormalizeTmpBytes, VecZnxRshInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
layouts::{GLWECiphertext, prepared::GGLWEAutomorphismKeyPrepared},
|
||||
TakeGLWECt,
|
||||
layouts::{
|
||||
Base2K, GGLWELayoutInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWEInfos,
|
||||
prepared::GGLWEAutomorphismKeyPrepared,
|
||||
},
|
||||
operations::GLWEOperations,
|
||||
};
|
||||
|
||||
@@ -27,34 +31,38 @@ impl GLWECiphertext<Vec<u8>> {
|
||||
gal_els
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn trace_scratch_space<B: Backend>(
|
||||
pub fn trace_scratch_space<B: Backend, OUT, IN, KEY>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
in_k: usize,
|
||||
ksk_k: usize,
|
||||
digits: usize,
|
||||
rank: usize,
|
||||
out_infos: &OUT,
|
||||
in_infos: &IN,
|
||||
key_infos: &KEY,
|
||||
) -> usize
|
||||
where
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
|
||||
OUT: GLWEInfos,
|
||||
IN: GLWEInfos,
|
||||
KEY: GGLWELayoutInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
Self::automorphism_inplace_scratch_space(module, basek, out_k.min(in_k), ksk_k, digits, rank)
|
||||
let trace: usize = Self::automorphism_inplace_scratch_space(module, out_infos, key_infos);
|
||||
if in_infos.base2k() != key_infos.base2k() {
|
||||
let glwe_conv: usize = VecZnx::alloc_bytes(
|
||||
module.n(),
|
||||
(key_infos.rank_out() + 1).into(),
|
||||
out_infos.k().min(in_infos.k()).div_ceil(key_infos.base2k()) as usize,
|
||||
) + module.vec_znx_normalize_tmp_bytes();
|
||||
return glwe_conv + trace;
|
||||
}
|
||||
|
||||
trace
|
||||
}
|
||||
|
||||
pub fn trace_inplace_scratch_space<B: Backend>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
ksk_k: usize,
|
||||
digits: usize,
|
||||
rank: usize,
|
||||
) -> usize
|
||||
pub fn trace_inplace_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
|
||||
where
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
|
||||
OUT: GLWEInfos,
|
||||
KEY: GGLWELayoutInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
Self::automorphism_inplace_scratch_space(module, basek, out_k, ksk_k, digits, rank)
|
||||
Self::trace_scratch_space(module, out_infos, out_infos, key_infos)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,8 +87,10 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>
|
||||
+ VecZnxRshInplace<B>
|
||||
+ VecZnxCopy,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
+ VecZnxCopy
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
|
||||
{
|
||||
self.copy(module, lhs);
|
||||
self.trace_inplace(module, start, end, auto_keys, scratch);
|
||||
@@ -104,23 +114,92 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>
|
||||
+ VecZnxRshInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
+ VecZnxRshInplace<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
|
||||
{
|
||||
(start..end).for_each(|i| {
|
||||
self.rsh(module, 1, scratch);
|
||||
let basek_ksk: Base2K = auto_keys
|
||||
.get(auto_keys.keys().next().unwrap())
|
||||
.unwrap()
|
||||
.base2k();
|
||||
|
||||
let p: i64 = if i == 0 {
|
||||
-1
|
||||
} else {
|
||||
module.galois_element(1 << (i - 1))
|
||||
};
|
||||
|
||||
if let Some(key) = auto_keys.get(&p) {
|
||||
self.automorphism_add_inplace(module, key, scratch);
|
||||
} else {
|
||||
panic!("auto_keys[{}] is empty", p)
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.n(), module.n() as u32);
|
||||
assert!(start < end);
|
||||
assert!(end <= module.log_n());
|
||||
for key in auto_keys.values() {
|
||||
assert_eq!(key.n(), module.n() as u32);
|
||||
assert_eq!(key.base2k(), basek_ksk);
|
||||
assert_eq!(key.rank_in(), self.rank());
|
||||
assert_eq!(key.rank_out(), self.rank());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if self.base2k() != basek_ksk {
|
||||
let (mut self_conv, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout {
|
||||
n: module.n().into(),
|
||||
base2k: basek_ksk,
|
||||
k: self.k(),
|
||||
rank: self.rank(),
|
||||
});
|
||||
|
||||
for j in 0..(self.rank() + 1).into() {
|
||||
module.vec_znx_normalize(
|
||||
basek_ksk.into(),
|
||||
&mut self_conv.data,
|
||||
j,
|
||||
basek_ksk.into(),
|
||||
&self.data,
|
||||
j,
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
|
||||
for i in start..end {
|
||||
self_conv.rsh(module, 1, scratch_1);
|
||||
|
||||
let p: i64 = if i == 0 {
|
||||
-1
|
||||
} else {
|
||||
module.galois_element(1 << (i - 1))
|
||||
};
|
||||
|
||||
if let Some(key) = auto_keys.get(&p) {
|
||||
self_conv.automorphism_add_inplace(module, key, scratch_1);
|
||||
} else {
|
||||
panic!("auto_keys[{p}] is empty")
|
||||
}
|
||||
}
|
||||
|
||||
for j in 0..(self.rank() + 1).into() {
|
||||
module.vec_znx_normalize(
|
||||
self.base2k().into(),
|
||||
&mut self.data,
|
||||
j,
|
||||
basek_ksk.into(),
|
||||
&self_conv.data,
|
||||
j,
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
for i in start..end {
|
||||
self.rsh(module, 1, scratch);
|
||||
|
||||
let p: i64 = if i == 0 {
|
||||
-1
|
||||
} else {
|
||||
module.galois_element(1 << (i - 1))
|
||||
};
|
||||
|
||||
if let Some(key) = auto_keys.get(&p) {
|
||||
self.automorphism_add_inplace(module, key, scratch);
|
||||
} else {
|
||||
panic!("auto_keys[{p}] is empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user