From e7a6ba17ee1425e31ba61b76a72808b72970dfd5 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 14 Feb 2025 19:21:14 +0100 Subject: [PATCH] updated Sampling implementation --- base2k/examples/rlwe_encrypt.rs | 11 +++++++-- base2k/src/sampling.rs | 41 ++++++++++++++++++--------------- base2k/src/vec_znx.rs | 30 ++++++++++++++++++++---- 3 files changed, 57 insertions(+), 25 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index f2ffbe4..0281ec3 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -32,7 +32,7 @@ fn main() { // a <- Z_{2^prec}[X]/(X^{N}+1) let mut a: VecZnx = module.new_vec_znx(limbs); - a.fill_uniform(log_base2k, limbs, &mut source); + module.fill_uniform(log_base2k, &mut a, limbs, &mut source); // Scratch space for DFT values let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.limbs()); @@ -62,7 +62,14 @@ fn main() { // b <- normalize(buf_big) + e let mut b: VecZnx = module.new_vec_znx(limbs); module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry); - b.add_normal(log_base2k, log_base2k * limbs, &mut source, 3.2, 19.0); + module.add_normal( + log_base2k, + &mut b, + log_base2k * limbs, + &mut source, + 3.2, + 19.0, + ); //Decrypt diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index bfe1ec3..1698d0f 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,25 +1,27 @@ -use crate::{Infos, VecZnx, VecZnxApi}; +use crate::{Infos, Module, VecZnxApi}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; -pub trait Sampling { +pub trait Sampling { /// Fills the first `limbs` limbs with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] - fn fill_uniform(&mut self, log_base2k: usize, limbs: usize, source: &mut Source); + fn fill_uniform(&self, log_base2k: usize, a: &mut T, limbs: usize, source: &mut Source); /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. - fn add_dist_f64>( - &mut self, + fn add_dist_f64>( + &self, log_base2k: usize, + a: &mut T, log_k: usize, source: &mut Source, - dist: T, + dist: D, bound: f64, ); /// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\]. fn add_normal( - &mut self, + &self, log_base2k: usize, + a: &mut T, log_k: usize, source: &mut Source, sigma: f64, @@ -27,25 +29,24 @@ pub trait Sampling { ); } -impl Sampling for VecZnx { - fn fill_uniform(&mut self, log_base2k: usize, limbs: usize, source: &mut Source) { +impl Sampling for Module { + fn fill_uniform(&self, log_base2k: usize, a: &mut T, limbs: usize, source: &mut Source) { let base2k: u64 = 1 << log_base2k; let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; - - let size: usize = self.n() * limbs; - - self.data[..size] + let size: usize = a.n() * limbs; + a.raw_mut()[..size] .iter_mut() .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); } - fn add_dist_f64>( - &mut self, + fn add_dist_f64>( + &self, log_base2k: usize, + a: &mut T, log_k: usize, source: &mut Source, - dist: T, + dist: D, bound: f64, ) { assert!( @@ -57,7 +58,7 @@ impl Sampling for VecZnx { let log_base2k_rem: usize = log_k % log_base2k; if log_base2k_rem != 0 { - self.at_mut(self.limbs() - 1).iter_mut().for_each(|a| { + a.at_mut(a.limbs() - 1).iter_mut().for_each(|a| { let mut dist_f64: f64 = dist.sample(source); while dist_f64.abs() > bound { dist_f64 = dist.sample(source) @@ -65,7 +66,7 @@ impl Sampling for VecZnx { *a += (dist_f64.round() as i64) << log_base2k_rem }); } else { - self.at_mut(self.limbs() - 1).iter_mut().for_each(|a| { + a.at_mut(a.limbs() - 1).iter_mut().for_each(|a| { let mut dist_f64: f64 = dist.sample(source); while dist_f64.abs() > bound { dist_f64 = dist.sample(source) @@ -76,8 +77,9 @@ impl Sampling for VecZnx { } fn add_normal( - &mut self, + &self, log_base2k: usize, + a: &mut T, log_k: usize, source: &mut Source, sigma: f64, @@ -85,6 +87,7 @@ impl Sampling for VecZnx { ) { self.add_dist_f64( log_base2k, + a, log_k, source, Normal::new(0.0, sigma).unwrap(), diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 216ccb0..b5680e4 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -11,6 +11,8 @@ pub trait VecZnxApi { /// Returns the minimum size of the [u8] array required to assign a /// new backend array to a [VecZnx] through [VecZnx::from_bytes]. fn bytes_of(n: usize, limbs: usize) -> usize; + fn raw(&self) -> &[i64]; + fn raw_mut(&mut self) -> &mut [i64]; fn as_ptr(&self) -> *const i64; fn as_mut_ptr(&mut self) -> *mut i64; fn at(&self, i: usize) -> &[i64]; @@ -69,12 +71,22 @@ impl VecZnxApi for VecZnxBorrow { self.data } + fn raw(&self) -> &[i64] { + unsafe { std::slice::from_raw_parts(self.data, self.n * self.limbs) } + } + + fn raw_mut(&mut self) -> &mut [i64] { + unsafe { std::slice::from_raw_parts_mut(self.data, self.n * self.limbs) } + } + fn at(&self, i: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.data.wrapping_add(self.n * i), self.n) } + let n: usize = self.n(); + &self.raw()[n * i..n * (i + 1)] } fn at_mut(&mut self, i: usize) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i), self.n) } + let n: usize = self.n(); + &mut self.raw_mut()[n * i..n * (i + 1)] } fn at_ptr(&self, i: usize) -> *const i64 { @@ -147,6 +159,14 @@ impl VecZnxApi for VecZnx { bytes_of_vec_znx(n, limbs) } + fn raw(&self) -> &[i64] { + &self.data + } + + fn raw_mut(&mut self) -> &mut [i64] { + &mut self.data + } + /// Returns a non-mutable pointer to the backing array of the [VecZnx]. fn as_ptr(&self) -> *const i64 { self.data.as_ptr() @@ -159,12 +179,14 @@ impl VecZnxApi for VecZnx { /// Returns a non-mutable reference to the i-th limb of the [VecZnx]. fn at(&self, i: usize) -> &[i64] { - &self.data[i * self.n..(i + 1) * self.n] + let n: usize = self.n(); + &self.raw()[n * i..n * (i + 1)] } /// Returns a mutable reference to the i-th limb of the [VecZnx]. fn at_mut(&mut self, i: usize) -> &mut [i64] { - &mut self.data[i * self.n..(i + 1) * self.n] + let n: usize = self.n(); + &mut self.raw_mut()[n * i..n * (i + 1)] } /// Returns a non-mutable pointer to the i-th limb of the [VecZnx].