mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
more fixes
This commit is contained in:
@@ -107,7 +107,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize,
|
|||||||
// values on the last limb.
|
// values on the last limb.
|
||||||
// Else we decompose values base2k.
|
// Else we decompose values base2k.
|
||||||
if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
|
if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
|
||||||
a.at_poly_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
|
a.at_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
|
||||||
} else {
|
} else {
|
||||||
let mask: i64 = (1 << log_base2k) - 1;
|
let mask: i64 = (1 << log_base2k) - 1;
|
||||||
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
||||||
@@ -116,7 +116,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize,
|
|||||||
.enumerate()
|
.enumerate()
|
||||||
.for_each(|(i, i_rev)| {
|
.for_each(|(i, i_rev)| {
|
||||||
let shift: usize = i * log_base2k;
|
let shift: usize = i * log_base2k;
|
||||||
izip!(a.at_poly_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
|
izip!(a.at_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize,
|
|||||||
if log_k_rem != log_base2k {
|
if log_k_rem != log_base2k {
|
||||||
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
||||||
(size - steps..size).rev().for_each(|i| {
|
(size - steps..size).rev().for_each(|i| {
|
||||||
a.at_poly_mut(col_i, i)[..data_len]
|
a.at_mut(col_i, i)[..data_len]
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.for_each(|x| *x <<= log_k_rem);
|
.for_each(|x| *x <<= log_k_rem);
|
||||||
})
|
})
|
||||||
@@ -143,16 +143,16 @@ fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, dat
|
|||||||
);
|
);
|
||||||
assert!(col_i < a.cols());
|
assert!(col_i < a.cols());
|
||||||
}
|
}
|
||||||
data.copy_from_slice(a.at_poly(col_i, 0));
|
data.copy_from_slice(a.at(col_i, 0));
|
||||||
let rem: usize = log_base2k - (log_k % log_base2k);
|
let rem: usize = log_base2k - (log_k % log_base2k);
|
||||||
(1..size).for_each(|i| {
|
(1..size).for_each(|i| {
|
||||||
if i == size - 1 && rem != log_base2k {
|
if i == size - 1 && rem != log_base2k {
|
||||||
let k_rem: usize = log_base2k - rem;
|
let k_rem: usize = log_base2k - rem;
|
||||||
izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||||
*y = (*y << k_rem) + (x >> rem);
|
*y = (*y << k_rem) + (x >> rem);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||||
*y = (*y << log_base2k) + x;
|
*y = (*y << log_base2k) + x;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -180,12 +180,12 @@ fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Flo
|
|||||||
// y[i] = sum x[j][i] * 2^{-log_base2k*j}
|
// y[i] = sum x[j][i] * 2^{-log_base2k*j}
|
||||||
(0..size).for_each(|i| {
|
(0..size).for_each(|i| {
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||||
y.assign(*x);
|
y.assign(*x);
|
||||||
*y /= &base;
|
*y /= &base;
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||||
*y += Float::with_val(prec, *x);
|
*y += Float::with_val(prec, *x);
|
||||||
*y /= &base;
|
*y /= &base;
|
||||||
});
|
});
|
||||||
@@ -209,13 +209,13 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz
|
|||||||
}
|
}
|
||||||
|
|
||||||
let log_k_rem: usize = log_base2k - (log_k % log_base2k);
|
let log_k_rem: usize = log_base2k - (log_k % log_base2k);
|
||||||
(0..a.size()).for_each(|j| a.at_poly_mut(col_i, j)[i] = 0);
|
(0..a.size()).for_each(|j| a.at_mut(col_i, j)[i] = 0);
|
||||||
|
|
||||||
// If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy
|
// If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy
|
||||||
// values on the last limb.
|
// values on the last limb.
|
||||||
// Else we decompose values base2k.
|
// Else we decompose values base2k.
|
||||||
if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
|
if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
|
||||||
a.at_poly_mut(col_i, size - 1)[i] = value;
|
a.at_mut(col_i, size - 1)[i] = value;
|
||||||
} else {
|
} else {
|
||||||
let mask: i64 = (1 << log_base2k) - 1;
|
let mask: i64 = (1 << log_base2k) - 1;
|
||||||
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
||||||
@@ -223,7 +223,7 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz
|
|||||||
.rev()
|
.rev()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.for_each(|(j, j_rev)| {
|
.for_each(|(j, j_rev)| {
|
||||||
a.at_poly_mut(col_i, j_rev)[i] = (value >> (j * log_base2k)) & mask;
|
a.at_mut(col_i, j_rev)[i] = (value >> (j * log_base2k)) & mask;
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -231,7 +231,7 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz
|
|||||||
if log_k_rem != log_base2k {
|
if log_k_rem != log_base2k {
|
||||||
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
||||||
(size - steps..size).rev().for_each(|j| {
|
(size - steps..size).rev().for_each(|j| {
|
||||||
a.at_poly_mut(col_i, j)[i] <<= log_k_rem;
|
a.at_mut(col_i, j)[i] <<= log_k_rem;
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,96 +0,0 @@
|
|||||||
use std::cmp::{max, min};
|
|
||||||
|
|
||||||
use crate::{Backend, IntegerType, Module, ZnxBasics, ZnxLayout, ffi::module::MODULE};
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn apply_unary_op<B: Backend, T: ZnxBasics + ZnxLayout>(
|
|
||||||
module: &Module<B>,
|
|
||||||
b: &mut T,
|
|
||||||
a: &T,
|
|
||||||
op: impl Fn(&mut [T::Scalar], &[T::Scalar]),
|
|
||||||
) where
|
|
||||||
<T as ZnxLayout>::Scalar: IntegerType,
|
|
||||||
{
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(a.n(), module.n());
|
|
||||||
assert_eq!(b.n(), module.n());
|
|
||||||
}
|
|
||||||
let a_cols: usize = a.cols();
|
|
||||||
let b_cols: usize = b.cols();
|
|
||||||
let min_cols: usize = min(a_cols, b_cols);
|
|
||||||
// Applies over the shared cols between (a, b)
|
|
||||||
(0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0)));
|
|
||||||
// Zeroes the remaining cols of b.
|
|
||||||
(min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j)));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn ffi_ternary_op_factory<T>(
|
|
||||||
module_ptr: *const MODULE,
|
|
||||||
c_size: usize,
|
|
||||||
c_sl: usize,
|
|
||||||
a_size: usize,
|
|
||||||
a_sl: usize,
|
|
||||||
b_size: usize,
|
|
||||||
b_sl: usize,
|
|
||||||
op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64, *const T, u64, u64),
|
|
||||||
) -> impl Fn(&mut [T], &[T], &[T]) {
|
|
||||||
move |cv: &mut [T], av: &[T], bv: &[T]| unsafe {
|
|
||||||
op_fn(
|
|
||||||
module_ptr,
|
|
||||||
cv.as_mut_ptr(),
|
|
||||||
c_size as u64,
|
|
||||||
c_sl as u64,
|
|
||||||
av.as_ptr(),
|
|
||||||
a_size as u64,
|
|
||||||
a_sl as u64,
|
|
||||||
bv.as_ptr(),
|
|
||||||
b_size as u64,
|
|
||||||
b_sl as u64,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn ffi_binary_op_factory_type_0<T>(
|
|
||||||
module_ptr: *const MODULE,
|
|
||||||
b_size: usize,
|
|
||||||
b_sl: usize,
|
|
||||||
a_size: usize,
|
|
||||||
a_sl: usize,
|
|
||||||
op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64),
|
|
||||||
) -> impl Fn(&mut [T], &[T]) {
|
|
||||||
move |bv: &mut [T], av: &[T]| unsafe {
|
|
||||||
op_fn(
|
|
||||||
module_ptr,
|
|
||||||
bv.as_mut_ptr(),
|
|
||||||
b_size as u64,
|
|
||||||
b_sl as u64,
|
|
||||||
av.as_ptr(),
|
|
||||||
a_size as u64,
|
|
||||||
a_sl as u64,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn ffi_binary_op_factory_type_1<T>(
|
|
||||||
module_ptr: *const MODULE,
|
|
||||||
k: i64,
|
|
||||||
b_size: usize,
|
|
||||||
b_sl: usize,
|
|
||||||
a_size: usize,
|
|
||||||
a_sl: usize,
|
|
||||||
op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut T, u64, u64, *const T, u64, u64),
|
|
||||||
) -> impl Fn(&mut [T], &[T]) {
|
|
||||||
move |bv: &mut [T], av: &[T]| unsafe {
|
|
||||||
op_fn(
|
|
||||||
module_ptr,
|
|
||||||
k,
|
|
||||||
bv.as_mut_ptr(),
|
|
||||||
b_size as u64,
|
|
||||||
b_sl as u64,
|
|
||||||
av.as_ptr(),
|
|
||||||
a_size as u64,
|
|
||||||
a_sl as u64,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -3,7 +3,6 @@ pub mod encoding;
|
|||||||
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
|
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
|
||||||
// Other modules and exports
|
// Other modules and exports
|
||||||
pub mod ffi;
|
pub mod ffi;
|
||||||
mod internals;
|
|
||||||
pub mod mat_znx_dft;
|
pub mod mat_znx_dft;
|
||||||
pub mod module;
|
pub mod module;
|
||||||
pub mod sampling;
|
pub mod sampling;
|
||||||
|
|||||||
@@ -32,12 +32,12 @@ pub trait Sampling {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> Sampling for Module<B> {
|
impl<B: Backend> Sampling for Module<B> {
|
||||||
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source) {
|
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, size: usize, source: &mut Source) {
|
||||||
let base2k: u64 = 1 << log_base2k;
|
let base2k: u64 = 1 << log_base2k;
|
||||||
let mask: u64 = base2k - 1;
|
let mask: u64 = base2k - 1;
|
||||||
let base2k_half: i64 = (base2k >> 1) as i64;
|
let base2k_half: i64 = (base2k >> 1) as i64;
|
||||||
(0..size).for_each(|j| {
|
(0..size).for_each(|j| {
|
||||||
a.at_poly_mut(col_i, j)
|
a.at_mut(col_a, j)
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
|
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
|
||||||
})
|
})
|
||||||
@@ -47,7 +47,7 @@ impl<B: Backend> Sampling for Module<B> {
|
|||||||
&self,
|
&self,
|
||||||
log_base2k: usize,
|
log_base2k: usize,
|
||||||
a: &mut VecZnx,
|
a: &mut VecZnx,
|
||||||
col_i: usize,
|
col_a: usize,
|
||||||
log_k: usize,
|
log_k: usize,
|
||||||
source: &mut Source,
|
source: &mut Source,
|
||||||
dist: D,
|
dist: D,
|
||||||
@@ -63,7 +63,7 @@ impl<B: Backend> Sampling for Module<B> {
|
|||||||
let log_base2k_rem: usize = log_k % log_base2k;
|
let log_base2k_rem: usize = log_k % log_base2k;
|
||||||
|
|
||||||
if log_base2k_rem != 0 {
|
if log_base2k_rem != 0 {
|
||||||
a.at_poly_mut(col_i, limb).iter_mut().for_each(|a| {
|
a.at_mut(col_a, limb).iter_mut().for_each(|a| {
|
||||||
let mut dist_f64: f64 = dist.sample(source);
|
let mut dist_f64: f64 = dist.sample(source);
|
||||||
while dist_f64.abs() > bound {
|
while dist_f64.abs() > bound {
|
||||||
dist_f64 = dist.sample(source)
|
dist_f64 = dist.sample(source)
|
||||||
@@ -71,7 +71,7 @@ impl<B: Backend> Sampling for Module<B> {
|
|||||||
*a += (dist_f64.round() as i64) << log_base2k_rem;
|
*a += (dist_f64.round() as i64) << log_base2k_rem;
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
a.at_poly_mut(col_i, limb).iter_mut().for_each(|a| {
|
a.at_mut(col_a, limb).iter_mut().for_each(|a| {
|
||||||
let mut dist_f64: f64 = dist.sample(source);
|
let mut dist_f64: f64 = dist.sample(source);
|
||||||
while dist_f64.abs() > bound {
|
while dist_f64.abs() > bound {
|
||||||
dist_f64 = dist.sample(source)
|
dist_f64 = dist.sample(source)
|
||||||
@@ -85,7 +85,7 @@ impl<B: Backend> Sampling for Module<B> {
|
|||||||
&self,
|
&self,
|
||||||
log_base2k: usize,
|
log_base2k: usize,
|
||||||
a: &mut VecZnx,
|
a: &mut VecZnx,
|
||||||
col_i: usize,
|
col_a: usize,
|
||||||
log_k: usize,
|
log_k: usize,
|
||||||
source: &mut Source,
|
source: &mut Source,
|
||||||
sigma: f64,
|
sigma: f64,
|
||||||
@@ -94,7 +94,7 @@ impl<B: Backend> Sampling for Module<B> {
|
|||||||
self.add_dist_f64(
|
self.add_dist_f64(
|
||||||
log_base2k,
|
log_base2k,
|
||||||
a,
|
a,
|
||||||
col_i,
|
col_a,
|
||||||
log_k,
|
log_k,
|
||||||
source,
|
source,
|
||||||
Normal::new(0.0, sigma).unwrap(),
|
Normal::new(0.0, sigma).unwrap(),
|
||||||
@@ -125,7 +125,7 @@ mod tests {
|
|||||||
(0..cols).for_each(|col_j| {
|
(0..cols).for_each(|col_j| {
|
||||||
if col_j != col_i {
|
if col_j != col_i {
|
||||||
(0..size).for_each(|limb_i| {
|
(0..size).for_each(|limb_i| {
|
||||||
assert_eq!(a.at_poly(col_j, limb_i), zero);
|
assert_eq!(a.at(col_j, limb_i), zero);
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
let std: f64 = a.std(col_i, log_base2k);
|
let std: f64 = a.std(col_i, log_base2k);
|
||||||
@@ -159,7 +159,7 @@ mod tests {
|
|||||||
(0..cols).for_each(|col_j| {
|
(0..cols).for_each(|col_j| {
|
||||||
if col_j != col_i {
|
if col_j != col_i {
|
||||||
(0..size).for_each(|limb_i| {
|
(0..size).for_each(|limb_i| {
|
||||||
assert_eq!(a.at_poly(col_j, limb_i), zero);
|
assert_eq!(a.at(col_j, limb_i), zero);
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
let std: f64 = a.std(col_i, log_base2k) * k_f64;
|
let std: f64 = a.std(col_i, log_base2k) * k_f64;
|
||||||
|
|||||||
@@ -341,7 +341,7 @@ mod tests {
|
|||||||
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
|
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
|
||||||
|
|
||||||
// a <- AUTO(a)
|
// a <- AUTO(a)
|
||||||
module.vec_znx_automorphism_inplace(p, &mut a);
|
module.vec_znx_automorphism_inplace(p, &mut a, 0);
|
||||||
|
|
||||||
// b_dft <- DFT(AUTO(a))
|
// b_dft <- DFT(AUTO(a))
|
||||||
module.vec_znx_dft(&mut b_dft, &a);
|
module.vec_znx_dft(&mut b_dft, &a);
|
||||||
|
|||||||
Reference in New Issue
Block a user