mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Add cross-basek normalization (#90)
* added cross_basek_normalization * updated method signatures to take layouts * fixed cross-base normalization fix #91 fix #93
This commit is contained in:
committed by
GitHub
parent
4da790ea6a
commit
37e13b965c
@@ -3,17 +3,17 @@ use std::collections::HashMap;
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
|
||||
VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalizeInplace, VecZnxRotate,
|
||||
VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalize,
|
||||
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
|
||||
VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
GLWEOperations, TakeGLWECt,
|
||||
layouts::{GLWECiphertext, Infos, prepared::GGLWEAutomorphismKeyPrepared},
|
||||
layouts::{GGLWELayoutInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWEAutomorphismKeyPrepared},
|
||||
};
|
||||
|
||||
/// [GLWEPacker] enables only the fly GLWE packing
|
||||
@@ -40,12 +40,15 @@ impl Accumulator {
|
||||
/// #Arguments
|
||||
///
|
||||
/// * `module`: static backend FFT tables.
|
||||
/// * `basek`: base 2 logarithm of the GLWE ciphertext in memory digit representation.
|
||||
/// * `base2k`: base 2 logarithm of the GLWE ciphertext in memory digit representation.
|
||||
/// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus.
|
||||
/// * `rank`: rank of the GLWE ciphertext.
|
||||
pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self {
|
||||
pub fn alloc<A>(infos: &A) -> Self
|
||||
where
|
||||
A: GLWEInfos,
|
||||
{
|
||||
Self {
|
||||
data: GLWECiphertext::alloc(n, basek, k, rank),
|
||||
data: GLWECiphertext::alloc(infos),
|
||||
value: false,
|
||||
control: false,
|
||||
}
|
||||
@@ -63,13 +66,13 @@ impl GLWEPacker {
|
||||
/// and N GLWE ciphertext can be packed. With `log_batch=2` all coefficients
|
||||
/// which are multiples of X^{N/4} are packed. Meaning that N/4 ciphertexts
|
||||
/// can be packed.
|
||||
/// * `basek`: base 2 logarithm of the GLWE ciphertext in memory digit representation.
|
||||
/// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus.
|
||||
/// * `rank`: rank of the GLWE ciphertext.
|
||||
pub fn new(n: usize, log_batch: usize, basek: usize, k: usize, rank: usize) -> Self {
|
||||
pub fn new<A>(infos: &A, log_batch: usize) -> Self
|
||||
where
|
||||
A: GLWEInfos,
|
||||
{
|
||||
let mut accumulators: Vec<Accumulator> = Vec::<Accumulator>::new();
|
||||
let log_n: usize = (usize::BITS - (n - 1).leading_zeros()) as _;
|
||||
(0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(n, basek, k, rank)));
|
||||
let log_n: usize = infos.n().log2();
|
||||
(0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(infos)));
|
||||
Self {
|
||||
accumulators,
|
||||
log_batch,
|
||||
@@ -87,18 +90,13 @@ impl GLWEPacker {
|
||||
}
|
||||
|
||||
/// Number of scratch space bytes required to call [Self::add].
|
||||
pub fn scratch_space<B: Backend>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
ct_k: usize,
|
||||
k_ksk: usize,
|
||||
digits: usize,
|
||||
rank: usize,
|
||||
) -> usize
|
||||
pub fn scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
|
||||
where
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
|
||||
OUT: GLWEInfos,
|
||||
KEY: GGLWELayoutInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
pack_core_scratch_space(module, basek, ct_k, k_ksk, digits, rank)
|
||||
pack_core_scratch_space(module, out_infos, key_infos)
|
||||
}
|
||||
|
||||
pub fn galois_elements<B: Backend>(module: &Module<B>) -> Vec<i64> {
|
||||
@@ -137,17 +135,19 @@ impl GLWEPacker {
|
||||
+ VecZnxRshInplace<B>
|
||||
+ VecZnxAddInplace
|
||||
+ VecZnxNormalizeInplace<B>
|
||||
+ VecZnxSubABInplace
|
||||
+ VecZnxSubInplace
|
||||
+ VecZnxRotate
|
||||
+ VecZnxAutomorphismInplace<B>
|
||||
+ VecZnxBigSubSmallBInplace<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>,
|
||||
+ VecZnxBigSubSmallNegateInplace<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxNormalizeTmpBytes,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
|
||||
{
|
||||
assert!(
|
||||
self.counter < self.accumulators[0].data.n(),
|
||||
(self.counter as u32) < self.accumulators[0].data.n(),
|
||||
"Packing limit of {} reached",
|
||||
self.accumulators[0].data.n() >> self.log_batch
|
||||
self.accumulators[0].data.n().0 as usize >> self.log_batch
|
||||
);
|
||||
|
||||
pack_core(
|
||||
@@ -166,7 +166,7 @@ impl GLWEPacker {
|
||||
where
|
||||
Module<B>: VecZnxCopy,
|
||||
{
|
||||
assert!(self.counter == self.accumulators[0].data.n());
|
||||
assert!(self.counter as u32 == self.accumulators[0].data.n());
|
||||
// Copy result GLWE into res GLWE
|
||||
res.copy(
|
||||
module,
|
||||
@@ -177,18 +177,13 @@ impl GLWEPacker {
|
||||
}
|
||||
}
|
||||
|
||||
fn pack_core_scratch_space<B: Backend>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
ct_k: usize,
|
||||
k_ksk: usize,
|
||||
digits: usize,
|
||||
rank: usize,
|
||||
) -> usize
|
||||
fn pack_core_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
|
||||
where
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
|
||||
OUT: GLWEInfos,
|
||||
KEY: GGLWELayoutInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
combine_scratch_space(module, basek, ct_k, k_ksk, digits, rank)
|
||||
combine_scratch_space(module, out_infos, key_infos)
|
||||
}
|
||||
|
||||
fn pack_core<D: DataRef, DataAK: DataRef, B: Backend>(
|
||||
@@ -215,11 +210,13 @@ fn pack_core<D: DataRef, DataAK: DataRef, B: Backend>(
|
||||
+ VecZnxRshInplace<B>
|
||||
+ VecZnxAddInplace
|
||||
+ VecZnxNormalizeInplace<B>
|
||||
+ VecZnxSubABInplace
|
||||
+ VecZnxSubInplace
|
||||
+ VecZnxRotate
|
||||
+ VecZnxAutomorphismInplace<B>
|
||||
+ VecZnxBigSubSmallBInplace<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>,
|
||||
+ VecZnxBigSubSmallNegateInplace<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxNormalizeTmpBytes,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
|
||||
{
|
||||
let log_n: usize = module.log_n();
|
||||
@@ -271,20 +268,15 @@ fn pack_core<D: DataRef, DataAK: DataRef, B: Backend>(
|
||||
}
|
||||
}
|
||||
|
||||
fn combine_scratch_space<B: Backend>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
ct_k: usize,
|
||||
k_ksk: usize,
|
||||
digits: usize,
|
||||
rank: usize,
|
||||
) -> usize
|
||||
fn combine_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
|
||||
where
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
|
||||
OUT: GLWEInfos,
|
||||
KEY: GGLWELayoutInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
GLWECiphertext::bytes_of(module.n(), basek, ct_k, rank)
|
||||
GLWECiphertext::alloc_bytes(out_infos)
|
||||
+ (GLWECiphertext::rsh_scratch_space(module.n())
|
||||
| GLWECiphertext::automorphism_scratch_space(module, basek, ct_k, ct_k, k_ksk, digits, rank))
|
||||
| GLWECiphertext::automorphism_inplace_scratch_space(module, out_infos, key_infos))
|
||||
}
|
||||
|
||||
/// [combine] merges two ciphertexts together.
|
||||
@@ -312,19 +304,17 @@ fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
|
||||
+ VecZnxRshInplace<B>
|
||||
+ VecZnxAddInplace
|
||||
+ VecZnxNormalizeInplace<B>
|
||||
+ VecZnxSubABInplace
|
||||
+ VecZnxSubInplace
|
||||
+ VecZnxRotate
|
||||
+ VecZnxAutomorphismInplace<B>
|
||||
+ VecZnxBigSubSmallBInplace<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
|
||||
+ VecZnxBigSubSmallNegateInplace<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxNormalizeTmpBytes,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeGLWECt,
|
||||
{
|
||||
let n: usize = acc.data.n();
|
||||
let log_n: usize = (u64::BITS - (n - 1).leading_zeros()) as _;
|
||||
let log_n: usize = acc.data.n().log2();
|
||||
let a: &mut GLWECiphertext<Vec<u8>> = &mut acc.data;
|
||||
let basek: usize = a.basek();
|
||||
let k: usize = a.k();
|
||||
let rank: usize = a.rank();
|
||||
|
||||
let gal_el: i64 = if i == 0 {
|
||||
-1
|
||||
@@ -346,7 +336,7 @@ fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
|
||||
// since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q.
|
||||
if acc.value {
|
||||
if let Some(b) = b {
|
||||
let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank);
|
||||
let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a);
|
||||
|
||||
// a = a * X^-t
|
||||
a.rotate_inplace(module, -t, scratch_1);
|
||||
@@ -365,7 +355,7 @@ fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
|
||||
if let Some(key) = auto_keys.get(&gal_el) {
|
||||
tmp_b.automorphism_inplace(module, key, scratch_1);
|
||||
} else {
|
||||
panic!("auto_key[{}] not found", gal_el);
|
||||
panic!("auto_key[{gal_el}] not found");
|
||||
}
|
||||
|
||||
// a = a * X^-t + b - phi(a * X^-t - b)
|
||||
@@ -382,19 +372,19 @@ fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
|
||||
if let Some(key) = auto_keys.get(&gal_el) {
|
||||
a.automorphism_add_inplace(module, key, scratch);
|
||||
} else {
|
||||
panic!("auto_key[{}] not found", gal_el);
|
||||
panic!("auto_key[{gal_el}] not found");
|
||||
}
|
||||
}
|
||||
} else if let Some(b) = b {
|
||||
let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank);
|
||||
let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a);
|
||||
tmp_b.rotate(module, 1 << (log_n - i - 1), b);
|
||||
tmp_b.rsh(module, 1, scratch_1);
|
||||
|
||||
// a = (b* X^t - phi(b* X^t))
|
||||
if let Some(key) = auto_keys.get(&gal_el) {
|
||||
a.automorphism_sub_ba(module, &tmp_b, key, scratch_1);
|
||||
a.automorphism_sub_negate(module, &tmp_b, key, scratch_1);
|
||||
} else {
|
||||
panic!("auto_key[{}] not found", gal_el);
|
||||
panic!("auto_key[{gal_el}] not found");
|
||||
}
|
||||
|
||||
acc.value = true;
|
||||
|
||||
Reference in New Issue
Block a user