Two changes:

- Fix setting `k` in `GlweOps`
- Improve GLWEPacker API avoid accumulating beyond limit (#50)
This commit is contained in:
Janmajaya Mall
2025-07-09 16:23:56 +05:30
parent b99f43aa0f
commit 64edc869d0
5 changed files with 79 additions and 78 deletions

View File

@@ -111,10 +111,10 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D> {
} }
} }
pub fn rotate(&mut self, k: i64){ pub fn rotate(&mut self, k: i64) {
unsafe{ unsafe {
(0..self.cols()).for_each(|i|{ (0..self.cols()).for_each(|i| {
(0..self.size()).for_each(|j|{ (0..self.size()).for_each(|j| {
znx::znx_rotate_inplace_i64(self.n() as u64, k, self.at_mut_ptr(i, j)); znx::znx_rotate_inplace_i64(self.n() as u64, k, self.at_mut_ptr(i, j));
}); });
}) })

View File

@@ -34,7 +34,7 @@ pub trait Infos {
/// Returns the number of size per polynomial. /// Returns the number of size per polynomial.
fn size(&self) -> usize { fn size(&self) -> usize {
let size: usize = self.inner().size(); let size: usize = self.inner().size();
debug_assert_eq!(size, self.k().div_ceil(self.basek())); debug_assert!(size >= self.k().div_ceil(self.basek()));
size size
} }

View File

@@ -2,7 +2,7 @@ use backend::{FFT64, Module, Scratch, VecZnx, VecZnxOps, ZnxZero};
use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, Infos, SetMetaData}; use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, Infos, SetMetaData};
pub trait GLWEOps: GLWECiphertextToMut + SetMetaData { pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized {
fn add<A, B>(&mut self, module: &Module<FFT64>, a: &A, b: &B) fn add<A, B>(&mut self, module: &Module<FFT64>, a: &A, b: &B)
where where
A: GLWECiphertextToRef, A: GLWECiphertextToRef,
@@ -14,7 +14,6 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData {
assert_eq!(b.n(), module.n()); assert_eq!(b.n(), module.n());
assert_eq!(self.n(), module.n()); assert_eq!(self.n(), module.n());
assert_eq!(a.basek(), b.basek()); assert_eq!(a.basek(), b.basek());
assert_eq!(self.basek(), a.basek());
assert!(self.rank() >= a.rank().max(b.rank())); assert!(self.rank() >= a.rank().max(b.rank()));
} }
@@ -47,8 +46,8 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData {
}); });
}); });
// self.set_basek(a.basek()); self.set_basek(a.basek());
// self.set_k(a.k().max(b.k())); self.set_k(set_k(self, a, b));
} }
fn add_inplace<A>(&mut self, module: &Module<FFT64>, a: &A) fn add_inplace<A>(&mut self, module: &Module<FFT64>, a: &A)
@@ -70,7 +69,9 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData {
module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i); module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i);
}); });
self.set_k(a.k().max(self.k())); if a.rank() != 0 {
self.set_k(a.k().min(self.k()));
}
} }
fn sub<A, B>(&mut self, module: &Module<FFT64>, a: &A, b: &B) fn sub<A, B>(&mut self, module: &Module<FFT64>, a: &A, b: &B)
@@ -118,7 +119,7 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData {
}); });
self.set_basek(a.basek()); self.set_basek(a.basek());
self.set_k(a.k().max(b.k())); self.set_k(set_k(self, a, b));
} }
fn sub_inplace_ab<A>(&mut self, module: &Module<FFT64>, a: &A) fn sub_inplace_ab<A>(&mut self, module: &Module<FFT64>, a: &A)
@@ -140,7 +141,9 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData {
module.vec_znx_sub_ab_inplace(&mut self_mut.data, i, &a_ref.data, i); module.vec_znx_sub_ab_inplace(&mut self_mut.data, i, &a_ref.data, i);
}); });
self.set_k(a.k().max(self.k())); if a.rank() != 0 {
self.set_k(a.k().min(self.k()));
}
} }
fn sub_inplace_ba<A>(&mut self, module: &Module<FFT64>, a: &A) fn sub_inplace_ba<A>(&mut self, module: &Module<FFT64>, a: &A)
@@ -162,7 +165,9 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData {
module.vec_znx_sub_ba_inplace(&mut self_mut.data, i, &a_ref.data, i); module.vec_znx_sub_ba_inplace(&mut self_mut.data, i, &a_ref.data, i);
}); });
self.set_k(a.k().max(self.k())); if a.rank() != 0 {
self.set_k(a.k().min(self.k()));
}
} }
fn rotate<A>(&mut self, module: &Module<FFT64>, k: i64, a: &A) fn rotate<A>(&mut self, module: &Module<FFT64>, k: i64, a: &A)
@@ -184,7 +189,9 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData {
}); });
self.set_basek(a.basek()); self.set_basek(a.basek());
self.set_k(a.k()); if a.rank() != 0 {
self.set_k(a.k().min(self.k()));
}
} }
fn rotate_inplace(&mut self, module: &Module<FFT64>, k: i64) { fn rotate_inplace(&mut self, module: &Module<FFT64>, k: i64) {
@@ -209,6 +216,8 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData {
assert_eq!(self.n(), module.n()); assert_eq!(self.n(), module.n());
assert_eq!(a.n(), module.n()); assert_eq!(a.n(), module.n());
assert_eq!(self.rank(), a.rank()); assert_eq!(self.rank(), a.rank());
assert_eq!(self.k(), a.k());
assert_eq!(self.basek(), a.basek());
} }
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
@@ -246,7 +255,7 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData {
module.vec_znx_normalize(a.basek(), &mut self_mut.data, i, &a_ref.data, i, scratch); module.vec_znx_normalize(a.basek(), &mut self_mut.data, i, &a_ref.data, i, scratch);
}); });
self.set_basek(a.basek()); self.set_basek(a.basek());
self.set_k(a.k()); self.set_k(a.k().min(self.k()));
} }
fn normalize_inplace(&mut self, module: &Module<FFT64>, scratch: &mut Scratch) { fn normalize_inplace(&mut self, module: &Module<FFT64>, scratch: &mut Scratch) {
@@ -266,3 +275,19 @@ impl GLWECiphertext<Vec<u8>> {
VecZnx::rsh_scratch_space(module.n()) VecZnx::rsh_scratch_space(module.n())
} }
} }
// c = op(a, b)
fn set_k(c: &impl Infos, a: &impl Infos, b: &impl Infos) -> usize {
if a.rank() != 0 || b.rank() != 0 {
let k = if a.rank() == 0 {
b.k()
} else if b.rank() == 0 {
a.k()
} else {
a.k().min(b.k())
};
k.min(c.k())
} else {
c.k()
}
}

