Add schemes (#71)

* Move br + cbt to schemes/tfhe

* refactor blind rotation

* refactor circuit bootstrapping

* renamed exec -> prepared
This commit is contained in:
Jean-Philippe Bossuat
2025-08-15 15:06:26 +02:00
committed by GitHub
parent 8d9897b88b
commit c7219c35e9
130 changed files with 2631 additions and 3270 deletions

View File

@@ -3,8 +3,7 @@ use backend::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal,
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace,
VecZnxDecodeVeci64, VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VecZnxEncodeVeci64,
VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
},
layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft},
},
@@ -73,7 +72,7 @@ fn main() {
let mut want: Vec<i64> = vec![0; n];
want.iter_mut()
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
module.encode_vec_i64(basek, &mut m, 0, log_scale, &want, 4);
m.encode_vec_i64(basek, 0, log_scale, &want, 4);
module.vec_znx_normalize_inplace(basek, &mut m, 0, scratch.borrow());
// m - BIG(ct[1] * s)
@@ -132,8 +131,7 @@ fn main() {
// have = m * 2^{log_scale} + e
let mut have: Vec<i64> = vec![i64::default(); n];
module.decode_vec_i64(basek, &mut res, 0, ct_size * basek, &mut have);
res.decode_vec_i64(basek, 0, ct_size * basek, &mut have);
let scale: f64 = (1 << (res.size() * basek - log_scale)) as f64;
izip!(want.iter(), have.iter())
.enumerate()

View File

@@ -1,5 +1,4 @@
use rand_distr::Distribution;
use rug::Float;
use sampling::source::Source;
use crate::hal::layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef};
@@ -198,13 +197,6 @@ pub trait VecZnxCopy {
A: VecZnxToRef;
}
pub trait VecZnxStd {
/// Returns the standard devaition of the i-th polynomial.
fn vec_znx_std<A>(&self, basek: usize, a: &A, a_col: usize) -> f64
where
A: VecZnxToRef;
}
pub trait VecZnxFillUniform {
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
@@ -269,75 +261,3 @@ pub trait VecZnxAddNormal {
) where
R: VecZnxToMut;
}
pub trait VecZnxEncodeVeci64 {
/// encode a vector of i64 on the receiver.
///
/// # Arguments
///
/// * `col_i`: the index of the poly where to encode the data.
/// * `basek`: base two negative logarithm decomposition of the receiver.
/// * `k`: base two negative logarithm of the scaling of the data.
/// * `data`: data to encode on the receiver.
/// * `log_max`: base two logarithm of the infinity norm of the input data.
fn encode_vec_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, data: &[i64], log_max: usize)
where
R: VecZnxToMut;
}
pub trait VecZnxEncodeCoeffsi64 {
/// encodes a single i64 on the receiver at the given index.
///
/// # Arguments
///
/// * `res_col`: the index of the poly where to encode the data.
/// * `basek`: base two negative logarithm decomposition of the receiver.
/// * `k`: base two negative logarithm of the scaling of the data.
/// * `i`: index of the coefficient on which to encode the data.
/// * `data`: data to encode on the receiver.
/// * `log_max`: base two logarithm of the infinity norm of the input data.
fn encode_coeff_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, i: usize, data: i64, log_max: usize)
where
R: VecZnxToMut;
}
pub trait VecZnxDecodeVeci64 {
/// decode a vector of i64 from the receiver.
///
/// # Arguments
///
/// * `res_col`: the index of the poly where to encode the data.
/// * `basek`: base two negative logarithm decomposition of the receiver.
/// * `k`: base two logarithm of the scaling of the data.
/// * `data`: data to decode from the receiver.
fn decode_vec_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
where
R: VecZnxToRef;
}
pub trait VecZnxDecodeCoeffsi64 {
/// decode a single of i64 from the receiver at the given index.
///
/// # Arguments
///
/// * `res_col`: the index of the poly where to encode the data.
/// * `basek`: base two negative logarithm decomposition of the receiver.
/// * `k`: base two negative logarithm of the scaling of the data.
/// * `i`: index of the coefficient to decode.
/// * `data`: data to decode from the receiver.
fn decode_coeff_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
where
R: VecZnxToRef;
}
pub trait VecZnxDecodeVecFloat {
/// decode a vector of Float from the receiver.
///
/// # Arguments
/// * `col_i`: the index of the poly where to encode the data.
/// * `basek`: base two negative logarithm decomposition of the receiver.
/// * `data`: data to decode from the receiver.
fn decode_vec_float<R>(&self, basek: usize, res: &R, col_i: usize, data: &mut [Float])
where
R: VecZnxToRef;
}

