diff --git a/backend/src/mat_znx_dft_ops.rs b/backend/src/mat_znx_dft_ops.rs index 9ed71a0..5f08a89 100644 --- a/backend/src/mat_znx_dft_ops.rs +++ b/backend/src/mat_znx_dft_ops.rs @@ -3,7 +3,7 @@ use crate::ffi::vmp; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxToMut, MatZnxToRef, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, - ScalarZnxDftOps, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxZero, + ScalarZnxDftOps, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, }; pub trait MatZnxDftAlloc { diff --git a/backend/src/scalar_znx.rs b/backend/src/scalar_znx.rs index cb51e0d..4c45f36 100644 --- a/backend/src/scalar_znx.rs +++ b/backend/src/scalar_znx.rs @@ -72,6 +72,31 @@ impl + AsRef<[u8]>> ScalarZnx { .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); self.at_mut(col, 0).shuffle(source); } + + pub fn fill_binary_prob(&mut self, col: usize, prob: f64, source: &mut Source) { + let choices: [i64; 2] = [0, 1]; + let weights: [f64; 2] = [1.0 - prob, prob]; + let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); + self.at_mut(col, 0) + .iter_mut() + .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); + } + + pub fn fill_binary_hw(&mut self, col: usize, hw: usize, source: &mut Source) { + assert!(hw <= self.n()); + self.at_mut(col, 0)[..hw] + .iter_mut() + .for_each(|x: &mut i64| *x = (source.next_u32() & 1) as i64); + self.at_mut(col, 0).shuffle(source); + } + + pub fn fill_binary_block(&mut self, col: usize, block_size: usize, source: &mut Source) { + assert!(self.n() % block_size == 0); + for chunk in self.at_mut(col, 0).chunks_mut(block_size) { + chunk[0] = 1; + chunk.shuffle(source); + } + } } impl>> ScalarZnx { diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 28c1724..cd85f0c 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -763,6 +763,9 @@ impl + AsMut<[u8]>> GLWECiphertext { ), SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + SecretDistribution::BinaryFixed(hw) => u.fill_binary_hw(0, hw, source_xu), + SecretDistribution::BinaryProb(prob) => u.fill_binary_prob(0, prob, source_xu), + SecretDistribution::BinaryBlock(block_size) => u.fill_binary_block(0, block_size, source_xu), SecretDistribution::ZERO => {} } diff --git a/core/src/glwe_keys.rs b/core/src/glwe_keys.rs index 8f04408..be8aa43 100644 --- a/core/src/glwe_keys.rs +++ b/core/src/glwe_keys.rs @@ -10,8 +10,11 @@ use crate::{GLWECiphertextFourier, Infos}; pub(crate) enum SecretDistribution { TernaryFixed(usize), // Ternary with fixed Hamming weight TernaryProb(f64), // Ternary with probabilistic Hamming weight + BinaryFixed(usize), // Binary with fixed Hamming weight + BinaryProb(f64), // Binary with probabilistic Hamming weight + BinaryBlock(usize), // Binary split in block of size 2^k ZERO, // Debug mod - NONE, + NONE, // Unitialized } pub struct GLWESecret { @@ -65,6 +68,30 @@ impl + AsRef<[u8]>> GLWESecret { self.dist = SecretDistribution::TernaryFixed(hw); } + pub fn fill_binary_prob(&mut self, module: &Module, prob: f64, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_binary_prob(i, prob, source); + }); + self.prep_fourier(module); + self.dist = SecretDistribution::BinaryProb(prob); + } + + pub fn fill_binary_hw(&mut self, module: &Module, hw: usize, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_binary_hw(i, hw, source); + }); + self.prep_fourier(module); + self.dist = SecretDistribution::BinaryFixed(hw); + } + + pub fn fill_binary_block(&mut self, module: &Module, block_size: usize, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_binary_block(i, block_size, source); + }); + self.prep_fourier(module); + self.dist = SecretDistribution::BinaryBlock(block_size); + } + pub fn fill_zero(&mut self) { self.data.zero(); self.dist = SecretDistribution::ZERO;