more fixes

This commit is contained in:
Jean-Philippe Bossuat
2025-04-29 18:16:09 +02:00
parent 917a472437
commit 06d0c5e832
5 changed files with 22 additions and 119 deletions

View File

@@ -32,12 +32,12 @@ pub trait Sampling {
}
impl<B: Backend> Sampling for Module<B> {
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source) {
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, size: usize, source: &mut Source) {
let base2k: u64 = 1 << log_base2k;
let mask: u64 = base2k - 1;
let base2k_half: i64 = (base2k >> 1) as i64;
(0..size).for_each(|j| {
a.at_poly_mut(col_i, j)
a.at_mut(col_a, j)
.iter_mut()
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
})
@@ -47,7 +47,7 @@ impl<B: Backend> Sampling for Module<B> {
&self,
log_base2k: usize,
a: &mut VecZnx,
col_i: usize,
col_a: usize,
log_k: usize,
source: &mut Source,
dist: D,
@@ -63,7 +63,7 @@ impl<B: Backend> Sampling for Module<B> {
let log_base2k_rem: usize = log_k % log_base2k;
if log_base2k_rem != 0 {
a.at_poly_mut(col_i, limb).iter_mut().for_each(|a| {
a.at_mut(col_a, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
@@ -71,7 +71,7 @@ impl<B: Backend> Sampling for Module<B> {
*a += (dist_f64.round() as i64) << log_base2k_rem;
});
} else {
a.at_poly_mut(col_i, limb).iter_mut().for_each(|a| {
a.at_mut(col_a, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
@@ -85,7 +85,7 @@ impl<B: Backend> Sampling for Module<B> {
&self,
log_base2k: usize,
a: &mut VecZnx,
col_i: usize,
col_a: usize,
log_k: usize,
source: &mut Source,
sigma: f64,
@@ -94,7 +94,7 @@ impl<B: Backend> Sampling for Module<B> {
self.add_dist_f64(
log_base2k,
a,
col_i,
col_a,
log_k,
source,
Normal::new(0.0, sigma).unwrap(),
@@ -125,7 +125,7 @@ mod tests {
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at_poly(col_j, limb_i), zero);
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(col_i, log_base2k);
@@ -159,7 +159,7 @@ mod tests {
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at_poly(col_j, limb_i), zero);
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(col_i, log_base2k) * k_f64;