View File

@@ -3,22 +3,19 @@ use sampling::source::Source;
use crate::hal::{
api::{
VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism,
VecZnxAutomorphismInplace, VecZnxCopy, VecZnxDecodeCoeffsi64, VecZnxDecodeVecFloat, VecZnxDecodeVeci64,
VecZnxEncodeCoeffsi64, VecZnxEncodeVeci64, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, VecZnxLshInplace,
VecZnxAutomorphismInplace, VecZnxCopy, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, VecZnxLshInplace,
VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSplit,
VecZnxStd, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree,
VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree,
},
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
oep::{
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxDecodeCoeffsi64Impl,
VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, VecZnxEncodeVeci64Impl,
VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl,
VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl,
VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl,
VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl,
VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl,
VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl,
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
},
};
@@ -325,18 +322,6 @@ where
}
}
impl<B> VecZnxStd for Module<B>
where
B: Backend + VecZnxStdImpl<B>,
{
fn vec_znx_std<A>(&self, basek: usize, a: &A, a_col: usize) -> f64
where
A: VecZnxToRef,
{
B::vec_znx_std_impl(self, basek, a, a_col)
}
}
impl<B> VecZnxFillUniform for Module<B>
where
B: Backend + VecZnxFillUniformImpl<B>,
@@ -428,63 +413,3 @@ where
B::vec_znx_add_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
}
}
impl<B> VecZnxEncodeVeci64 for Module<B>
where
B: Backend + VecZnxEncodeVeci64Impl<B>,
{
fn encode_vec_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, data: &[i64], log_max: usize)
where
R: VecZnxToMut,
{
B::encode_vec_i64_impl(self, basek, res, res_col, k, data, log_max);
}
}
impl<B> VecZnxEncodeCoeffsi64 for Module<B>
where
B: Backend + VecZnxEncodeCoeffsi64Impl<B>,
{
fn encode_coeff_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, i: usize, data: i64, log_max: usize)
where
R: VecZnxToMut,
{
B::encode_coeff_i64_impl(self, basek, res, res_col, k, i, data, log_max);
}
}
impl<B> VecZnxDecodeVeci64 for Module<B>
where
B: Backend + VecZnxDecodeVeci64Impl<B>,
{
fn decode_vec_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
where
R: VecZnxToRef,
{
B::decode_vec_i64_impl(self, basek, res, res_col, k, data);
}
}
impl<B> VecZnxDecodeCoeffsi64 for Module<B>
where
B: Backend + VecZnxDecodeCoeffsi64Impl<B>,
{
fn decode_coeff_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
where
R: VecZnxToRef,
{
B::decode_coeff_i64_impl(self, basek, res, res_col, k, i)
}
}
impl<B> VecZnxDecodeVecFloat for Module<B>
where
B: Backend + VecZnxDecodeVecFloatImpl<B>,
{
fn decode_vec_float<R>(&self, basek: usize, res: &R, col_i: usize, data: &mut [rug::Float])
where
R: VecZnxToRef,
{
B::decode_vec_float_impl(self, basek, res, col_i, data);
}
}

View File

