mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add schemes (#71)
* Move br + cbt to schemes/tfhe * refactor blind rotation * refactor circuit bootstrapping * renamed exec -> prepared
This commit is contained in:
committed by
GitHub
parent
8d9897b88b
commit
c7219c35e9
@@ -1,29 +1,22 @@
|
||||
use itertools::izip;
|
||||
use rand_distr::Normal;
|
||||
use rug::{
|
||||
Assign, Float,
|
||||
float::Round,
|
||||
ops::{AddAssignRound, DivAssignRound, SubAssignRound},
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{
|
||||
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxDecodeVecFloat, VecZnxFillDistF64,
|
||||
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
ZnxViewMut, ZnxZero,
|
||||
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate,
|
||||
VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
},
|
||||
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxDecodeCoeffsi64Impl,
|
||||
VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, VecZnxEncodeVeci64Impl,
|
||||
VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl,
|
||||
VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl,
|
||||
VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl,
|
||||
VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl,
|
||||
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl,
|
||||
VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl,
|
||||
VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
|
||||
VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl,
|
||||
VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
@@ -857,35 +850,6 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxStdImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn vec_znx_std_impl<A>(module: &Module<B>, basek: usize, a: &A, a_col: usize) -> f64
|
||||
where
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let prec: u32 = (a.size() * basek) as u32;
|
||||
let mut data: Vec<Float> = (0..a.n()).map(|_| Float::with_val(prec, 0)).collect();
|
||||
module.decode_vec_float(basek, &a, a_col, &mut data);
|
||||
// std = sqrt(sum((xi - avg)^2) / n)
|
||||
let mut avg: Float = Float::with_val(prec, 0);
|
||||
data.iter().for_each(|x| {
|
||||
avg.add_assign_round(x, Round::Nearest);
|
||||
});
|
||||
avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
|
||||
data.iter_mut().for_each(|x| {
|
||||
x.sub_assign_round(&avg, Round::Nearest);
|
||||
});
|
||||
let mut std: Float = Float::with_val(prec, 0);
|
||||
data.iter().for_each(|x| std += x * x);
|
||||
std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
|
||||
std = std.sqrt();
|
||||
std.to_f64()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxFillUniformImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
@@ -1053,251 +1017,3 @@ where
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxEncodeVeci64Impl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn encode_vec_i64_impl<R>(
|
||||
_module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
data: &[i64],
|
||||
log_max: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let size: usize = k.div_ceil(basek);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
size <= a.size(),
|
||||
"invalid argument k: k.div_ceil(basek)={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
assert!(res_col < a.cols());
|
||||
assert!(data.len() <= a.n())
|
||||
}
|
||||
|
||||
let data_len: usize = data.len();
|
||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let k_rem: usize = basek - (k % basek);
|
||||
|
||||
// Zeroes coefficients of the i-th column
|
||||
(0..a.size()).for_each(|i| {
|
||||
a.zero_at(res_col, i);
|
||||
});
|
||||
|
||||
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + k_rem < 63 || k_rem == basek {
|
||||
a.at_mut(res_col, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
|
||||
} else {
|
||||
let mask: i64 = (1 << basek) - 1;
|
||||
let steps: usize = size.min(log_max.div_ceil(basek));
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(i, i_rev)| {
|
||||
let shift: usize = i * basek;
|
||||
izip!(a.at_mut(res_col, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
|
||||
})
|
||||
}
|
||||
|
||||
// Case where self.prec % self.k != 0.
|
||||
if k_rem != basek {
|
||||
let steps: usize = size.min(log_max.div_ceil(basek));
|
||||
(size - steps..size).rev().for_each(|i| {
|
||||
a.at_mut(res_col, i)[..data_len]
|
||||
.iter_mut()
|
||||
.for_each(|x| *x <<= k_rem);
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxEncodeCoeffsi64Impl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn encode_coeff_i64_impl<R>(
|
||||
_module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
i: usize,
|
||||
data: i64,
|
||||
log_max: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let size: usize = k.div_ceil(basek);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
assert!(i < a.n());
|
||||
assert!(
|
||||
size <= a.size(),
|
||||
"invalid argument k: k.div_ceil(basek)={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
assert!(res_col < a.cols());
|
||||
}
|
||||
|
||||
let k_rem: usize = basek - (k % basek);
|
||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
(0..a.size()).for_each(|j| a.at_mut(res_col, j)[i] = 0);
|
||||
|
||||
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + k_rem < 63 || k_rem == basek {
|
||||
a.at_mut(res_col, size - 1)[i] = data;
|
||||
} else {
|
||||
let mask: i64 = (1 << basek) - 1;
|
||||
let steps: usize = size.min(log_max.div_ceil(basek));
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(j, j_rev)| {
|
||||
a.at_mut(res_col, j_rev)[i] = (data >> (j * basek)) & mask;
|
||||
})
|
||||
}
|
||||
|
||||
// Case where prec % k != 0.
|
||||
if k_rem != basek {
|
||||
let steps: usize = size.min(log_max.div_ceil(basek));
|
||||
(size - steps..size).rev().for_each(|j| {
|
||||
a.at_mut(res_col, j)[i] <<= k_rem;
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxDecodeVeci64Impl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn decode_vec_i64_impl<R>(_module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
let size: usize = k.div_ceil(basek);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&[u8]> = res.to_ref();
|
||||
assert!(
|
||||
data.len() >= a.n(),
|
||||
"invalid data: data.len()={} < a.n()={}",
|
||||
data.len(),
|
||||
a.n()
|
||||
);
|
||||
assert!(res_col < a.cols());
|
||||
}
|
||||
|
||||
let a: VecZnx<&[u8]> = res.to_ref();
|
||||
data.copy_from_slice(a.at(res_col, 0));
|
||||
let rem: usize = basek - (k % basek);
|
||||
if k < basek {
|
||||
data.iter_mut().for_each(|x| *x >>= rem);
|
||||
} else {
|
||||
(1..size).for_each(|i| {
|
||||
if i == size - 1 && rem != basek {
|
||||
let k_rem: usize = basek - rem;
|
||||
izip!(a.at(res_col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << k_rem) + (x >> rem);
|
||||
});
|
||||
} else {
|
||||
izip!(a.at(res_col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << basek) + x;
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxDecodeCoeffsi64Impl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn decode_coeff_i64_impl<R>(_module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&[u8]> = res.to_ref();
|
||||
assert!(i < a.n());
|
||||
assert!(res_col < a.cols())
|
||||
}
|
||||
|
||||
let a: VecZnx<&[u8]> = res.to_ref();
|
||||
let size: usize = k.div_ceil(basek);
|
||||
let mut res: i64 = 0;
|
||||
let rem: usize = basek - (k % basek);
|
||||
(0..size).for_each(|j| {
|
||||
let x: i64 = a.at(res_col, j)[i];
|
||||
if j == size - 1 && rem != basek {
|
||||
let k_rem: usize = basek - rem;
|
||||
res = (res << k_rem) + (x >> rem);
|
||||
} else {
|
||||
res = (res << basek) + x;
|
||||
}
|
||||
});
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxDecodeVecFloatImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn decode_vec_float_impl<R>(_module: &Module<B>, basek: usize, res: &R, res_col: usize, data: &mut [Float])
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: VecZnx<&[u8]> = res.to_ref();
|
||||
assert!(
|
||||
data.len() >= a.n(),
|
||||
"invalid data: data.len()={} < a.n()={}",
|
||||
data.len(),
|
||||
a.n()
|
||||
);
|
||||
assert!(res_col < a.cols());
|
||||
}
|
||||
|
||||
let a: VecZnx<&[u8]> = res.to_ref();
|
||||
let size: usize = a.size();
|
||||
let prec: u32 = (basek * size) as u32;
|
||||
|
||||
// 2^{basek}
|
||||
let base = Float::with_val(prec, (1 << basek) as f64);
|
||||
|
||||
// y[i] = sum x[j][i] * 2^{-basek*j}
|
||||
(0..size).for_each(|i| {
|
||||
if i == 0 {
|
||||
izip!(a.at(res_col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
y.assign(*x);
|
||||
*y /= &base;
|
||||
});
|
||||
} else {
|
||||
izip!(a.at(res_col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y += Float::with_val(prec, *x);
|
||||
*y /= &base;
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user