View File

@@ -1,7 +1,7 @@
use crate::{GLWEAutomorphismKey, GLWECiphertext, GLWEOps, Infos, ScratchCore}; use crate::{GLWEAutomorphismKey, GLWECiphertext, GLWEOps, Infos, ScratchCore};
use std::collections::HashMap; use std::collections::HashMap;
use backend::{FFT64, Module, Scratch}; use backend::{FFT64, Module, Scratch, VecZnxOps};
/// [StreamPacker] enables only the fly GLWE packing /// [StreamPacker] enables only the fly GLWE packing
/// with constant memory of Log(N) ciphertexts. /// with constant memory of Log(N) ciphertexts.
@@ -65,7 +65,7 @@ impl GLWEPacker {
} }
/// Implicit reset of the internal state (to be called before a new packing procedure). /// Implicit reset of the internal state (to be called before a new packing procedure).
pub fn reset(&mut self) { fn reset(&mut self) {
for i in 0..self.accumulators.len() { for i in 0..self.accumulators.len() {
self.accumulators[i].value = false; self.accumulators[i].value = false;
self.accumulators[i].control = false; self.accumulators[i].control = false;
@@ -82,9 +82,7 @@ impl GLWEPacker {
GLWECiphertext::trace_galois_elements(module) GLWECiphertext::trace_galois_elements(module)
} }
/// Adds a GLWE ciphertext to the [StreamPacker]. And propagates /// Adds a GLWE ciphertext to the [StreamPacker].
/// intermediate results among the [Accumulator]s.
///
/// #Arguments /// #Arguments
/// ///
/// * `module`: static backend FFT tables. /// * `module`: static backend FFT tables.
@@ -96,11 +94,16 @@ impl GLWEPacker {
pub fn add<DataA: AsRef<[u8]>, DataAK: AsRef<[u8]>>( pub fn add<DataA: AsRef<[u8]>, DataAK: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
res: &mut Vec<GLWECiphertext<Vec<u8>>>,
a: Option<&GLWECiphertext<DataA>>, a: Option<&GLWECiphertext<DataA>>,
auto_keys: &HashMap<i64, GLWEAutomorphismKey<DataAK, FFT64>>, auto_keys: &HashMap<i64, GLWEAutomorphismKey<DataAK, FFT64>>,
scratch: &mut Scratch, scratch: &mut Scratch,
) { ) {
assert!(
self.counter < module.n(),
"Packing limit of {} reached",
module.n() >> self.log_batch
);
pack_core( pack_core(
module, module,
a, a,
@@ -110,35 +113,18 @@ impl GLWEPacker {
scratch, scratch,
); );
self.counter += 1 << self.log_batch; self.counter += 1 << self.log_batch;
if self.counter == module.n() {
res.push(
self.accumulators[module.log_n() - self.log_batch - 1]
.data
.clone(),
);
self.reset();
}
} }
/// Flushes all accumlators and appends the result to `res`. /// Flush result to`res`.
pub fn flush<DataAK: AsRef<[u8]>>( pub fn flush<Data: AsMut<[u8]> + AsRef<[u8]>>(&mut self, module: &Module<FFT64>, res: &mut GLWECiphertext<Data>) {
&mut self, assert!(self.counter == module.n());
module: &Module<FFT64>, // Copy result GLWE into res GLWE
res: &mut Vec<GLWECiphertext<Vec<u8>>>, res.copy(
auto_keys: &HashMap<i64, GLWEAutomorphismKey<DataAK, FFT64>>, module,
scratch: &mut Scratch, &self.accumulators[module.log_n() - self.log_batch - 1].data,
) { );
if self.counter != 0 {
while self.counter != 0 { self.reset();
self.add(
module,
res,
None::<&GLWECiphertext<Vec<u8>>>,
auto_keys,
scratch,
);
}
}
} }
} }

View File

@@ -74,8 +74,6 @@ fn apply() {
scratch.borrow(), scratch.borrow(),
); );
let mut res: Vec<GLWECiphertext<Vec<u8>>> = Vec::new();
(0..module.n() >> log_batch).for_each(|i| { (0..module.n() >> log_batch).for_each(|i| {
ct.encrypt_sk( ct.encrypt_sk(
&module, &module,
@@ -90,11 +88,10 @@ fn apply() {
pt.rotate_inplace(&module, -(1 << log_batch)); // X^-batch * pt pt.rotate_inplace(&module, -(1 << log_batch)); // X^-batch * pt
if reverse_bits_msb(i, log_n as u32) % 5 == 0 { if reverse_bits_msb(i, log_n as u32) % 5 == 0 {
packer.add(&module, &mut res, Some(&ct), &auto_keys, scratch.borrow()); packer.add(&module, Some(&ct), &auto_keys, scratch.borrow());
} else { } else {
packer.add( packer.add(
&module, &module,
&mut res,
None::<&GLWECiphertext<Vec<u8>>>, None::<&GLWECiphertext<Vec<u8>>>,
&auto_keys, &auto_keys,
scratch.borrow(), scratch.borrow(),
@@ -102,36 +99,29 @@ fn apply() {
} }
}); });
packer.flush(&module, &mut res, &auto_keys, scratch.borrow()); let mut res = GLWECiphertext::alloc(&module, basek, k_ct, rank);
packer.reset(); packer.flush(&module, &mut res);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_ct);
let mut data: Vec<i64> = vec![0i64; module.n()];
res.iter().enumerate().for_each(|(i, res_i)| { data.iter_mut().enumerate().for_each(|(i, x)| {
let mut data: Vec<i64> = vec![0i64; module.n()]; if i % 5 == 0 {
data.iter_mut().enumerate().for_each(|(i, x)| { *x = reverse_bits_msb(i, log_n as u32) as i64;
if i % 5 == 0 {
*x = reverse_bits_msb(i, log_n as u32) as i64;
}
});
pt_want.data.encode_vec_i64(0, basek, pt_k, &data, 32);
res_i.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
if i & 1 == 0 {
pt.sub_inplace_ab(&module, &pt_want);
} else {
pt.add_inplace(&module, &pt_want);
} }
let noise_have = pt.data.std(0, basek).log2();
// println!("noise_have: {}", noise_have);
assert!(
noise_have < -((k_ct - basek) as f64),
"noise: {}",
noise_have
);
}); });
pt_want.data.encode_vec_i64(0, basek, pt_k, &data, 32);
res.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
pt.sub_inplace_ab(&module, &pt_want);
let noise_have = pt.data.std(0, basek).log2();
// println!("noise_have: {}", noise_have);
assert!(
noise_have < -((k_ct - basek) as f64),
"noise: {}",
noise_have
);
} }
#[inline(always)] #[inline(always)]