@@ -0,0 +1,204 @@
use itertools::izip;
use rug::{Assign, Float};
use crate::hal::{
api::{ZnxInfos, ZnxView, ZnxViewMut, ZnxZero},
layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef},
};
impl<D: DataMut> VecZnx<D> {
pub fn encode_vec_i64(&mut self, basek: usize, col: usize, k: usize, data: &[i64], log_max: usize) {
let size: usize = k.div_ceil(basek);
#[cfg(debug_assertions)]
{
let a: VecZnx<&mut [u8]> = self.to_mut();
assert!(
size <= a.size(),
"invalid argument k: k.div_ceil(basek)={} > a.size()={}",
size,
a.size()
);
assert!(col < a.cols());
assert!(data.len() <= a.n())
}
let data_len: usize = data.len();
let mut a: VecZnx<&mut [u8]> = self.to_mut();
let k_rem: usize = basek - (k % basek);
// Zeroes coefficients of the i-th column
(0..a.size()).for_each(|i| {
a.zero_at(col, i);
});
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == basek {
a.at_mut(col, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
} else {
let mask: i64 = (1 << basek) - 1;
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size)
.rev()
.enumerate()
.for_each(|(i, i_rev)| {
let shift: usize = i * basek;
izip!(a.at_mut(col, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
})
}
// Case where self.prec % self.k != 0.
if k_rem != basek {
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size).rev().for_each(|i| {
a.at_mut(col, i)[..data_len]
.iter_mut()
.for_each(|x| *x <<= k_rem);
})
}
}
pub fn encode_coeff_i64(&mut self, basek: usize, col: usize, k: usize, idx: usize, data: i64, log_max: usize) {
let size: usize = k.div_ceil(basek);
#[cfg(debug_assertions)]
{
let a: VecZnx<&mut [u8]> = self.to_mut();
assert!(idx < a.n());
assert!(
size <= a.size(),
"invalid argument k: k.div_ceil(basek)={} > a.size()={}",
size,
a.size()
);
assert!(col < a.cols());
}
let k_rem: usize = basek - (k % basek);
let mut a: VecZnx<&mut [u8]> = self.to_mut();
(0..a.size()).for_each(|j| a.at_mut(col, j)[idx] = 0);
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == basek {
a.at_mut(col, size - 1)[idx] = data;
} else {
let mask: i64 = (1 << basek) - 1;
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size)
.rev()
.enumerate()
.for_each(|(j, j_rev)| {
a.at_mut(col, j_rev)[idx] = (data >> (j * basek)) & mask;
})
}
// Case where prec % k != 0.
if k_rem != basek {
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size).rev().for_each(|j| {
a.at_mut(col, j)[idx] <<= k_rem;
})
}
}
}
impl<D: DataRef> VecZnx<D> {
pub fn decode_vec_i64(&self, basek: usize, col: usize, k: usize, data: &mut [i64]) {
let size: usize = k.div_ceil(basek);
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = self.to_ref();
assert!(
data.len() >= a.n(),
"invalid data: data.len()={} < a.n()={}",
data.len(),
a.n()
);
assert!(col < a.cols());
}
let a: VecZnx<&[u8]> = self.to_ref();
data.copy_from_slice(a.at(col, 0));
let rem: usize = basek - (k % basek);
if k < basek {
data.iter_mut().for_each(|x| *x >>= rem);
} else {
(1..size).for_each(|i| {
if i == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
izip!(a.at(col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << k_rem) + (x >> rem);
});
} else {
izip!(a.at(col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << basek) + x;
});
}
})
}
}
pub fn decode_coeff_i64(&self, basek: usize, col: usize, k: usize, idx: usize) -> i64 {
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = self.to_ref();
assert!(idx < a.n());
assert!(col < a.cols())
}
let a: VecZnx<&[u8]> = self.to_ref();
let size: usize = k.div_ceil(basek);
let mut res: i64 = 0;
let rem: usize = basek - (k % basek);
(0..size).for_each(|j| {
let x: i64 = a.at(col, j)[idx];
if j == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
res = (res << k_rem) + (x >> rem);
} else {
res = (res << basek) + x;
}
});
res
}
pub fn decode_vec_float(&self, basek: usize, col: usize, data: &mut [Float]) {
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = self.to_ref();
assert!(
data.len() >= a.n(),
"invalid data: data.len()={} < a.n()={}",
data.len(),
a.n()
);
assert!(col < a.cols());
}
let a: VecZnx<&[u8]> = self.to_ref();
let size: usize = a.size();
let prec: u32 = (basek * size) as u32;
// 2^{basek}
let base = Float::with_val(prec, (1 << basek) as f64);
// y[i] = sum x[j][i] * 2^{-basek*j}
(0..size).for_each(|i| {
if i == 0 {
izip!(a.at(col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
y.assign(*x);
*y /= &base;
});
} else {
izip!(a.at(col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
*y += Float::with_val(prec, *x);
*y /= &base;
});
}
});
}
}

