diff --git a/backend/src/vec_znx.rs b/backend/src/vec_znx.rs index 00568dd..74f9f86 100644 --- a/backend/src/vec_znx.rs +++ b/backend/src/vec_znx.rs @@ -111,10 +111,10 @@ impl + AsRef<[u8]>> VecZnx { } } - pub fn rotate(&mut self, k: i64){ - unsafe{ - (0..self.cols()).for_each(|i|{ - (0..self.size()).for_each(|j|{ + pub fn rotate(&mut self, k: i64) { + unsafe { + (0..self.cols()).for_each(|i| { + (0..self.size()).for_each(|j| { znx::znx_rotate_inplace_i64(self.n() as u64, k, self.at_mut_ptr(i, j)); }); }) diff --git a/core/src/elem.rs b/core/src/elem.rs index 9a1de39..6e15616 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -34,7 +34,7 @@ pub trait Infos { /// Returns the number of size per polynomial. fn size(&self) -> usize { let size: usize = self.inner().size(); - debug_assert_eq!(size, self.k().div_ceil(self.basek())); + debug_assert!(size >= self.k().div_ceil(self.basek())); size } diff --git a/core/src/glwe/ops.rs b/core/src/glwe/ops.rs index 9248e85..46f6bdc 100644 --- a/core/src/glwe/ops.rs +++ b/core/src/glwe/ops.rs @@ -2,7 +2,7 @@ use backend::{FFT64, Module, Scratch, VecZnx, VecZnxOps, ZnxZero}; use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, Infos, SetMetaData}; -pub trait GLWEOps: GLWECiphertextToMut + SetMetaData { +pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { fn add(&mut self, module: &Module, a: &A, b: &B) where A: GLWECiphertextToRef, @@ -14,7 +14,6 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData { assert_eq!(b.n(), module.n()); assert_eq!(self.n(), module.n()); assert_eq!(a.basek(), b.basek()); - assert_eq!(self.basek(), a.basek()); assert!(self.rank() >= a.rank().max(b.rank())); } @@ -47,8 +46,8 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData { }); }); - // self.set_basek(a.basek()); - // self.set_k(a.k().max(b.k())); + self.set_basek(a.basek()); + self.set_k(set_k(self, a, b)); } fn add_inplace(&mut self, module: &Module, a: &A) @@ -70,7 +69,9 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData { module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i); }); - self.set_k(a.k().max(self.k())); + if a.rank() != 0 { + self.set_k(a.k().min(self.k())); + } } fn sub(&mut self, module: &Module, a: &A, b: &B) @@ -118,7 +119,7 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData { }); self.set_basek(a.basek()); - self.set_k(a.k().max(b.k())); + self.set_k(set_k(self, a, b)); } fn sub_inplace_ab(&mut self, module: &Module, a: &A) @@ -140,7 +141,9 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData { module.vec_znx_sub_ab_inplace(&mut self_mut.data, i, &a_ref.data, i); }); - self.set_k(a.k().max(self.k())); + if a.rank() != 0 { + self.set_k(a.k().min(self.k())); + } } fn sub_inplace_ba(&mut self, module: &Module, a: &A) @@ -162,7 +165,9 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData { module.vec_znx_sub_ba_inplace(&mut self_mut.data, i, &a_ref.data, i); }); - self.set_k(a.k().max(self.k())); + if a.rank() != 0 { + self.set_k(a.k().min(self.k())); + } } fn rotate(&mut self, module: &Module, k: i64, a: &A) @@ -184,7 +189,9 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData { }); self.set_basek(a.basek()); - self.set_k(a.k()); + if a.rank() != 0 { + self.set_k(a.k().min(self.k())); + } } fn rotate_inplace(&mut self, module: &Module, k: i64) { @@ -209,6 +216,8 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData { assert_eq!(self.n(), module.n()); assert_eq!(a.n(), module.n()); assert_eq!(self.rank(), a.rank()); + assert_eq!(self.k(), a.k()); + assert_eq!(self.basek(), a.basek()); } let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); @@ -246,7 +255,7 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData { module.vec_znx_normalize(a.basek(), &mut self_mut.data, i, &a_ref.data, i, scratch); }); self.set_basek(a.basek()); - self.set_k(a.k()); + self.set_k(a.k().min(self.k())); } fn normalize_inplace(&mut self, module: &Module, scratch: &mut Scratch) { @@ -266,3 +275,19 @@ impl GLWECiphertext> { VecZnx::rsh_scratch_space(module.n()) } } + +// c = op(a, b) +fn set_k(c: &impl Infos, a: &impl Infos, b: &impl Infos) -> usize { + if a.rank() != 0 || b.rank() != 0 { + let k = if a.rank() == 0 { + b.k() + } else if b.rank() == 0 { + a.k() + } else { + a.k().min(b.k()) + }; + k.min(c.k()) + } else { + c.k() + } +} diff --git a/core/src/glwe/packing.rs b/core/src/glwe/packing.rs index 3496994..bfd7f36 100644 --- a/core/src/glwe/packing.rs +++ b/core/src/glwe/packing.rs @@ -1,7 +1,7 @@ use crate::{GLWEAutomorphismKey, GLWECiphertext, GLWEOps, Infos, ScratchCore}; use std::collections::HashMap; -use backend::{FFT64, Module, Scratch}; +use backend::{FFT64, Module, Scratch, VecZnxOps}; /// [StreamPacker] enables only the fly GLWE packing /// with constant memory of Log(N) ciphertexts. @@ -65,7 +65,7 @@ impl GLWEPacker { } /// Implicit reset of the internal state (to be called before a new packing procedure). - pub fn reset(&mut self) { + fn reset(&mut self) { for i in 0..self.accumulators.len() { self.accumulators[i].value = false; self.accumulators[i].control = false; @@ -82,9 +82,7 @@ impl GLWEPacker { GLWECiphertext::trace_galois_elements(module) } - /// Adds a GLWE ciphertext to the [StreamPacker]. And propagates - /// intermediate results among the [Accumulator]s. - /// + /// Adds a GLWE ciphertext to the [StreamPacker]. /// #Arguments /// /// * `module`: static backend FFT tables. @@ -96,11 +94,16 @@ impl GLWEPacker { pub fn add, DataAK: AsRef<[u8]>>( &mut self, module: &Module, - res: &mut Vec>>, a: Option<&GLWECiphertext>, auto_keys: &HashMap>, scratch: &mut Scratch, ) { + assert!( + self.counter < module.n(), + "Packing limit of {} reached", + module.n() >> self.log_batch + ); + pack_core( module, a, @@ -110,35 +113,18 @@ impl GLWEPacker { scratch, ); self.counter += 1 << self.log_batch; - if self.counter == module.n() { - res.push( - self.accumulators[module.log_n() - self.log_batch - 1] - .data - .clone(), - ); - self.reset(); - } } - /// Flushes all accumlators and appends the result to `res`. - pub fn flush>( - &mut self, - module: &Module, - res: &mut Vec>>, - auto_keys: &HashMap>, - scratch: &mut Scratch, - ) { - if self.counter != 0 { - while self.counter != 0 { - self.add( - module, - res, - None::<&GLWECiphertext>>, - auto_keys, - scratch, - ); - } - } + /// Flush result to`res`. + pub fn flush + AsRef<[u8]>>(&mut self, module: &Module, res: &mut GLWECiphertext) { + assert!(self.counter == module.n()); + // Copy result GLWE into res GLWE + res.copy( + module, + &self.accumulators[module.log_n() - self.log_batch - 1].data, + ); + + self.reset(); } } diff --git a/core/src/glwe/test_fft64/packing.rs b/core/src/glwe/test_fft64/packing.rs index 0e9ee71..f747697 100644 --- a/core/src/glwe/test_fft64/packing.rs +++ b/core/src/glwe/test_fft64/packing.rs @@ -74,8 +74,6 @@ fn apply() { scratch.borrow(), ); - let mut res: Vec>> = Vec::new(); - (0..module.n() >> log_batch).for_each(|i| { ct.encrypt_sk( &module, @@ -90,11 +88,10 @@ fn apply() { pt.rotate_inplace(&module, -(1 << log_batch)); // X^-batch * pt if reverse_bits_msb(i, log_n as u32) % 5 == 0 { - packer.add(&module, &mut res, Some(&ct), &auto_keys, scratch.borrow()); + packer.add(&module, Some(&ct), &auto_keys, scratch.borrow()); } else { packer.add( &module, - &mut res, None::<&GLWECiphertext>>, &auto_keys, scratch.borrow(), @@ -102,36 +99,29 @@ fn apply() { } }); - packer.flush(&module, &mut res, &auto_keys, scratch.borrow()); - packer.reset(); + let mut res = GLWECiphertext::alloc(&module, basek, k_ct, rank); + packer.flush(&module, &mut res); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - res.iter().enumerate().for_each(|(i, res_i)| { - let mut data: Vec = vec![0i64; module.n()]; - data.iter_mut().enumerate().for_each(|(i, x)| { - if i % 5 == 0 { - *x = reverse_bits_msb(i, log_n as u32) as i64; - } - }); - pt_want.data.encode_vec_i64(0, basek, pt_k, &data, 32); - - res_i.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - - if i & 1 == 0 { - pt.sub_inplace_ab(&module, &pt_want); - } else { - pt.add_inplace(&module, &pt_want); + let mut data: Vec = vec![0i64; module.n()]; + data.iter_mut().enumerate().for_each(|(i, x)| { + if i % 5 == 0 { + *x = reverse_bits_msb(i, log_n as u32) as i64; } - - let noise_have = pt.data.std(0, basek).log2(); - // println!("noise_have: {}", noise_have); - assert!( - noise_have < -((k_ct - basek) as f64), - "noise: {}", - noise_have - ); }); + pt_want.data.encode_vec_i64(0, basek, pt_k, &data, 32); + + res.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + pt.sub_inplace_ab(&module, &pt_want); + + let noise_have = pt.data.std(0, basek).log2(); + // println!("noise_have: {}", noise_have); + assert!( + noise_have < -((k_ct - basek) as f64), + "noise: {}", + noise_have + ); } #[inline(always)]