Add schemes (#71)

* Move br + cbt to schemes/tfhe

* refactor blind rotation

* refactor circuit bootstrapping

* renamed exec -> prepared
This commit is contained in:
Jean-Philippe Bossuat
2025-08-15 15:06:26 +02:00
committed by GitHub
parent 8d9897b88b
commit c7219c35e9
130 changed files with 2631 additions and 3270 deletions

View File

@@ -1,29 +1,22 @@
use itertools::izip;
use rand_distr::Normal;
use rug::{
Assign, Float,
float::Round,
ops::{AddAssignRound, DivAssignRound, SubAssignRound},
};
use sampling::source::Source;
use crate::{
hal::{
api::{
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxDecodeVecFloat, VecZnxFillDistF64,
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView,
ZnxViewMut, ZnxZero,
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate,
VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef},
oep::{
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxDecodeCoeffsi64Impl,
VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, VecZnxEncodeVeci64Impl,
VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl,
VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl,
VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl,
VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl,
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl,
VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl,
VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
},
},
implementation::cpu_spqlios::{
@@ -857,35 +850,6 @@ where
})
}
unsafe impl<B: Backend> VecZnxStdImpl<B> for B
where
B: CPUAVX,
{
fn vec_znx_std_impl<A>(module: &Module<B>, basek: usize, a: &A, a_col: usize) -> f64
where
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let prec: u32 = (a.size() * basek) as u32;
let mut data: Vec<Float> = (0..a.n()).map(|_| Float::with_val(prec, 0)).collect();
module.decode_vec_float(basek, &a, a_col, &mut data);
// std = sqrt(sum((xi - avg)^2) / n)
let mut avg: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| {
avg.add_assign_round(x, Round::Nearest);
});
avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
data.iter_mut().for_each(|x| {
x.sub_assign_round(&avg, Round::Nearest);
});
let mut std: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| std += x * x);
std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
std = std.sqrt();
std.to_f64()
}
}
unsafe impl<B: Backend> VecZnxFillUniformImpl<B> for B
where
B: CPUAVX,
@@ -1053,251 +1017,3 @@ where
);
}
}
unsafe impl<B: Backend> VecZnxEncodeVeci64Impl<B> for B
where
B: CPUAVX,
{
fn encode_vec_i64_impl<R>(
_module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
data: &[i64],
log_max: usize,
) where
R: VecZnxToMut,
{
let size: usize = k.div_ceil(basek);
#[cfg(debug_assertions)]
{
let a: VecZnx<&mut [u8]> = res.to_mut();
assert!(
size <= a.size(),
"invalid argument k: k.div_ceil(basek)={} > a.size()={}",
size,
a.size()
);
assert!(res_col < a.cols());
assert!(data.len() <= a.n())
}
let data_len: usize = data.len();
let mut a: VecZnx<&mut [u8]> = res.to_mut();
let k_rem: usize = basek - (k % basek);
// Zeroes coefficients of the i-th column
(0..a.size()).for_each(|i| {
a.zero_at(res_col, i);
});
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == basek {
a.at_mut(res_col, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
} else {
let mask: i64 = (1 << basek) - 1;
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size)
.rev()
.enumerate()
.for_each(|(i, i_rev)| {
let shift: usize = i * basek;
izip!(a.at_mut(res_col, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
})
}
// Case where self.prec % self.k != 0.
if k_rem != basek {
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size).rev().for_each(|i| {
a.at_mut(res_col, i)[..data_len]
.iter_mut()
.for_each(|x| *x <<= k_rem);
})
}
}
}
unsafe impl<B: Backend> VecZnxEncodeCoeffsi64Impl<B> for B
where
B: CPUAVX,
{
fn encode_coeff_i64_impl<R>(
_module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
i: usize,
data: i64,
log_max: usize,
) where
R: VecZnxToMut,
{
let size: usize = k.div_ceil(basek);
#[cfg(debug_assertions)]
{
let a: VecZnx<&mut [u8]> = res.to_mut();
assert!(i < a.n());
assert!(
size <= a.size(),
"invalid argument k: k.div_ceil(basek)={} > a.size()={}",
size,
a.size()
);
assert!(res_col < a.cols());
}
let k_rem: usize = basek - (k % basek);
let mut a: VecZnx<&mut [u8]> = res.to_mut();
(0..a.size()).for_each(|j| a.at_mut(res_col, j)[i] = 0);
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == basek {
a.at_mut(res_col, size - 1)[i] = data;
} else {
let mask: i64 = (1 << basek) - 1;
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size)
.rev()
.enumerate()
.for_each(|(j, j_rev)| {
a.at_mut(res_col, j_rev)[i] = (data >> (j * basek)) & mask;
})
}
// Case where prec % k != 0.
if k_rem != basek {
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size).rev().for_each(|j| {
a.at_mut(res_col, j)[i] <<= k_rem;
})
}
}
}
unsafe impl<B: Backend> VecZnxDecodeVeci64Impl<B> for B
where
B: CPUAVX,
{
fn decode_vec_i64_impl<R>(_module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
where
R: VecZnxToRef,
{
let size: usize = k.div_ceil(basek);
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = res.to_ref();
assert!(
data.len() >= a.n(),
"invalid data: data.len()={} < a.n()={}",
data.len(),
a.n()
);
assert!(res_col < a.cols());
}
let a: VecZnx<&[u8]> = res.to_ref();
data.copy_from_slice(a.at(res_col, 0));
let rem: usize = basek - (k % basek);
if k < basek {
data.iter_mut().for_each(|x| *x >>= rem);
} else {
(1..size).for_each(|i| {
if i == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
izip!(a.at(res_col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << k_rem) + (x >> rem);
});
} else {
izip!(a.at(res_col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << basek) + x;
});
}
})
}
}
}
unsafe impl<B: Backend> VecZnxDecodeCoeffsi64Impl<B> for B
where
B: CPUAVX,
{
fn decode_coeff_i64_impl<R>(_module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
where
R: VecZnxToRef,
{
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = res.to_ref();
assert!(i < a.n());
assert!(res_col < a.cols())
}
let a: VecZnx<&[u8]> = res.to_ref();
let size: usize = k.div_ceil(basek);
let mut res: i64 = 0;
let rem: usize = basek - (k % basek);
(0..size).for_each(|j| {
let x: i64 = a.at(res_col, j)[i];
if j == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
res = (res << k_rem) + (x >> rem);
} else {
res = (res << basek) + x;
}
});
res
}
}
unsafe impl<B: Backend> VecZnxDecodeVecFloatImpl<B> for B
where
B: CPUAVX,
{
fn decode_vec_float_impl<R>(_module: &Module<B>, basek: usize, res: &R, res_col: usize, data: &mut [Float])
where
R: VecZnxToRef,
{
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = res.to_ref();
assert!(
data.len() >= a.n(),
"invalid data: data.len()={} < a.n()={}",
data.len(),
a.n()
);
assert!(res_col < a.cols());
}
let a: VecZnx<&[u8]> = res.to_ref();
let size: usize = a.size();
let prec: u32 = (basek * size) as u32;
// 2^{basek}
let base = Float::with_val(prec, (1 << basek) as f64);
// y[i] = sum x[j][i] * 2^{-basek*j}
(0..size).for_each(|i| {
if i == 0 {
izip!(a.at(res_col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
y.assign(*x);
*y /= &base;
});
} else {
izip!(a.at(res_col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
*y += Float::with_val(prec, *x);
*y /= &base;
});
}
});
}
}