View File

@@ -1,8 +1,10 @@
mod encoding;
mod mat_znx;
mod module;
mod scalar_znx;
mod scratch;
mod serialization;
mod stats;
mod svp_ppol;
mod vec_znx;
mod vec_znx_big;

View File

@@ -0,0 +1,32 @@
use rug::{
Float,
float::Round,
ops::{AddAssignRound, DivAssignRound, SubAssignRound},
};
use crate::hal::{
api::ZnxInfos,
layouts::{DataRef, VecZnx},
};
impl<D: DataRef> VecZnx<D> {
pub fn std(&self, basek: usize, col: usize) -> f64 {
let prec: u32 = (self.size() * basek) as u32;
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
self.decode_vec_float(basek, col, &mut data);
// std = sqrt(sum((xi - avg)^2) / n)
let mut avg: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| {
avg.add_assign_round(x, Round::Nearest);
});
avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
data.iter_mut().for_each(|x| {
x.sub_assign_round(&avg, Round::Nearest);
});
let mut std: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| std += x * x);
std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
std = std.sqrt();
std.to_f64()
}
}

View File

@@ -1,5 +1,4 @@
use rand_distr::Distribution;
use rug::Float;
use sampling::source::Source;
use crate::hal::layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef};
@@ -288,15 +287,6 @@ pub unsafe trait VecZnxCopyImpl<B: Backend> {
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::hal::api::VecZnxStd] for corresponding public API.
/// * See [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxStdImpl<B: Backend> {
fn vec_znx_std_impl<A>(module: &Module<B>, basek: usize, a: &A, a_col: usize) -> f64
where
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::hal::api::VecZnxFillUniform] for corresponding public API.
/// * See [crate::doc::backend_safety] for safety contract.
@@ -373,68 +363,3 @@ pub unsafe trait VecZnxAddNormalImpl<B: Backend> {
) where
R: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See \[TODO\] for reference code.
/// * See [crate::hal::api::VecZnxEncodeVeci64] for corresponding public API.
/// * See [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxEncodeVeci64Impl<B: Backend> {
fn encode_vec_i64_impl<R>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
data: &[i64],
log_max: usize,
) where
R: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See \[TODO\] for reference code.
/// * See [crate::hal::api::VecZnxEncodeCoeffsi64] for corresponding public API.
/// * See [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxEncodeCoeffsi64Impl<B: Backend> {
fn encode_coeff_i64_impl<R>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
i: usize,
data: i64,
log_max: usize,
) where
R: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See \[TODO\] for reference code.
/// * See [crate::hal::api::VecZnxDecodeVeci64] for corresponding public API.
/// * See [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDecodeVeci64Impl<B: Backend> {
fn decode_vec_i64_impl<R>(module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
where
R: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See \[TODO\] for reference code.
/// * See [crate::hal::api::VecZnxDecodeCoeffsi64] for corresponding public API.
/// * See [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDecodeCoeffsi64Impl<B: Backend> {
fn decode_coeff_i64_impl<R>(module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
where
R: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See \[TODO\] for reference code.
/// * See [crate::hal::api::VecZnxDecodeVecFloat] for corresponding public API.
/// * See [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDecodeVecFloatImpl<B: Backend> {
fn decode_vec_float_impl<R>(module: &Module<B>, basek: usize, res: &R, res_col: usize, data: &mut [Float])
where
R: VecZnxToRef;
}

View File

@@ -0,0 +1,52 @@
use sampling::source::Source;
use crate::hal::{
api::{ZnxInfos, ZnxViewMut},
layouts::VecZnx,
};
pub fn test_vec_znx_encode_vec_i64_lo_norm() {
let n: usize = 32;
let basek: usize = 17;
let size: usize = 5;
let k: usize = size * basek - 5;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source: Source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut()
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
a.encode_vec_i64(basek, col_i, k, &have, 10);
let mut want: Vec<i64> = vec![i64::default(); n];
a.decode_vec_i64(basek, col_i, k, &mut want);
assert_eq!(have, want, "{:?} != {:?}", &have, &want);
});
}
pub fn test_vec_znx_encode_vec_i64_hi_norm() {
let n: usize = 32;
let basek: usize = 17;
let size: usize = 5;
for k in [1, basek / 2, size * basek - 5] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut().for_each(|x| {
if k < 64 {
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
} else {
*x = source.next_i64();
}
});
a.encode_vec_i64(basek, col_i, k, &have, 63);
let mut want: Vec<i64> = vec![i64::default(); n];
a.decode_vec_i64(basek, col_i, k, &mut want);
assert_eq!(have, want, "{:?} != {:?}", &have, &want);
})
}
}

