mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add BDD Arithmetic (#98)
* Added some circuit, evaluation + some layouts * Refactor + memory reduction * Rows -> Dnum, Digits -> Dsize * fix #96 + glwe_packing (indirectly CBT) * clippy
This commit is contained in:
committed by
GitHub
parent
37e13b965c
commit
6357a05509
@@ -42,13 +42,13 @@ where
|
||||
|
||||
if block_size > 1 {
|
||||
let cols: usize = (brk_infos.rank() + 1).into();
|
||||
let rows: usize = brk_infos.rows().into();
|
||||
let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, rows) * extension_factor;
|
||||
let dnum: usize = brk_infos.dnum().into();
|
||||
let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, dnum) * extension_factor;
|
||||
let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size);
|
||||
let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor;
|
||||
let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(1, brk_size);
|
||||
let acc_dft_add: usize = vmp_res;
|
||||
let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
|
||||
let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(brk_size, dnum, dnum, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
|
||||
let acc: usize = if extension_factor > 1 {
|
||||
VecZnx::alloc_bytes(module.n(), cols, glwe_infos.size()) * extension_factor
|
||||
} else {
|
||||
@@ -158,11 +158,11 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
let n_glwe: usize = brk.n_glwe().into();
|
||||
let extension_factor: usize = lut.extension_factor();
|
||||
let base2k: usize = res.base2k().into();
|
||||
let rows: usize = brk.rows().into();
|
||||
let dnum: usize = brk.dnum().into();
|
||||
let cols: usize = (res.rank() + 1).into();
|
||||
|
||||
let (mut acc, scratch_1) = scratch.take_vec_znx_slice(extension_factor, n_glwe, cols, res.size());
|
||||
let (mut acc_dft, scratch_2) = scratch_1.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, rows);
|
||||
let (mut acc_dft, scratch_2) = scratch_1.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, dnum);
|
||||
let (mut vmp_res, scratch_3) = scratch_2.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size());
|
||||
let (mut acc_add_dft, scratch_4) = scratch_3.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size());
|
||||
let (mut vmp_xai, scratch_5) = scratch_4.take_vec_znx_dft(n_glwe, 1, brk.size());
|
||||
@@ -328,7 +328,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
|
||||
let two_n: usize = n_glwe << 1;
|
||||
let base2k: usize = brk.base2k().into();
|
||||
let rows: usize = brk.rows().into();
|
||||
let dnum: usize = brk.dnum().into();
|
||||
|
||||
let cols: usize = (out_mut.rank() + 1).into();
|
||||
|
||||
@@ -351,7 +351,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
|
||||
// ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)]
|
||||
|
||||
let (mut acc_dft, scratch_1) = scratch.take_vec_znx_dft(n_glwe, cols, rows);
|
||||
let (mut acc_dft, scratch_1) = scratch.take_vec_znx_dft(n_glwe, cols, dnum);
|
||||
let (mut vmp_res, scratch_2) = scratch_1.take_vec_znx_dft(n_glwe, cols, brk.size());
|
||||
let (mut acc_add_dft, scratch_3) = scratch_2.take_vec_znx_dft(n_glwe, cols, brk.size());
|
||||
let (mut vmp_xai, scratch_4) = scratch_3.take_vec_znx_dft(n_glwe, 1, brk.size());
|
||||
|
||||
@@ -8,7 +8,7 @@ use std::{fmt, marker::PhantomData};
|
||||
use poulpy_core::{
|
||||
Distribution,
|
||||
layouts::{
|
||||
Base2K, Degree, Digits, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, LWESecret, Rank, Rows, TorusPrecision,
|
||||
Base2K, Degree, Dnum, Dsize, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, LWESecret, Rank, TorusPrecision,
|
||||
prepared::GLWESecretPrepared,
|
||||
},
|
||||
};
|
||||
@@ -23,7 +23,7 @@ pub struct BlindRotationKeyLayout {
|
||||
pub n_lwe: Degree,
|
||||
pub base2k: Base2K,
|
||||
pub k: TorusPrecision,
|
||||
pub rows: Rows,
|
||||
pub dnum: Dnum,
|
||||
pub rank: Rank,
|
||||
}
|
||||
|
||||
@@ -38,12 +38,12 @@ impl BlindRotationKeyInfos for BlindRotationKeyLayout {
|
||||
}
|
||||
|
||||
impl GGSWInfos for BlindRotationKeyLayout {
|
||||
fn digits(&self) -> Digits {
|
||||
Digits(1)
|
||||
fn dsize(&self) -> Dsize {
|
||||
Dsize(1)
|
||||
}
|
||||
|
||||
fn rows(&self) -> Rows {
|
||||
self.rows
|
||||
fn dnum(&self) -> Dnum {
|
||||
self.dnum
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,11 +221,11 @@ impl<D: DataRef, BRT: BlindRotationAlgo> GLWEInfos for BlindRotationKey<D, BRT>
|
||||
}
|
||||
}
|
||||
impl<D: DataRef, BRT: BlindRotationAlgo> GGSWInfos for BlindRotationKey<D, BRT> {
|
||||
fn digits(&self) -> poulpy_core::layouts::Digits {
|
||||
Digits(1)
|
||||
fn dsize(&self) -> poulpy_core::layouts::Dsize {
|
||||
Dsize(1)
|
||||
}
|
||||
|
||||
fn rows(&self) -> Rows {
|
||||
self.keys[0].rows()
|
||||
fn dnum(&self) -> Dnum {
|
||||
self.keys[0].dnum()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ use std::{fmt, marker::PhantomData};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use poulpy_core::{
|
||||
Distribution,
|
||||
layouts::{Base2K, Degree, Digits, GGSWInfos, GLWEInfos, LWEInfos, TorusPrecision, compressed::GGSWCiphertextCompressed},
|
||||
layouts::{Base2K, Degree, Dsize, GGSWInfos, GLWEInfos, LWEInfos, TorusPrecision, compressed::GGSWCiphertextCompressed},
|
||||
};
|
||||
|
||||
use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKeyInfos};
|
||||
@@ -128,12 +128,12 @@ impl<D: DataRef, BRA: BlindRotationAlgo> GLWEInfos for BlindRotationKeyCompresse
|
||||
}
|
||||
|
||||
impl<D: DataRef, BRA: BlindRotationAlgo> GGSWInfos for BlindRotationKeyCompressed<D, BRA> {
|
||||
fn rows(&self) -> poulpy_core::layouts::Rows {
|
||||
self.keys[0].rows()
|
||||
fn dnum(&self) -> poulpy_core::layouts::Dnum {
|
||||
self.keys[0].dnum()
|
||||
}
|
||||
|
||||
fn digits(&self) -> poulpy_core::layouts::Digits {
|
||||
Digits(1)
|
||||
fn dsize(&self) -> poulpy_core::layouts::Dsize {
|
||||
Dsize(1)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ use std::marker::PhantomData;
|
||||
use poulpy_core::{
|
||||
Distribution,
|
||||
layouts::{
|
||||
Base2K, Degree, Digits, GGSWInfos, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision,
|
||||
Base2K, Degree, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision,
|
||||
prepared::{GGSWCiphertextPrepared, Prepare, PrepareAlloc},
|
||||
},
|
||||
};
|
||||
@@ -63,12 +63,12 @@ impl<D: Data, BRT: BlindRotationAlgo, B: Backend> GLWEInfos for BlindRotationKey
|
||||
}
|
||||
}
|
||||
impl<D: Data, BRT: BlindRotationAlgo, B: Backend> GGSWInfos for BlindRotationKeyPrepared<D, BRT, B> {
|
||||
fn digits(&self) -> poulpy_core::layouts::Digits {
|
||||
Digits(1)
|
||||
fn dsize(&self) -> poulpy_core::layouts::Dsize {
|
||||
Dsize(1)
|
||||
}
|
||||
|
||||
fn rows(&self) -> Rows {
|
||||
self.data[0].rows()
|
||||
fn dnum(&self) -> Dnum {
|
||||
self.data[0].dnum()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace,
|
||||
VecZnxRotateInplaceTmpBytes, VecZnxSwitchRing,
|
||||
ScratchOwnedAlloc, ScratchOwnedBorrow, TakeSlice, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes,
|
||||
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxSwitchRing,
|
||||
},
|
||||
layouts::{Backend, Module, ScratchOwned, VecZnx, ZnxInfos, ZnxViewMut},
|
||||
layouts::{Backend, Module, Scratch, ScratchOwned, VecZnx, ZnxInfos, ZnxViewMut},
|
||||
reference::{vec_znx::vec_znx_rotate_inplace, znx::ZnxRef},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
@@ -76,12 +77,13 @@ impl LookUpTable {
|
||||
+ VecZnxCopy
|
||||
+ VecZnxRotateInplaceTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
Scratch<B>: TakeSlice,
|
||||
{
|
||||
assert!(f.len() <= module.n());
|
||||
|
||||
let base2k: usize = self.base2k;
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes() | (self.domain_size() << 3));
|
||||
|
||||
// Get the number minimum limb to store the message modulus
|
||||
let limbs: usize = k.div_ceil(base2k);
|
||||
@@ -128,19 +130,21 @@ impl LookUpTable {
|
||||
// Rotates half the step to the left
|
||||
|
||||
if self.extension_factor() > 1 {
|
||||
(0..self.extension_factor()).for_each(|i| {
|
||||
let (tmp, _) = scratch.borrow().take_slice(lut_full.n());
|
||||
|
||||
for i in 0..self.extension_factor() {
|
||||
module.vec_znx_switch_ring(&mut self.data[i], 0, &lut_full, 0);
|
||||
if i < self.extension_factor() {
|
||||
module.vec_znx_rotate_inplace(-1, &mut lut_full, 0, scratch.borrow());
|
||||
vec_znx_rotate_inplace::<_, ZnxRef>(-1, &mut lut_full, 0, tmp);
|
||||
}
|
||||
});
|
||||
}
|
||||
} else {
|
||||
module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0);
|
||||
}
|
||||
|
||||
self.data.iter_mut().for_each(|a| {
|
||||
for a in self.data.iter_mut() {
|
||||
module.vec_znx_normalize_inplace(self.base2k, a, 0, scratch.borrow());
|
||||
});
|
||||
}
|
||||
|
||||
self.rotate(module, -(drift as i64));
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ use poulpy_hal::{
|
||||
},
|
||||
layouts::{Backend, Module, ScratchOwned, ZnxView},
|
||||
oep::{
|
||||
ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
|
||||
ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
|
||||
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl,
|
||||
},
|
||||
source::Source,
|
||||
@@ -82,7 +82,8 @@ where
|
||||
+ TakeVecZnxDftSliceImpl<B>
|
||||
+ ScratchAvailableImpl<B>
|
||||
+ TakeVecZnxImpl<B>
|
||||
+ TakeVecZnxSliceImpl<B>,
|
||||
+ TakeVecZnxSliceImpl<B>
|
||||
+ TakeSliceImpl<B>,
|
||||
{
|
||||
let n_glwe: usize = module.n();
|
||||
let base2k: usize = 19;
|
||||
@@ -106,7 +107,7 @@ where
|
||||
n_lwe: n_lwe.into(),
|
||||
base2k: base2k.into(),
|
||||
k: k_brk.into(),
|
||||
rows: rows_brk.into(),
|
||||
dnum: rows_brk.into(),
|
||||
rank: rank.into(),
|
||||
};
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ use poulpy_hal::{
|
||||
VecZnxSwitchRing,
|
||||
},
|
||||
layouts::{Backend, Module},
|
||||
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl},
|
||||
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl},
|
||||
};
|
||||
|
||||
use crate::tfhe::blind_rotation::{DivRound, LookUpTable};
|
||||
@@ -19,7 +19,7 @@ where
|
||||
+ VecZnxSwitchRing
|
||||
+ VecZnxCopy
|
||||
+ VecZnxRotateInplaceTmpBytes,
|
||||
B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
|
||||
B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B> + TakeSliceImpl<B>,
|
||||
{
|
||||
let base2k: usize = 20;
|
||||
let k_lut: usize = 40;
|
||||
@@ -59,7 +59,7 @@ where
|
||||
+ VecZnxSwitchRing
|
||||
+ VecZnxCopy
|
||||
+ VecZnxRotateInplaceTmpBytes,
|
||||
B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
|
||||
B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B> + TakeSliceImpl<B>,
|
||||
{
|
||||
let base2k: usize = 20;
|
||||
let k_lut: usize = 40;
|
||||
|
||||
@@ -11,7 +11,7 @@ fn test_cggi_blind_rotation_key_serialization() {
|
||||
n_lwe: 64_usize.into(),
|
||||
base2k: 12_usize.into(),
|
||||
k: 54_usize.into(),
|
||||
rows: 2_usize.into(),
|
||||
dnum: 2_usize.into(),
|
||||
rank: 2_usize.into(),
|
||||
};
|
||||
|
||||
@@ -26,7 +26,7 @@ fn test_cggi_blind_rotation_key_compressed_serialization() {
|
||||
n_lwe: 64_usize.into(),
|
||||
base2k: 12_usize.into(),
|
||||
k: 54_usize.into(),
|
||||
rows: 2_usize.into(),
|
||||
dnum: 2_usize.into(),
|
||||
rank: 2_usize.into(),
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user