Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -3,17 +3,17 @@ use std::collections::HashMap;
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy,
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalizeInplace, VecZnxRotate,
VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy,
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch},
};
use crate::{
GLWEOperations, TakeGLWECt,
layouts::{GLWECiphertext, Infos, prepared::GGLWEAutomorphismKeyPrepared},
layouts::{GGLWELayoutInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWEAutomorphismKeyPrepared},
};
/// [GLWEPacker] enables only the fly GLWE packing
@@ -40,12 +40,15 @@ impl Accumulator {
/// #Arguments
///
/// * `module`: static backend FFT tables.
/// * `basek`: base 2 logarithm of the GLWE ciphertext in memory digit representation.
/// * `base2k`: base 2 logarithm of the GLWE ciphertext in memory digit representation.
/// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus.
/// * `rank`: rank of the GLWE ciphertext.
pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self {
pub fn alloc<A>(infos: &A) -> Self
where
A: GLWEInfos,
{
Self {
data: GLWECiphertext::alloc(n, basek, k, rank),
data: GLWECiphertext::alloc(infos),
value: false,
control: false,
}
@@ -63,13 +66,13 @@ impl GLWEPacker {
/// and N GLWE ciphertext can be packed. With `log_batch=2` all coefficients
/// which are multiples of X^{N/4} are packed. Meaning that N/4 ciphertexts
/// can be packed.
/// * `basek`: base 2 logarithm of the GLWE ciphertext in memory digit representation.
/// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus.
/// * `rank`: rank of the GLWE ciphertext.
pub fn new(n: usize, log_batch: usize, basek: usize, k: usize, rank: usize) -> Self {
pub fn new<A>(infos: &A, log_batch: usize) -> Self
where
A: GLWEInfos,
{
let mut accumulators: Vec<Accumulator> = Vec::<Accumulator>::new();
let log_n: usize = (usize::BITS - (n - 1).leading_zeros()) as _;
(0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(n, basek, k, rank)));
let log_n: usize = infos.n().log2();
(0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(infos)));
Self {
accumulators,
log_batch,
@@ -87,18 +90,13 @@ impl GLWEPacker {
}
/// Number of scratch space bytes required to call [Self::add].
pub fn scratch_space<B: Backend>(
module: &Module<B>,
basek: usize,
ct_k: usize,
k_ksk: usize,
digits: usize,
rank: usize,
) -> usize
pub fn scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
where
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
OUT: GLWEInfos,
KEY: GGLWELayoutInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
{
pack_core_scratch_space(module, basek, ct_k, k_ksk, digits, rank)
pack_core_scratch_space(module, out_infos, key_infos)
}
pub fn galois_elements<B: Backend>(module: &Module<B>) -> Vec<i64> {
@@ -137,17 +135,19 @@ impl GLWEPacker {
+ VecZnxRshInplace<B>
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<B>
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxRotate
+ VecZnxAutomorphismInplace<B>
+ VecZnxBigSubSmallBInplace<B>
+ VecZnxBigAutomorphismInplace<B>,
+ VecZnxBigSubSmallNegateInplace<B>
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxNormalize<B>
+ VecZnxNormalizeTmpBytes,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
assert!(
self.counter < self.accumulators[0].data.n(),
(self.counter as u32) < self.accumulators[0].data.n(),
"Packing limit of {} reached",
self.accumulators[0].data.n() >> self.log_batch
self.accumulators[0].data.n().0 as usize >> self.log_batch
);
pack_core(
@@ -166,7 +166,7 @@ impl GLWEPacker {
where
Module<B>: VecZnxCopy,
{
assert!(self.counter == self.accumulators[0].data.n());
assert!(self.counter as u32 == self.accumulators[0].data.n());
// Copy result GLWE into res GLWE
res.copy(
module,
@@ -177,18 +177,13 @@ impl GLWEPacker {
}
}
fn pack_core_scratch_space<B: Backend>(
module: &Module<B>,
basek: usize,
ct_k: usize,
k_ksk: usize,
digits: usize,
rank: usize,
) -> usize
fn pack_core_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
where
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
OUT: GLWEInfos,
KEY: GGLWELayoutInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
{
combine_scratch_space(module, basek, ct_k, k_ksk, digits, rank)
combine_scratch_space(module, out_infos, key_infos)
}
fn pack_core<D: DataRef, DataAK: DataRef, B: Backend>(
@@ -215,11 +210,13 @@ fn pack_core<D: DataRef, DataAK: DataRef, B: Backend>(
+ VecZnxRshInplace<B>
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<B>
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxRotate
+ VecZnxAutomorphismInplace<B>
+ VecZnxBigSubSmallBInplace<B>
+ VecZnxBigAutomorphismInplace<B>,
+ VecZnxBigSubSmallNegateInplace<B>
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxNormalize<B>
+ VecZnxNormalizeTmpBytes,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
let log_n: usize = module.log_n();
@@ -271,20 +268,15 @@ fn pack_core<D: DataRef, DataAK: DataRef, B: Backend>(
}
}
fn combine_scratch_space<B: Backend>(
module: &Module<B>,
basek: usize,
ct_k: usize,
k_ksk: usize,
digits: usize,
rank: usize,
) -> usize
fn combine_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
where
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
OUT: GLWEInfos,
KEY: GGLWELayoutInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
{
GLWECiphertext::bytes_of(module.n(), basek, ct_k, rank)
GLWECiphertext::alloc_bytes(out_infos)
+ (GLWECiphertext::rsh_scratch_space(module.n())
| GLWECiphertext::automorphism_scratch_space(module, basek, ct_k, ct_k, k_ksk, digits, rank))
| GLWECiphertext::automorphism_inplace_scratch_space(module, out_infos, key_infos))
}
/// [combine] merges two ciphertexts together.
@@ -312,19 +304,17 @@ fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
+ VecZnxRshInplace<B>
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<B>
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxRotate
+ VecZnxAutomorphismInplace<B>
+ VecZnxBigSubSmallBInplace<B>
+ VecZnxBigAutomorphismInplace<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
+ VecZnxBigSubSmallNegateInplace<B>
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxNormalize<B>
+ VecZnxNormalizeTmpBytes,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeGLWECt,
{
let n: usize = acc.data.n();
let log_n: usize = (u64::BITS - (n - 1).leading_zeros()) as _;
let log_n: usize = acc.data.n().log2();
let a: &mut GLWECiphertext<Vec<u8>> = &mut acc.data;
let basek: usize = a.basek();
let k: usize = a.k();
let rank: usize = a.rank();
let gal_el: i64 = if i == 0 {
-1
@@ -346,7 +336,7 @@ fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
// since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q.
if acc.value {
if let Some(b) = b {
let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank);
let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a);
// a = a * X^-t
a.rotate_inplace(module, -t, scratch_1);
@@ -365,7 +355,7 @@ fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
if let Some(key) = auto_keys.get(&gal_el) {
tmp_b.automorphism_inplace(module, key, scratch_1);
} else {
panic!("auto_key[{}] not found", gal_el);
panic!("auto_key[{gal_el}] not found");
}
// a = a * X^-t + b - phi(a * X^-t - b)
@@ -382,19 +372,19 @@ fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
if let Some(key) = auto_keys.get(&gal_el) {
a.automorphism_add_inplace(module, key, scratch);
} else {
panic!("auto_key[{}] not found", gal_el);
panic!("auto_key[{gal_el}] not found");
}
}
} else if let Some(b) = b {
let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank);
let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a);
tmp_b.rotate(module, 1 << (log_n - i - 1), b);
tmp_b.rsh(module, 1, scratch_1);
// a = (b* X^t - phi(b* X^t))
if let Some(key) = auto_keys.get(&gal_el) {
a.automorphism_sub_ba(module, &tmp_b, key, scratch_1);
a.automorphism_sub_negate(module, &tmp_b, key, scratch_1);
} else {
panic!("auto_key[{}] not found", gal_el);
panic!("auto_key[{gal_el}] not found");
}
acc.value = true;