View File

@@ -1,14 +1,13 @@
use itertools::izip;
use sampling::source::Source;
use crate::hal::{
api::{VecZnxAddNormal, VecZnxDecodeVeci64, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxStd, ZnxInfos, ZnxView, ZnxViewMut},
api::{VecZnxAddNormal, VecZnxFillUniform, ZnxView},
layouts::{Backend, Module, VecZnx},
};
pub fn test_vec_znx_fill_uniform<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxFillUniform + VecZnxStd,
Module<B>: VecZnxFillUniform,
{
let n: usize = module.n();
let basek: usize = 17;
@@ -26,7 +25,7 @@ where
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = module.vec_znx_std(basek, &a, col_i);
let std: f64 = a.std(basek, col_i);
assert!(
(std - one_12_sqrt).abs() < 0.01,
"std={} ~!= {}",
@@ -40,7 +39,7 @@ where
pub fn test_vec_znx_add_normal<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxAddNormal + VecZnxStd,
Module<B>: VecZnxAddNormal,
{
let n: usize = module.n();
let basek: usize = 17;
@@ -61,61 +60,9 @@ where
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = module.vec_znx_std(basek, &a, col_i) * k_f64;
let std: f64 = a.std(basek, col_i) * k_f64;
assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma);
}
})
});
}
pub fn test_vec_znx_encode_vec_i64_lo_norm<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64,
{
let n: usize = module.n();
let basek: usize = 17;
let size: usize = 5;
let k: usize = size * basek - 5;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source: Source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut()
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
module.encode_vec_i64(basek, &mut a, col_i, k, &have, 10);
let mut want: Vec<i64> = vec![i64::default(); n];
module.decode_vec_i64(basek, &a, col_i, k, &mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
});
}
pub fn test_vec_znx_encode_vec_i64_hi_norm<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64,
{
let n: usize = module.n();
let basek: usize = 17;
let size: usize = 5;
for k in [1, basek / 2, size * basek - 5] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut().for_each(|x| {
if k < 64 {
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
} else {
*x = source.next_i64();
}
});
module.encode_vec_i64(basek, &mut a, col_i, k, &have, 63);
let mut want: Vec<i64> = vec![i64::default(); n];
module.decode_vec_i64(basek, &a, col_i, k, &mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
})
}
}

View File

@@ -1,2 +1,5 @@
mod generics;
pub use generics::*;
#[cfg(test)]
mod encoding;

View File

