updated Sampling implementation

This commit is contained in:
Jean-Philippe Bossuat
2025-02-14 19:21:14 +01:00
parent 9ff197dd37
commit e7a6ba17ee
3 changed files with 57 additions and 25 deletions

View File

@@ -32,7 +32,7 @@ fn main() {
// a <- Z_{2^prec}[X]/(X^{N}+1)
let mut a: VecZnx = module.new_vec_znx(limbs);
a.fill_uniform(log_base2k, limbs, &mut source);
module.fill_uniform(log_base2k, &mut a, limbs, &mut source);
// Scratch space for DFT values
let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.limbs());
@@ -62,7 +62,14 @@ fn main() {
// b <- normalize(buf_big) + e
let mut b: VecZnx = module.new_vec_znx(limbs);
module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry);
b.add_normal(log_base2k, log_base2k * limbs, &mut source, 3.2, 19.0);
module.add_normal(
log_base2k,
&mut b,
log_base2k * limbs,
&mut source,
3.2,
19.0,
);
//Decrypt

View File

@@ -1,25 +1,27 @@
use crate::{Infos, VecZnx, VecZnxApi};
use crate::{Infos, Module, VecZnxApi};
use rand_distr::{Distribution, Normal};
use sampling::source::Source;
pub trait Sampling {
pub trait Sampling<T: VecZnxApi + Infos> {
/// Fills the first `limbs` limbs with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
fn fill_uniform(&mut self, log_base2k: usize, limbs: usize, source: &mut Source);
fn fill_uniform(&self, log_base2k: usize, a: &mut T, limbs: usize, source: &mut Source);
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
fn add_dist_f64<T: Distribution<f64>>(
&mut self,
fn add_dist_f64<D: Distribution<f64>>(
&self,
log_base2k: usize,
a: &mut T,
log_k: usize,
source: &mut Source,
dist: T,
dist: D,
bound: f64,
);
/// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\].
fn add_normal(
&mut self,
&self,
log_base2k: usize,
a: &mut T,
log_k: usize,
source: &mut Source,
sigma: f64,
@@ -27,25 +29,24 @@ pub trait Sampling {
);
}
impl Sampling for VecZnx {
fn fill_uniform(&mut self, log_base2k: usize, limbs: usize, source: &mut Source) {
impl<T: VecZnxApi + Infos> Sampling<T> for Module {
fn fill_uniform(&self, log_base2k: usize, a: &mut T, limbs: usize, source: &mut Source) {
let base2k: u64 = 1 << log_base2k;
let mask: u64 = base2k - 1;
let base2k_half: i64 = (base2k >> 1) as i64;
let size: usize = self.n() * limbs;
self.data[..size]
let size: usize = a.n() * limbs;
a.raw_mut()[..size]
.iter_mut()
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
}
fn add_dist_f64<T: Distribution<f64>>(
&mut self,
fn add_dist_f64<D: Distribution<f64>>(
&self,
log_base2k: usize,
a: &mut T,
log_k: usize,
source: &mut Source,
dist: T,
dist: D,
bound: f64,
) {
assert!(
@@ -57,7 +58,7 @@ impl Sampling for VecZnx {
let log_base2k_rem: usize = log_k % log_base2k;
if log_base2k_rem != 0 {
self.at_mut(self.limbs() - 1).iter_mut().for_each(|a| {
a.at_mut(a.limbs() - 1).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
@@ -65,7 +66,7 @@ impl Sampling for VecZnx {
*a += (dist_f64.round() as i64) << log_base2k_rem
});
} else {
self.at_mut(self.limbs() - 1).iter_mut().for_each(|a| {
a.at_mut(a.limbs() - 1).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
@@ -76,8 +77,9 @@ impl Sampling for VecZnx {
}
fn add_normal(
&mut self,
&self,
log_base2k: usize,
a: &mut T,
log_k: usize,
source: &mut Source,
sigma: f64,
@@ -85,6 +87,7 @@ impl Sampling for VecZnx {
) {
self.add_dist_f64(
log_base2k,
a,
log_k,
source,
Normal::new(0.0, sigma).unwrap(),

View File

@@ -11,6 +11,8 @@ pub trait VecZnxApi {
/// Returns the minimum size of the [u8] array required to assign a
/// new backend array to a [VecZnx] through [VecZnx::from_bytes].
fn bytes_of(n: usize, limbs: usize) -> usize;
fn raw(&self) -> &[i64];
fn raw_mut(&mut self) -> &mut [i64];
fn as_ptr(&self) -> *const i64;
fn as_mut_ptr(&mut self) -> *mut i64;
fn at(&self, i: usize) -> &[i64];
@@ -69,12 +71,22 @@ impl VecZnxApi for VecZnxBorrow {
self.data
}
fn raw(&self) -> &[i64] {
unsafe { std::slice::from_raw_parts(self.data, self.n * self.limbs) }
}
fn raw_mut(&mut self) -> &mut [i64] {
unsafe { std::slice::from_raw_parts_mut(self.data, self.n * self.limbs) }
}
fn at(&self, i: usize) -> &[i64] {
unsafe { std::slice::from_raw_parts(self.data.wrapping_add(self.n * i), self.n) }
let n: usize = self.n();
&self.raw()[n * i..n * (i + 1)]
}
fn at_mut(&mut self, i: usize) -> &mut [i64] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i), self.n) }
let n: usize = self.n();
&mut self.raw_mut()[n * i..n * (i + 1)]
}
fn at_ptr(&self, i: usize) -> *const i64 {
@@ -147,6 +159,14 @@ impl VecZnxApi for VecZnx {
bytes_of_vec_znx(n, limbs)
}
fn raw(&self) -> &[i64] {
&self.data
}
fn raw_mut(&mut self) -> &mut [i64] {
&mut self.data
}
/// Returns a non-mutable pointer to the backing array of the [VecZnx].
fn as_ptr(&self) -> *const i64 {
self.data.as_ptr()
@@ -159,12 +179,14 @@ impl VecZnxApi for VecZnx {
/// Returns a non-mutable reference to the i-th limb of the [VecZnx].
fn at(&self, i: usize) -> &[i64] {
&self.data[i * self.n..(i + 1) * self.n]
let n: usize = self.n();
&self.raw()[n * i..n * (i + 1)]
}
/// Returns a mutable reference to the i-th limb of the [VecZnx].
fn at_mut(&mut self, i: usize) -> &mut [i64] {
&mut self.data[i * self.n..(i + 1) * self.n]
let n: usize = self.n();
&mut self.raw_mut()[n * i..n * (i + 1)]
}
/// Returns a non-mutable pointer to the i-th limb of the [VecZnx].