Add BDD Arithmetic (#98)

* Added some circuit, evaluation + some layouts

* Refactor + memory reduction

* Rows -> Dnum, Digits -> Dsize

* fix #96 + glwe_packing (indirectly CBT)

* clippy
This commit is contained in:
Jean-Philippe Bossuat
2025-10-08 17:52:03 +02:00
committed by GitHub
parent 37e13b965c
commit 6357a05509
119 changed files with 15996 additions and 1659 deletions

View File

@@ -10,7 +10,7 @@ use poulpy_hal::{
use crate::{
layouts::{
GGLWECiphertext, GGLWELayoutInfos, GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos,
GGLWECiphertext, GGLWEInfos, GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos,
prepared::{GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared},
},
operations::GLWEOperations,
@@ -20,14 +20,14 @@ impl GGSWCiphertext<Vec<u8>> {
pub(crate) fn expand_row_scratch_space<B: Backend, OUT, TSK>(module: &Module<B>, out_infos: &OUT, tsk_infos: &TSK) -> usize
where
OUT: GGSWInfos,
TSK: GGLWELayoutInfos,
TSK: GGLWEInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes,
{
let tsk_size: usize = tsk_infos.k().div_ceil(tsk_infos.base2k()) as usize;
let size_in: usize = out_infos
.k()
.div_ceil(tsk_infos.base2k())
.div_ceil(tsk_infos.digits().into()) as usize;
.div_ceil(tsk_infos.dsize().into()) as usize;
let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes((tsk_infos.rank_out() + 1).into(), tsk_size);
let tmp_a: usize = module.vec_znx_dft_alloc_bytes(1, size_in);
@@ -56,8 +56,8 @@ impl GGSWCiphertext<Vec<u8>> {
where
OUT: GGSWInfos,
IN: GGSWInfos,
KEY: GGLWELayoutInfos,
TSK: GGLWELayoutInfos,
KEY: GGLWEInfos,
TSK: GGLWEInfos,
Module<B>: VecZnxDftAllocBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigAllocBytes
@@ -101,8 +101,8 @@ impl GGSWCiphertext<Vec<u8>> {
) -> usize
where
OUT: GGSWInfos,
KEY: GGLWELayoutInfos,
TSK: GGLWELayoutInfos,
KEY: GGLWEInfos,
TSK: GGLWEInfos,
Module<B>: VecZnxDftAllocBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigAllocBytes
@@ -143,12 +143,12 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
use crate::layouts::{GLWEInfos, LWEInfos};
assert_eq!(self.rank(), a.rank_out());
assert_eq!(self.rows(), a.rows());
assert_eq!(self.dnum(), a.dnum());
assert_eq!(self.n(), module.n() as u32);
assert_eq!(a.n(), module.n() as u32);
assert_eq!(tsk.n(), module.n() as u32);
}
(0..self.rows().into()).for_each(|row_i| {
(0..self.dnum().into()).for_each(|row_i| {
self.at_mut(row_i, 0).copy(module, &a.at(row_i, 0));
});
self.expand_row(module, tsk, scratch);
@@ -180,7 +180,7 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VecZnxNormalize<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnx,
{
(0..lhs.rows().into()).for_each(|row_i| {
(0..lhs.dnum().into()).for_each(|row_i| {
// Key-switch column 0, i.e.
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
self.at_mut(row_i, 0)
@@ -214,7 +214,7 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VecZnxNormalize<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnx,
{
(0..self.rows().into()).for_each(|row_i| {
(0..self.dnum().into()).for_each(|row_i| {
// Key-switch column 0, i.e.
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
self.at_mut(row_i, 0)
@@ -255,7 +255,7 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
let a_size: usize = (self.size() * basek_in).div_ceil(basek_tsk);
// Keyswitch the j-th row of the col 0
for row_i in 0..self.rows().into() {
for row_i in 0..self.dnum().into() {
let a = &self.at(row_i, 0).data;
// Pre-compute DFT of (a0, a1, a2)
@@ -277,7 +277,7 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
// Example for rank 3:
//
// Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is
// actually composed of that many rows and we focus on a specific row here
// actually composed of that many dnum and we focus on a specific row here
// implicitely given ci_dft.
//
// # Input
@@ -294,10 +294,10 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
// col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 )
// col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i])
let digits: usize = tsk.digits().into();
let dsize: usize = tsk.dsize().into();
let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(n, cols, tsk.size());
let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits));
let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(dsize));
{
// Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
@@ -315,19 +315,19 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
let pmat: &VmpPMat<DataTsk, B> = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j])
// Extracts a[i] and multipies with Enc(s[i]s[j])
for di in 0..digits {
tmp_a.set_size((ci_dft.size() + di) / digits);
for di in 0..dsize {
tmp_a.set_size((ci_dft.size() + di) / dsize);
// Small optimization for digits > 2
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last digits-1 limbs, but this introduce
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
tmp_dft_i.set_size(tsk.size() - ((digits - di) as isize - 2).max(0) as usize);
tmp_dft_i.set_size(tsk.size() - ((dsize - di) as isize - 2).max(0) as usize);
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, &ci_dft, col_i);
module.vec_znx_dft_copy(dsize, dsize - 1 - di, &mut tmp_a, 0, &ci_dft, col_i);
if di == 0 && col_i == 1 {
module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3);
} else {