@@ -2,10 +2,7 @@ use crate::{
hal::{
api::ModuleNew,
layouts::Module,
tests::vec_znx::{
test_vec_znx_add_normal, test_vec_znx_encode_vec_i64_hi_norm, test_vec_znx_encode_vec_i64_lo_norm,
test_vec_znx_fill_uniform,
},
tests::vec_znx::{test_vec_znx_add_normal, test_vec_znx_fill_uniform},
},
implementation::cpu_spqlios::FFT64,
};
@@ -21,15 +18,3 @@ fn test_vec_znx_add_normal_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(1 << 12);
test_vec_znx_add_normal(&module);
}
#[test]
fn test_vec_znx_encode_vec_lo_norm_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(1 << 8);
test_vec_znx_encode_vec_i64_lo_norm(&module);
}
#[test]
fn test_vec_znx_encode_vec_hi_norm_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(1 << 8);
test_vec_znx_encode_vec_i64_hi_norm(&module);
}

View File

@@ -1,29 +1,22 @@
use itertools::izip;
use rand_distr::Normal;
use rug::{
Assign, Float,
float::Round,
ops::{AddAssignRound, DivAssignRound, SubAssignRound},
};
use sampling::source::Source;
use crate::{
hal::{
api::{
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxDecodeVecFloat, VecZnxFillDistF64,
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView,
ZnxViewMut, ZnxZero,
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate,
VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef},
oep::{
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxDecodeCoeffsi64Impl,
VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, VecZnxEncodeVeci64Impl,
VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl,
VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl,
VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl,
VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl,
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl,
VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl,
VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
},
},
implementation::cpu_spqlios::{
@@ -857,35 +850,6 @@ where
})
}
unsafe impl<B: Backend> VecZnxStdImpl<B> for B
where
B: CPUAVX,
{
fn vec_znx_std_impl<A>(module: &Module<B>, basek: usize, a: &A, a_col: usize) -> f64
where
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let prec: u32 = (a.size() * basek) as u32;
let mut data: Vec<Float> = (0..a.n()).map(|_| Float::with_val(prec, 0)).collect();
module.decode_vec_float(basek, &a, a_col, &mut data);
// std = sqrt(sum((xi - avg)^2) / n)
let mut avg: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| {
avg.add_assign_round(x, Round::Nearest);
});
avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
data.iter_mut().for_each(|x| {
x.sub_assign_round(&avg, Round::Nearest);
});
let mut std: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| std += x * x);
std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
std = std.sqrt();
std.to_f64()
}
}
unsafe impl<B: Backend> VecZnxFillUniformImpl<B> for B
where
B: CPUAVX,
@@ -1053,251 +1017,3 @@ where
);
}
}
unsafe impl<B: Backend> VecZnxEncodeVeci64Impl<B> for B
where
B: CPUAVX,
{
fn encode_vec_i64_impl<R>(
_module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
data: &[i64],
log_max: usize,
) where
R: VecZnxToMut,
{
let size: usize = k.div_ceil(basek);
#[cfg(debug_assertions)]
{
let a: VecZnx<&mut [u8]> = res.to_mut();
assert!(
size <= a.size(),
"invalid argument k: k.div_ceil(basek)={} > a.size()={}",
size,
a.size()
);
assert!(res_col < a.cols());
assert!(data.len() <= a.n())
}
let data_len: usize = data.len();
let mut a: VecZnx<&mut [u8]> = res.to_mut();
let k_rem: usize = basek - (k % basek);
// Zeroes coefficients of the i-th column
(0..a.size()).for_each(|i| {
a.zero_at(res_col, i);
});
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == basek {
a.at_mut(res_col, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
} else {
let mask: i64 = (1 << basek) - 1;
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size)
.rev()
.enumerate()
.for_each(|(i, i_rev)| {
let shift: usize = i * basek;
izip!(a.at_mut(res_col, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
})
}
// Case where self.prec % self.k != 0.
if k_rem != basek {
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size).rev().for_each(|i| {
a.at_mut(res_col, i)[..data_len]
.iter_mut()
.for_each(|x| *x <<= k_rem);
})
}
}
}
unsafe impl<B: Backend> VecZnxEncodeCoeffsi64Impl<B> for B
where
B: CPUAVX,
{
fn encode_coeff_i64_impl<R>(
_module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
i: usize,
data: i64,
log_max: usize,
) where
R: VecZnxToMut,
{
let size: usize = k.div_ceil(basek);
#[cfg(debug_assertions)]
{
let a: VecZnx<&mut [u8]> = res.to_mut();
assert!(i < a.n());
assert!(
size <= a.size(),
"invalid argument k: k.div_ceil(basek)={} > a.size()={}",
size,
a.size()
);
assert!(res_col < a.cols());
}
let k_rem: usize = basek - (k % basek);
let mut a: VecZnx<&mut [u8]> = res.to_mut();
(0..a.size()).for_each(|j| a.at_mut(res_col, j)[i] = 0);
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == basek {
a.at_mut(res_col, size - 1)[i] = data;
} else {
let mask: i64 = (1 << basek) - 1;
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size)
.rev()
.enumerate()
.for_each(|(j, j_rev)| {
a.at_mut(res_col, j_rev)[i] = (data >> (j * basek)) & mask;
})
}
// Case where prec % k != 0.
if k_rem != basek {
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size).rev().for_each(|j| {
a.at_mut(res_col, j)[i] <<= k_rem;
})
}
}
}
unsafe impl<B: Backend> VecZnxDecodeVeci64Impl<B> for B
where
B: CPUAVX,
{
fn decode_vec_i64_impl<R>(_module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
where
R: VecZnxToRef,
{
let size: usize = k.div_ceil(basek);
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = res.to_ref();
assert!(
data.len() >= a.n(),
"invalid data: data.len()={} < a.n()={}",
data.len(),
a.n()
);
assert!(res_col < a.cols());
}
let a: VecZnx<&[u8]> = res.to_ref();
data.copy_from_slice(a.at(res_col, 0));
let rem: usize = basek - (k % basek);
if k < basek {
data.iter_mut().for_each(|x| *x >>= rem);
} else {
(1..size).for_each(|i| {
if i == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
izip!(a.at(res_col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << k_rem) + (x >> rem);
});
} else {
izip!(a.at(res_col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << basek) + x;
});
}
})
}
}
}
unsafe impl<B: Backend> VecZnxDecodeCoeffsi64Impl<B> for B
where
B: CPUAVX,
{
fn decode_coeff_i64_impl<R>(_module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
where
R: VecZnxToRef,
{
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = res.to_ref();
assert!(i < a.n());
assert!(res_col < a.cols())
}
let a: VecZnx<&[u8]> = res.to_ref();
let size: usize = k.div_ceil(basek);
let mut res: i64 = 0;
let rem: usize = basek - (k % basek);
(0..size).for_each(|j| {
let x: i64 = a.at(res_col, j)[i];
if j == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
res = (res << k_rem) + (x >> rem);
} else {
res = (res << basek) + x;
}
});
res
}
}
unsafe impl<B: Backend> VecZnxDecodeVecFloatImpl<B> for B
where
B: CPUAVX,
{
fn decode_vec_float_impl<R>(_module: &Module<B>, basek: usize, res: &R, res_col: usize, data: &mut [Float])
where
R: VecZnxToRef,
{
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = res.to_ref();
assert!(
data.len() >= a.n(),
"invalid data: data.len()={} < a.n()={}",
data.len(),
a.n()
);
assert!(res_col < a.cols());
}
let a: VecZnx<&[u8]> = res.to_ref();
let size: usize = a.size();
let prec: u32 = (basek * size) as u32;
// 2^{basek}
let base = Float::with_val(prec, (1 << basek) as f64);
// y[i] = sum x[j][i] * 2^{-basek*j}
(0..size).for_each(|i| {
if i == 0 {
izip!(a.at(res_col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
y.assign(*x);
*y /= &base;
});
} else {
izip!(a.at(res_col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
*y += Float::with_val(prec, *x);
*y /= &base;
});
}
});
}
}