mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add schemes (#71)
* Move br + cbt to schemes/tfhe * refactor blind rotation * refactor circuit bootstrapping * renamed exec -> prepared
This commit is contained in:
committed by
GitHub
parent
8d9897b88b
commit
c7219c35e9
@@ -3,8 +3,7 @@ use backend::{
|
||||
api::{
|
||||
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal,
|
||||
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace,
|
||||
VecZnxDecodeVeci64, VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VecZnxEncodeVeci64,
|
||||
VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
|
||||
VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
|
||||
},
|
||||
layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft},
|
||||
},
|
||||
@@ -73,7 +72,7 @@ fn main() {
|
||||
let mut want: Vec<i64> = vec![0; n];
|
||||
want.iter_mut()
|
||||
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
|
||||
module.encode_vec_i64(basek, &mut m, 0, log_scale, &want, 4);
|
||||
m.encode_vec_i64(basek, 0, log_scale, &want, 4);
|
||||
module.vec_znx_normalize_inplace(basek, &mut m, 0, scratch.borrow());
|
||||
|
||||
// m - BIG(ct[1] * s)
|
||||
@@ -132,8 +131,7 @@ fn main() {
|
||||
|
||||
// have = m * 2^{log_scale} + e
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
module.decode_vec_i64(basek, &mut res, 0, ct_size * basek, &mut have);
|
||||
|
||||
res.decode_vec_i64(basek, 0, ct_size * basek, &mut have);
|
||||
let scale: f64 = (1 << (res.size() * basek - log_scale)) as f64;
|
||||
izip!(want.iter(), have.iter())
|
||||
.enumerate()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use rand_distr::Distribution;
|
||||
use rug::Float;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef};
|
||||
@@ -198,13 +197,6 @@ pub trait VecZnxCopy {
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxStd {
|
||||
/// Returns the standard devaition of the i-th polynomial.
|
||||
fn vec_znx_std<A>(&self, basek: usize, a: &A, a_col: usize) -> f64
|
||||
where
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxFillUniform {
|
||||
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
|
||||
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
@@ -269,75 +261,3 @@ pub trait VecZnxAddNormal {
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxEncodeVeci64 {
|
||||
/// encode a vector of i64 on the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `data`: data to encode on the receiver.
|
||||
/// * `log_max`: base two logarithm of the infinity norm of the input data.
|
||||
fn encode_vec_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, data: &[i64], log_max: usize)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxEncodeCoeffsi64 {
|
||||
/// encodes a single i64 on the receiver at the given index.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `res_col`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `i`: index of the coefficient on which to encode the data.
|
||||
/// * `data`: data to encode on the receiver.
|
||||
/// * `log_max`: base two logarithm of the infinity norm of the input data.
|
||||
fn encode_coeff_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, i: usize, data: i64, log_max: usize)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxDecodeVeci64 {
|
||||
/// decode a vector of i64 from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `res_col`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two logarithm of the scaling of the data.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxDecodeCoeffsi64 {
|
||||
/// decode a single of i64 from the receiver at the given index.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `res_col`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `i`: index of the coefficient to decode.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_coeff_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxDecodeVecFloat {
|
||||
/// decode a vector of Float from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_float<R>(&self, basek: usize, res: &R, col_i: usize, data: &mut [Float])
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
|
||||
@@ -3,22 +3,19 @@ use sampling::source::Source;
|
||||
use crate::hal::{
|
||||
api::{
|
||||
VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism,
|
||||
VecZnxAutomorphismInplace, VecZnxCopy, VecZnxDecodeCoeffsi64, VecZnxDecodeVecFloat, VecZnxDecodeVeci64,
|
||||
VecZnxEncodeCoeffsi64, VecZnxEncodeVeci64, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, VecZnxLshInplace,
|
||||
VecZnxAutomorphismInplace, VecZnxCopy, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, VecZnxLshInplace,
|
||||
VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize,
|
||||
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSplit,
|
||||
VecZnxStd, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree,
|
||||
VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree,
|
||||
},
|
||||
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxDecodeCoeffsi64Impl,
|
||||
VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, VecZnxEncodeVeci64Impl,
|
||||
VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl,
|
||||
VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl,
|
||||
VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl,
|
||||
VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl,
|
||||
VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl,
|
||||
VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl,
|
||||
VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl,
|
||||
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -325,18 +322,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxStd for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxStdImpl<B>,
|
||||
{
|
||||
fn vec_znx_std<A>(&self, basek: usize, a: &A, a_col: usize) -> f64
|
||||
where
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_std_impl(self, basek, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxFillUniform for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxFillUniformImpl<B>,
|
||||
@@ -428,63 +413,3 @@ where
|
||||
B::vec_znx_add_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxEncodeVeci64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxEncodeVeci64Impl<B>,
|
||||
{
|
||||
fn encode_vec_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, data: &[i64], log_max: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::encode_vec_i64_impl(self, basek, res, res_col, k, data, log_max);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxEncodeCoeffsi64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxEncodeCoeffsi64Impl<B>,
|
||||
{
|
||||
fn encode_coeff_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, i: usize, data: i64, log_max: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::encode_coeff_i64_impl(self, basek, res, res_col, k, i, data, log_max);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDecodeVeci64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDecodeVeci64Impl<B>,
|
||||
{
|
||||
fn decode_vec_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
B::decode_vec_i64_impl(self, basek, res, res_col, k, data);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDecodeCoeffsi64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDecodeCoeffsi64Impl<B>,
|
||||
{
|
||||
fn decode_coeff_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
B::decode_coeff_i64_impl(self, basek, res, res_col, k, i)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDecodeVecFloat for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDecodeVecFloatImpl<B>,
|
||||
{
|
||||
fn decode_vec_float<R>(&self, basek: usize, res: &R, col_i: usize, data: &mut [rug::Float])
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
B::decode_vec_float_impl(self, basek, res, col_i, data);
|
||||
}
|
||||
}
|
||||
|
||||
204
backend/src/hal/layouts/encoding.rs
Normal file
204
backend/src/hal/layouts/encoding.rs
Normal file
@@ -0,0 +1,204 @@
|
||||
use itertools::izip;
|
||||
use rug::{Assign, Float};
|
||||
|
||||
use crate::hal::{
|
||||
api::{ZnxInfos, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef},
|
||||
};
|
||||
|
||||
impl<D: DataMut> VecZnx<D> {
|
||||
pub fn encode_vec_i64(&mut self, basek: usize, col: usize, k: usize, data: &[i64], log_max: usize) {
|
||||
let size: usize = k.div_ceil(basek);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&mut [u8]> = self.to_mut();
|
||||
assert!(
|
||||
size <= a.size(),
|
||||
"invalid argument k: k.div_ceil(basek)={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
assert!(col < a.cols());
|
||||
assert!(data.len() <= a.n())
|
||||
}
|
||||
|
||||
let data_len: usize = data.len();
|
||||
let mut a: VecZnx<&mut [u8]> = self.to_mut();
|
||||
let k_rem: usize = basek - (k % basek);
|
||||
|
||||
// Zeroes coefficients of the i-th column
|
||||
(0..a.size()).for_each(|i| {
|
||||
a.zero_at(col, i);
|
||||
});
|
||||
|
||||
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + k_rem < 63 || k_rem == basek {
|
||||
a.at_mut(col, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
|
||||
} else {
|
||||
let mask: i64 = (1 << basek) - 1;
|
||||
let steps: usize = size.min(log_max.div_ceil(basek));
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(i, i_rev)| {
|
||||
let shift: usize = i * basek;
|
||||
izip!(a.at_mut(col, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
|
||||
})
|
||||
}
|
||||
|
||||
// Case where self.prec % self.k != 0.
|
||||
if k_rem != basek {
|
||||
let steps: usize = size.min(log_max.div_ceil(basek));
|
||||
(size - steps..size).rev().for_each(|i| {
|
||||
a.at_mut(col, i)[..data_len]
|
||||
.iter_mut()
|
||||
.for_each(|x| *x <<= k_rem);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encode_coeff_i64(&mut self, basek: usize, col: usize, k: usize, idx: usize, data: i64, log_max: usize) {
|
||||
let size: usize = k.div_ceil(basek);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&mut [u8]> = self.to_mut();
|
||||
assert!(idx < a.n());
|
||||
assert!(
|
||||
size <= a.size(),
|
||||
"invalid argument k: k.div_ceil(basek)={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
assert!(col < a.cols());
|
||||
}
|
||||
|
||||
let k_rem: usize = basek - (k % basek);
|
||||
let mut a: VecZnx<&mut [u8]> = self.to_mut();
|
||||
(0..a.size()).for_each(|j| a.at_mut(col, j)[idx] = 0);
|
||||
|
||||
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + k_rem < 63 || k_rem == basek {
|
||||
a.at_mut(col, size - 1)[idx] = data;
|
||||
} else {
|
||||
let mask: i64 = (1 << basek) - 1;
|
||||
let steps: usize = size.min(log_max.div_ceil(basek));
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(j, j_rev)| {
|
||||
a.at_mut(col, j_rev)[idx] = (data >> (j * basek)) & mask;
|
||||
})
|
||||
}
|
||||
|
||||
// Case where prec % k != 0.
|
||||
if k_rem != basek {
|
||||
let steps: usize = size.min(log_max.div_ceil(basek));
|
||||
(size - steps..size).rev().for_each(|j| {
|
||||
a.at_mut(col, j)[idx] <<= k_rem;
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> VecZnx<D> {
|
||||
pub fn decode_vec_i64(&self, basek: usize, col: usize, k: usize, data: &mut [i64]) {
|
||||
let size: usize = k.div_ceil(basek);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&[u8]> = self.to_ref();
|
||||
assert!(
|
||||
data.len() >= a.n(),
|
||||
"invalid data: data.len()={} < a.n()={}",
|
||||
data.len(),
|
||||
a.n()
|
||||
);
|
||||
assert!(col < a.cols());
|
||||
}
|
||||
|
||||
let a: VecZnx<&[u8]> = self.to_ref();
|
||||
data.copy_from_slice(a.at(col, 0));
|
||||
let rem: usize = basek - (k % basek);
|
||||
if k < basek {
|
||||
data.iter_mut().for_each(|x| *x >>= rem);
|
||||
} else {
|
||||
(1..size).for_each(|i| {
|
||||
if i == size - 1 && rem != basek {
|
||||
let k_rem: usize = basek - rem;
|
||||
izip!(a.at(col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << k_rem) + (x >> rem);
|
||||
});
|
||||
} else {
|
||||
izip!(a.at(col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << basek) + x;
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_coeff_i64(&self, basek: usize, col: usize, k: usize, idx: usize) -> i64 {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&[u8]> = self.to_ref();
|
||||
assert!(idx < a.n());
|
||||
assert!(col < a.cols())
|
||||
}
|
||||
|
||||
let a: VecZnx<&[u8]> = self.to_ref();
|
||||
let size: usize = k.div_ceil(basek);
|
||||
let mut res: i64 = 0;
|
||||
let rem: usize = basek - (k % basek);
|
||||
(0..size).for_each(|j| {
|
||||
let x: i64 = a.at(col, j)[idx];
|
||||
if j == size - 1 && rem != basek {
|
||||
let k_rem: usize = basek - rem;
|
||||
res = (res << k_rem) + (x >> rem);
|
||||
} else {
|
||||
res = (res << basek) + x;
|
||||
}
|
||||
});
|
||||
res
|
||||
}
|
||||
|
||||
pub fn decode_vec_float(&self, basek: usize, col: usize, data: &mut [Float]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&[u8]> = self.to_ref();
|
||||
assert!(
|
||||
data.len() >= a.n(),
|
||||
"invalid data: data.len()={} < a.n()={}",
|
||||
data.len(),
|
||||
a.n()
|
||||
);
|
||||
assert!(col < a.cols());
|
||||
}
|
||||
|
||||
let a: VecZnx<&[u8]> = self.to_ref();
|
||||
let size: usize = a.size();
|
||||
let prec: u32 = (basek * size) as u32;
|
||||
|
||||
// 2^{basek}
|
||||
let base = Float::with_val(prec, (1 << basek) as f64);
|
||||
|
||||
// y[i] = sum x[j][i] * 2^{-basek*j}
|
||||
(0..size).for_each(|i| {
|
||||
if i == 0 {
|
||||
izip!(a.at(col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
y.assign(*x);
|
||||
*y /= &base;
|
||||
});
|
||||
} else {
|
||||
izip!(a.at(col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y += Float::with_val(prec, *x);
|
||||
*y /= &base;
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
mod encoding;
|
||||
mod mat_znx;
|
||||
mod module;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod serialization;
|
||||
mod stats;
|
||||
mod svp_ppol;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
|
||||
32
backend/src/hal/layouts/stats.rs
Normal file
32
backend/src/hal/layouts/stats.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use rug::{
|
||||
Float,
|
||||
float::Round,
|
||||
ops::{AddAssignRound, DivAssignRound, SubAssignRound},
|
||||
};
|
||||
|
||||
use crate::hal::{
|
||||
api::ZnxInfos,
|
||||
layouts::{DataRef, VecZnx},
|
||||
};
|
||||
|
||||
impl<D: DataRef> VecZnx<D> {
|
||||
pub fn std(&self, basek: usize, col: usize) -> f64 {
|
||||
let prec: u32 = (self.size() * basek) as u32;
|
||||
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
|
||||
self.decode_vec_float(basek, col, &mut data);
|
||||
// std = sqrt(sum((xi - avg)^2) / n)
|
||||
let mut avg: Float = Float::with_val(prec, 0);
|
||||
data.iter().for_each(|x| {
|
||||
avg.add_assign_round(x, Round::Nearest);
|
||||
});
|
||||
avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
|
||||
data.iter_mut().for_each(|x| {
|
||||
x.sub_assign_round(&avg, Round::Nearest);
|
||||
});
|
||||
let mut std: Float = Float::with_val(prec, 0);
|
||||
data.iter().for_each(|x| std += x * x);
|
||||
std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
|
||||
std = std.sqrt();
|
||||
std.to_f64()
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
use rand_distr::Distribution;
|
||||
use rug::Float;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef};
|
||||
@@ -288,15 +287,6 @@ pub unsafe trait VecZnxCopyImpl<B: Backend> {
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::api::VecZnxStd] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxStdImpl<B: Backend> {
|
||||
fn vec_znx_std_impl<A>(module: &Module<B>, basek: usize, a: &A, a_col: usize) -> f64
|
||||
where
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::api::VecZnxFillUniform] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
@@ -373,68 +363,3 @@ pub unsafe trait VecZnxAddNormalImpl<B: Backend> {
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxEncodeVeci64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxEncodeVeci64Impl<B: Backend> {
|
||||
fn encode_vec_i64_impl<R>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
data: &[i64],
|
||||
log_max: usize,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxEncodeCoeffsi64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxEncodeCoeffsi64Impl<B: Backend> {
|
||||
fn encode_coeff_i64_impl<R>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
i: usize,
|
||||
data: i64,
|
||||
log_max: usize,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxDecodeVeci64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxDecodeVeci64Impl<B: Backend> {
|
||||
fn decode_vec_i64_impl<R>(module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxDecodeCoeffsi64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxDecodeCoeffsi64Impl<B: Backend> {
|
||||
fn decode_coeff_i64_impl<R>(module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxDecodeVecFloat] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxDecodeVecFloatImpl<B: Backend> {
|
||||
fn decode_vec_float_impl<R>(module: &Module<B>, basek: usize, res: &R, res_col: usize, data: &mut [Float])
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
|
||||
52
backend/src/hal/tests/vec_znx/encoding.rs
Normal file
52
backend/src/hal/tests/vec_znx/encoding.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::{
|
||||
api::{ZnxInfos, ZnxViewMut},
|
||||
layouts::VecZnx,
|
||||
};
|
||||
|
||||
pub fn test_vec_znx_encode_vec_i64_lo_norm() {
|
||||
let n: usize = 32;
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
let k: usize = size * basek - 5;
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
have.iter_mut()
|
||||
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
|
||||
a.encode_vec_i64(basek, col_i, k, &have, 10);
|
||||
let mut want: Vec<i64> = vec![i64::default(); n];
|
||||
a.decode_vec_i64(basek, col_i, k, &mut want);
|
||||
assert_eq!(have, want, "{:?} != {:?}", &have, &want);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_encode_vec_i64_hi_norm() {
|
||||
let n: usize = 32;
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
for k in [1, basek / 2, size * basek - 5] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
have.iter_mut().for_each(|x| {
|
||||
if k < 64 {
|
||||
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
|
||||
} else {
|
||||
*x = source.next_i64();
|
||||
}
|
||||
});
|
||||
a.encode_vec_i64(basek, col_i, k, &have, 63);
|
||||
let mut want: Vec<i64> = vec![i64::default(); n];
|
||||
a.decode_vec_i64(basek, col_i, k, &mut want);
|
||||
assert_eq!(have, want, "{:?} != {:?}", &have, &want);
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,13 @@
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::{
|
||||
api::{VecZnxAddNormal, VecZnxDecodeVeci64, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxStd, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
api::{VecZnxAddNormal, VecZnxFillUniform, ZnxView},
|
||||
layouts::{Backend, Module, VecZnx},
|
||||
};
|
||||
|
||||
pub fn test_vec_znx_fill_uniform<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxFillUniform + VecZnxStd,
|
||||
Module<B>: VecZnxFillUniform,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 17;
|
||||
@@ -26,7 +25,7 @@ where
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = module.vec_znx_std(basek, &a, col_i);
|
||||
let std: f64 = a.std(basek, col_i);
|
||||
assert!(
|
||||
(std - one_12_sqrt).abs() < 0.01,
|
||||
"std={} ~!= {}",
|
||||
@@ -40,7 +39,7 @@ where
|
||||
|
||||
pub fn test_vec_znx_add_normal<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxAddNormal + VecZnxStd,
|
||||
Module<B>: VecZnxAddNormal,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 17;
|
||||
@@ -61,61 +60,9 @@ where
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = module.vec_znx_std(basek, &a, col_i) * k_f64;
|
||||
let std: f64 = a.std(basek, col_i) * k_f64;
|
||||
assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_encode_vec_i64_lo_norm<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
let k: usize = size * basek - 5;
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
have.iter_mut()
|
||||
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
|
||||
module.encode_vec_i64(basek, &mut a, col_i, k, &have, 10);
|
||||
let mut want: Vec<i64> = vec![i64::default(); n];
|
||||
module.decode_vec_i64(basek, &a, col_i, k, &mut want);
|
||||
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
});
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_encode_vec_i64_hi_norm<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
for k in [1, basek / 2, size * basek - 5] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
have.iter_mut().for_each(|x| {
|
||||
if k < 64 {
|
||||
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
|
||||
} else {
|
||||
*x = source.next_i64();
|
||||
}
|
||||
});
|
||||
module.encode_vec_i64(basek, &mut a, col_i, k, &have, 63);
|
||||
let mut want: Vec<i64> = vec![i64::default(); n];
|
||||
module.decode_vec_i64(basek, &a, col_i, k, &mut want);
|
||||
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
mod generics;
|
||||
pub use generics::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod encoding;
|
||||
|
||||
@@ -2,10 +2,7 @@ use crate::{
|
||||
hal::{
|
||||
api::ModuleNew,
|
||||
layouts::Module,
|
||||
tests::vec_znx::{
|
||||
test_vec_znx_add_normal, test_vec_znx_encode_vec_i64_hi_norm, test_vec_znx_encode_vec_i64_lo_norm,
|
||||
test_vec_znx_fill_uniform,
|
||||
},
|
||||
tests::vec_znx::{test_vec_znx_add_normal, test_vec_znx_fill_uniform},
|
||||
},
|
||||
implementation::cpu_spqlios::FFT64,
|
||||
};
|
||||
@@ -21,15 +18,3 @@ fn test_vec_znx_add_normal_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 12);
|
||||
test_vec_znx_add_normal(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_encode_vec_lo_norm_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 8);
|
||||
test_vec_znx_encode_vec_i64_lo_norm(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_encode_vec_hi_norm_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 8);
|
||||
test_vec_znx_encode_vec_i64_hi_norm(&module);
|
||||
}
|
||||
|
||||
@@ -1,29 +1,22 @@
|
||||
use itertools::izip;
|
||||
use rand_distr::Normal;
|
||||
use rug::{
|
||||
Assign, Float,
|
||||
float::Round,
|
||||
ops::{AddAssignRound, DivAssignRound, SubAssignRound},
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{
|
||||
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxDecodeVecFloat, VecZnxFillDistF64,
|
||||
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
ZnxViewMut, ZnxZero,
|
||||
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate,
|
||||
VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
},
|
||||
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxDecodeCoeffsi64Impl,
|
||||
VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, VecZnxEncodeVeci64Impl,
|
||||
VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl,
|
||||
VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl,
|
||||
VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl,
|
||||
VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl,
|
||||
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl,
|
||||
VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl,
|
||||
VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
|
||||
VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl,
|
||||
VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
@@ -857,35 +850,6 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxStdImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn vec_znx_std_impl<A>(module: &Module<B>, basek: usize, a: &A, a_col: usize) -> f64
|
||||
where
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let prec: u32 = (a.size() * basek) as u32;
|
||||
let mut data: Vec<Float> = (0..a.n()).map(|_| Float::with_val(prec, 0)).collect();
|
||||
module.decode_vec_float(basek, &a, a_col, &mut data);
|
||||
// std = sqrt(sum((xi - avg)^2) / n)
|
||||
let mut avg: Float = Float::with_val(prec, 0);
|
||||
data.iter().for_each(|x| {
|
||||
avg.add_assign_round(x, Round::Nearest);
|
||||
});
|
||||
avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
|
||||
data.iter_mut().for_each(|x| {
|
||||
x.sub_assign_round(&avg, Round::Nearest);
|
||||
});
|
||||
let mut std: Float = Float::with_val(prec, 0);
|
||||
data.iter().for_each(|x| std += x * x);
|
||||
std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
|
||||
std = std.sqrt();
|
||||
std.to_f64()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxFillUniformImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
@@ -1053,251 +1017,3 @@ where
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxEncodeVeci64Impl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn encode_vec_i64_impl<R>(
|
||||
_module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
data: &[i64],
|
||||
log_max: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let size: usize = k.div_ceil(basek);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
size <= a.size(),
|
||||
"invalid argument k: k.div_ceil(basek)={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
assert!(res_col < a.cols());
|
||||
assert!(data.len() <= a.n())
|
||||
}
|
||||
|
||||
let data_len: usize = data.len();
|
||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let k_rem: usize = basek - (k % basek);
|
||||
|
||||
// Zeroes coefficients of the i-th column
|
||||
(0..a.size()).for_each(|i| {
|
||||
a.zero_at(res_col, i);
|
||||
});
|
||||
|
||||
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + k_rem < 63 || k_rem == basek {
|
||||
a.at_mut(res_col, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
|
||||
} else {
|
||||
let mask: i64 = (1 << basek) - 1;
|
||||
let steps: usize = size.min(log_max.div_ceil(basek));
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(i, i_rev)| {
|
||||
let shift: usize = i * basek;
|
||||
izip!(a.at_mut(res_col, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
|
||||
})
|
||||
}
|
||||
|
||||
// Case where self.prec % self.k != 0.
|
||||
if k_rem != basek {
|
||||
let steps: usize = size.min(log_max.div_ceil(basek));
|
||||
(size - steps..size).rev().for_each(|i| {
|
||||
a.at_mut(res_col, i)[..data_len]
|
||||
.iter_mut()
|
||||
.for_each(|x| *x <<= k_rem);
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxEncodeCoeffsi64Impl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn encode_coeff_i64_impl<R>(
|
||||
_module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
i: usize,
|
||||
data: i64,
|
||||
log_max: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let size: usize = k.div_ceil(basek);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
assert!(i < a.n());
|
||||
assert!(
|
||||
size <= a.size(),
|
||||
"invalid argument k: k.div_ceil(basek)={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
assert!(res_col < a.cols());
|
||||
}
|
||||
|
||||
let k_rem: usize = basek - (k % basek);
|
||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
(0..a.size()).for_each(|j| a.at_mut(res_col, j)[i] = 0);
|
||||
|
||||
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + k_rem < 63 || k_rem == basek {
|
||||
a.at_mut(res_col, size - 1)[i] = data;
|
||||
} else {
|
||||
let mask: i64 = (1 << basek) - 1;
|
||||
let steps: usize = size.min(log_max.div_ceil(basek));
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(j, j_rev)| {
|
||||
a.at_mut(res_col, j_rev)[i] = (data >> (j * basek)) & mask;
|
||||
})
|
||||
}
|
||||
|
||||
// Case where prec % k != 0.
|
||||
if k_rem != basek {
|
||||
let steps: usize = size.min(log_max.div_ceil(basek));
|
||||
(size - steps..size).rev().for_each(|j| {
|
||||
a.at_mut(res_col, j)[i] <<= k_rem;
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxDecodeVeci64Impl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn decode_vec_i64_impl<R>(_module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
let size: usize = k.div_ceil(basek);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&[u8]> = res.to_ref();
|
||||
assert!(
|
||||
data.len() >= a.n(),
|
||||
"invalid data: data.len()={} < a.n()={}",
|
||||
data.len(),
|
||||
a.n()
|
||||
);
|
||||
assert!(res_col < a.cols());
|
||||
}
|
||||
|
||||
let a: VecZnx<&[u8]> = res.to_ref();
|
||||
data.copy_from_slice(a.at(res_col, 0));
|
||||
let rem: usize = basek - (k % basek);
|
||||
if k < basek {
|
||||
data.iter_mut().for_each(|x| *x >>= rem);
|
||||
} else {
|
||||
(1..size).for_each(|i| {
|
||||
if i == size - 1 && rem != basek {
|
||||
let k_rem: usize = basek - rem;
|
||||
izip!(a.at(res_col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << k_rem) + (x >> rem);
|
||||
});
|
||||
} else {
|
||||
izip!(a.at(res_col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << basek) + x;
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxDecodeCoeffsi64Impl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn decode_coeff_i64_impl<R>(_module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&[u8]> = res.to_ref();
|
||||
assert!(i < a.n());
|
||||
assert!(res_col < a.cols())
|
||||
}
|
||||
|
||||
let a: VecZnx<&[u8]> = res.to_ref();
|
||||
let size: usize = k.div_ceil(basek);
|
||||
let mut res: i64 = 0;
|
||||
let rem: usize = basek - (k % basek);
|
||||
(0..size).for_each(|j| {
|
||||
let x: i64 = a.at(res_col, j)[i];
|
||||
if j == size - 1 && rem != basek {
|
||||
let k_rem: usize = basek - rem;
|
||||
res = (res << k_rem) + (x >> rem);
|
||||
} else {
|
||||
res = (res << basek) + x;
|
||||
}
|
||||
});
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxDecodeVecFloatImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn decode_vec_float_impl<R>(_module: &Module<B>, basek: usize, res: &R, res_col: usize, data: &mut [Float])
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&[u8]> = res.to_ref();
|
||||
assert!(
|
||||
data.len() >= a.n(),
|
||||
"invalid data: data.len()={} < a.n()={}",
|
||||
data.len(),
|
||||
a.n()
|
||||
);
|
||||
assert!(res_col < a.cols());
|
||||
}
|
||||
|
||||
let a: VecZnx<&[u8]> = res.to_ref();
|
||||
let size: usize = a.size();
|
||||
let prec: u32 = (basek * size) as u32;
|
||||
|
||||
// 2^{basek}
|
||||
let base = Float::with_val(prec, (1 << basek) as f64);
|
||||
|
||||
// y[i] = sum x[j][i] * 2^{-basek*j}
|
||||
(0..size).for_each(|i| {
|
||||
if i == 0 {
|
||||
izip!(a.at(res_col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
y.assign(*x);
|
||||
*y /= &base;
|
||||
});
|
||||
} else {
|
||||
izip!(a.at(res_col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y += Float::with_val(prec, *x);
|
||||
*y /= &base;
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user