Add BDD Arithmetic (#98)

* Added some circuit, evaluation + some layouts

* Refactor + memory reduction

* Rows -> Dnum, Digits -> Dsize

* fix #96 + glwe_packing (indirectly CBT)

* clippy
This commit is contained in:
Jean-Philippe Bossuat
2025-10-08 17:52:03 +02:00
committed by GitHub
parent 37e13b965c
commit 6357a05509
119 changed files with 15996 additions and 1659 deletions

View File

@@ -7,7 +7,7 @@ use poulpy_hal::{
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
};
use crate::layouts::{GGLWELayoutInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWESwitchingKeyPrepared};
use crate::layouts::{GGLWEInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWESwitchingKeyPrepared};
impl GLWECiphertext<Vec<u8>> {
pub fn keyswitch_scratch_space<B: Backend, OUT, IN, KEY>(
@@ -19,13 +19,13 @@ impl GLWECiphertext<Vec<u8>> {
where
OUT: GLWEInfos,
IN: GLWEInfos,
KEY: GGLWELayoutInfos,
KEY: GGLWEInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
{
let in_size: usize = in_infos
.k()
.div_ceil(key_apply.base2k())
.div_ceil(key_apply.digits().into()) as usize;
.div_ceil(key_apply.dsize().into()) as usize;
let out_size: usize = out_infos.size();
let ksk_size: usize = key_apply.size();
let res_dft: usize = module.vec_znx_dft_alloc_bytes((key_apply.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE
@@ -41,12 +41,12 @@ impl GLWECiphertext<Vec<u8>> {
let normalize_big: usize = module.vec_znx_big_normalize_tmp_bytes();
if in_infos.base2k() == key_apply.base2k() {
res_dft + ((ai_dft + vmp) | normalize_big)
} else if key_apply.digits() == 1 {
} else if key_apply.dsize() == 1 {
// In this case, we only need one column, temporary, that we can drop once a_dft is computed.
let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), 1, in_size) + module.vec_znx_normalize_tmp_bytes();
res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big)
} else {
// Since we stride over a to get a_dft when digits > 1, we need to store the full columns of a with in the base conversion.
// Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion.
let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), (key_apply.rank_in()).into(), in_size);
res_dft + ((ai_dft + normalize_conv + (module.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
}
@@ -55,7 +55,7 @@ impl GLWECiphertext<Vec<u8>> {
pub fn keyswitch_inplace_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_apply: &KEY) -> usize
where
OUT: GLWEInfos,
KEY: GGLWELayoutInfos,
KEY: GGLWEInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
{
Self::keyswitch_scratch_space(module, out_infos, out_infos, key_apply)
@@ -105,7 +105,7 @@ impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
lhs.k(),
rhs.base2k(),
rhs.k(),
rhs.digits(),
rhs.dsize(),
rhs.rank_in(),
rhs.rank_out(),
)={scrach_needed}",
@@ -256,7 +256,7 @@ impl<D: DataRef> GLWECiphertext<D> {
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + TakeVecZnx,
{
if rhs.digits() == 1 {
if rhs.dsize() == 1 {
return keyswitch_vmp_one_digit(
module,
self.base2k().into(),
@@ -275,7 +275,7 @@ impl<D: DataRef> GLWECiphertext<D> {
res_dft,
&self.data,
&rhs.key.data,
rhs.digits().into(),
rhs.dsize().into(),
scratch,
)
}
@@ -333,7 +333,7 @@ fn keyswitch_vmp_multiple_digits<B: Backend, DataRes, DataIn, DataVmp>(
mut res_dft: VecZnxDft<DataRes, B>,
a: &VecZnx<DataIn>,
mat: &VmpPMat<DataVmp, B>,
digits: usize,
dsize: usize,
scratch: &mut Scratch<B>,
) -> VecZnxBig<DataRes, B>
where
@@ -351,24 +351,24 @@ where
{
let cols: usize = a.cols();
let a_size: usize = (a.size() * basek_in).div_ceil(basek_ksk);
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a_size.div_ceil(digits));
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a_size.div_ceil(dsize));
ai_dft.data_mut().fill(0);
if basek_in == basek_ksk {
for di in 0..digits {
ai_dft.set_size((a_size + di) / digits);
for di in 0..dsize {
ai_dft.set_size((a_size + di) / dsize);
// Small optimization for digits > 2
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last digits-1 limbs, but this introduce
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
res_dft.set_size(mat.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols - 1 {
module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, j, a, j + 1);
module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, a, j + 1);
}
if di == 0 {
@@ -383,20 +383,20 @@ where
module.vec_znx_normalize(basek_ksk, &mut a_conv, j, basek_in, a, j + 1, scratch_2);
}
for di in 0..digits {
ai_dft.set_size((a_size + di) / digits);
for di in 0..dsize {
ai_dft.set_size((a_size + di) / dsize);
// Small optimization for digits > 2
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last digits-1 limbs, but this introduce
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
res_dft.set_size(mat.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols - 1 {
module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, j, &a_conv, j);
module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, &a_conv, j);
}
if di == 0 {