mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add BDD Arithmetic (#98)
* Added some circuit, evaluation + some layouts * Refactor + memory reduction * Rows -> Dnum, Digits -> Dsize * fix #96 + glwe_packing (indirectly CBT) * clippy
This commit is contained in:
committed by
GitHub
parent
37e13b965c
commit
6357a05509
@@ -7,7 +7,7 @@ use poulpy_hal::{
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch},
|
||||
};
|
||||
|
||||
use crate::layouts::{GGLWEAutomorphismKey, GGLWELayoutInfos, GGLWESwitchingKey, GGSWInfos, prepared::GGSWCiphertextPrepared};
|
||||
use crate::layouts::{GGLWEAutomorphismKey, GGLWEInfos, GGLWESwitchingKey, GGSWInfos, prepared::GGSWCiphertextPrepared};
|
||||
|
||||
impl GGLWEAutomorphismKey<Vec<u8>> {
|
||||
pub fn external_product_scratch_space<B: Backend, OUT, IN, GGSW>(
|
||||
@@ -17,8 +17,8 @@ impl GGLWEAutomorphismKey<Vec<u8>> {
|
||||
ggsw_infos: &GGSW,
|
||||
) -> usize
|
||||
where
|
||||
OUT: GGLWELayoutInfos,
|
||||
IN: GGLWELayoutInfos,
|
||||
OUT: GGLWEInfos,
|
||||
IN: GGLWEInfos,
|
||||
GGSW: GGSWInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
@@ -31,7 +31,7 @@ impl GGLWEAutomorphismKey<Vec<u8>> {
|
||||
ggsw_infos: &GGSW,
|
||||
) -> usize
|
||||
where
|
||||
OUT: GGLWELayoutInfos,
|
||||
OUT: GGLWEInfos,
|
||||
GGSW: GGSWInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
|
||||
@@ -7,7 +7,7 @@ use poulpy_hal::{
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
|
||||
};
|
||||
|
||||
use crate::layouts::{GGLWELayoutInfos, GGLWESwitchingKey, GGSWInfos, GLWECiphertext, prepared::GGSWCiphertextPrepared};
|
||||
use crate::layouts::{GGLWEInfos, GGLWESwitchingKey, GGSWInfos, GLWECiphertext, prepared::GGSWCiphertextPrepared};
|
||||
|
||||
impl GGLWESwitchingKey<Vec<u8>> {
|
||||
pub fn external_product_scratch_space<B: Backend, OUT, IN, GGSW>(
|
||||
@@ -17,8 +17,8 @@ impl GGLWESwitchingKey<Vec<u8>> {
|
||||
ggsw_infos: &GGSW,
|
||||
) -> usize
|
||||
where
|
||||
OUT: GGLWELayoutInfos,
|
||||
IN: GGLWELayoutInfos,
|
||||
OUT: GGLWEInfos,
|
||||
IN: GGLWEInfos,
|
||||
GGSW: GGSWInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
@@ -36,7 +36,7 @@ impl GGLWESwitchingKey<Vec<u8>> {
|
||||
ggsw_infos: &GGSW,
|
||||
) -> usize
|
||||
where
|
||||
OUT: GGLWELayoutInfos,
|
||||
OUT: GGLWEInfos,
|
||||
GGSW: GGSWInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
@@ -91,13 +91,13 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
}
|
||||
|
||||
(0..self.rank_in().into()).for_each(|col_i| {
|
||||
(0..self.rows().into()).for_each(|row_j| {
|
||||
(0..self.dnum().into()).for_each(|row_j| {
|
||||
self.at_mut(row_j, col_i)
|
||||
.external_product(module, &lhs.at(row_j, col_i), rhs, scratch);
|
||||
});
|
||||
});
|
||||
|
||||
(self.rows().min(lhs.rows()).into()..self.rows().into()).for_each(|row_i| {
|
||||
(self.dnum().min(lhs.dnum()).into()..self.dnum().into()).for_each(|row_i| {
|
||||
(0..self.rank_in().into()).for_each(|col_j| {
|
||||
self.at_mut(row_i, col_j).data.zero();
|
||||
});
|
||||
@@ -135,7 +135,7 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
}
|
||||
|
||||
(0..self.rank_in().into()).for_each(|col_i| {
|
||||
(0..self.rows().into()).for_each(|row_j| {
|
||||
(0..self.dnum().into()).for_each(|row_j| {
|
||||
self.at_mut(row_j, col_i)
|
||||
.external_product_inplace(module, rhs, scratch);
|
||||
});
|
||||
|
||||
@@ -89,14 +89,14 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
assert!(scratch.available() >= GGSWCiphertext::external_product_scratch_space(module, self, lhs, rhs))
|
||||
}
|
||||
|
||||
let min_rows: usize = self.rows().min(lhs.rows()).into();
|
||||
let min_dnum: usize = self.dnum().min(lhs.dnum()).into();
|
||||
|
||||
(0..(self.rank() + 1).into()).for_each(|col_i| {
|
||||
(0..min_rows).for_each(|row_j| {
|
||||
(0..min_dnum).for_each(|row_j| {
|
||||
self.at_mut(row_j, col_i)
|
||||
.external_product(module, &lhs.at(row_j, col_i), rhs, scratch);
|
||||
});
|
||||
(min_rows..self.rows().into()).for_each(|row_i| {
|
||||
(min_dnum..self.dnum().into()).for_each(|row_i| {
|
||||
self.at_mut(row_i, col_i).data.zero();
|
||||
});
|
||||
});
|
||||
@@ -134,7 +134,7 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
}
|
||||
|
||||
(0..(self.rank() + 1).into()).for_each(|col_i| {
|
||||
(0..self.rows().into()).for_each(|row_j| {
|
||||
(0..self.dnum().into()).for_each(|row_j| {
|
||||
self.at_mut(row_j, col_i)
|
||||
.external_product_inplace(module, rhs, scratch);
|
||||
});
|
||||
|
||||
@@ -7,7 +7,13 @@ use poulpy_hal::{
|
||||
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig},
|
||||
};
|
||||
|
||||
use crate::layouts::{GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGSWCiphertextPrepared};
|
||||
use crate::{
|
||||
GLWEExternalProduct, GLWEExternalProductInplace,
|
||||
layouts::{
|
||||
GGSWInfos, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEInfos, LWEInfos,
|
||||
prepared::{GGSWCiphertextPrepared, GGSWCiphertextPreparedToRef},
|
||||
},
|
||||
};
|
||||
|
||||
impl GLWECiphertext<Vec<u8>> {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -26,7 +32,7 @@ impl GLWECiphertext<Vec<u8>> {
|
||||
let in_size: usize = in_infos
|
||||
.k()
|
||||
.div_ceil(apply_infos.base2k())
|
||||
.div_ceil(apply_infos.digits().into()) as usize;
|
||||
.div_ceil(apply_infos.dsize().into()) as usize;
|
||||
let out_size: usize = out_infos.size();
|
||||
let ggsw_size: usize = apply_infos.size();
|
||||
let res_dft: usize = module.vec_znx_dft_alloc_bytes((apply_infos.rank() + 1).into(), ggsw_size);
|
||||
@@ -71,70 +77,221 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
rhs: &GGSWCiphertextPrepared<DataRhs, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftApply<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
|
||||
Module<B>: GLWEExternalProduct<B>,
|
||||
{
|
||||
module.external_product(self, lhs, rhs, scratch);
|
||||
}
|
||||
|
||||
pub fn external_product_inplace<DataRhs: DataRef, B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
rhs: &GGSWCiphertextPrepared<DataRhs, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: GLWEExternalProductInplace<B>,
|
||||
{
|
||||
module.external_product_inplace(self, rhs, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> GLWEExternalProductInplace<BE> for Module<BE>
|
||||
where
|
||||
Module<BE>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftApply<BE>
|
||||
+ VmpApplyDftToDft<BE>
|
||||
+ VmpApplyDftToDftAdd<BE>
|
||||
+ VecZnxIdftApplyConsume<BE>
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxNormalize<BE>,
|
||||
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx,
|
||||
{
|
||||
fn external_product_inplace<R, D>(&self, res: &mut R, ggsw: &D, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWECiphertextToMut,
|
||||
D: GGSWCiphertextPreparedToRef<BE>,
|
||||
{
|
||||
let res: &mut GLWECiphertext<&mut [u8]> = &mut res.to_mut();
|
||||
let rhs: &GGSWCiphertextPrepared<&[u8], BE> = &ggsw.to_ref();
|
||||
|
||||
let basek_in: usize = res.base2k().into();
|
||||
let basek_ggsw: usize = rhs.base2k().into();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
use poulpy_hal::api::ScratchAvailable;
|
||||
|
||||
assert_eq!(rhs.rank(), res.rank());
|
||||
assert_eq!(rhs.n(), res.n());
|
||||
assert!(scratch.available() >= GLWECiphertext::external_product_inplace_scratch_space(self, res, rhs));
|
||||
}
|
||||
|
||||
let cols: usize = (rhs.rank() + 1).into();
|
||||
let dsize: usize = rhs.dsize().into();
|
||||
let a_size: usize = (res.size() * basek_in).div_ceil(basek_ggsw);
|
||||
|
||||
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(res.n().into(), cols, rhs.size()); // Todo optimise
|
||||
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(res.n().into(), cols, a_size.div_ceil(dsize));
|
||||
a_dft.data_mut().fill(0);
|
||||
|
||||
if basek_in == basek_ggsw {
|
||||
for di in 0..dsize {
|
||||
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
||||
a_dft.set_size((res.size() + di) / dsize);
|
||||
|
||||
// Small optimization for dsize > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
|
||||
} else {
|
||||
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_normalize(
|
||||
basek_ggsw,
|
||||
&mut a_conv,
|
||||
j,
|
||||
basek_in,
|
||||
&res.data,
|
||||
j,
|
||||
scratch_3,
|
||||
);
|
||||
}
|
||||
|
||||
for di in 0..dsize {
|
||||
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
||||
a_dft.set_size((res.size() + di) / dsize);
|
||||
|
||||
// Small optimization for dsize > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
|
||||
} else {
|
||||
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_big_normalize(
|
||||
basek_in,
|
||||
&mut res.data,
|
||||
j,
|
||||
basek_ggsw,
|
||||
&res_big,
|
||||
j,
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE>
|
||||
where
|
||||
Module<BE>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftApply<BE>
|
||||
+ VmpApplyDftToDft<BE>
|
||||
+ VmpApplyDftToDftAdd<BE>
|
||||
+ VecZnxIdftApplyConsume<BE>
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxNormalize<BE>,
|
||||
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx,
|
||||
{
|
||||
fn external_product<R, A, D>(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWECiphertextToMut,
|
||||
A: GLWECiphertextToRef,
|
||||
D: GGSWCiphertextPreparedToRef<BE>,
|
||||
{
|
||||
let res: &mut GLWECiphertext<&mut [u8]> = &mut res.to_mut();
|
||||
let lhs: &GLWECiphertext<&[u8]> = &lhs.to_ref();
|
||||
|
||||
let rhs: &GGSWCiphertextPrepared<&[u8], BE> = &rhs.to_ref();
|
||||
|
||||
let basek_in: usize = lhs.base2k().into();
|
||||
let basek_ggsw: usize = rhs.base2k().into();
|
||||
let basek_out: usize = self.base2k().into();
|
||||
let basek_out: usize = res.base2k().into();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
use poulpy_hal::api::ScratchAvailable;
|
||||
|
||||
assert_eq!(rhs.rank(), lhs.rank());
|
||||
assert_eq!(rhs.rank(), self.rank());
|
||||
assert_eq!(rhs.n(), self.n());
|
||||
assert_eq!(lhs.n(), self.n());
|
||||
assert!(scratch.available() >= GLWECiphertext::external_product_scratch_space(module, self, lhs, rhs));
|
||||
assert_eq!(rhs.rank(), res.rank());
|
||||
assert_eq!(rhs.n(), res.n());
|
||||
assert_eq!(lhs.n(), res.n());
|
||||
assert!(scratch.available() >= GLWECiphertext::external_product_scratch_space(self, res, lhs, rhs));
|
||||
}
|
||||
|
||||
let cols: usize = (rhs.rank() + 1).into();
|
||||
let digits: usize = rhs.digits().into();
|
||||
let dsize: usize = rhs.dsize().into();
|
||||
|
||||
let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw);
|
||||
|
||||
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), cols, rhs.size()); // Todo optimise
|
||||
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), cols, a_size.div_ceil(digits));
|
||||
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise
|
||||
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), cols, a_size.div_ceil(dsize));
|
||||
a_dft.data_mut().fill(0);
|
||||
|
||||
if basek_in == basek_ggsw {
|
||||
for di in 0..digits {
|
||||
// (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits)
|
||||
a_dft.set_size((lhs.size() + di) / digits);
|
||||
for di in 0..dsize {
|
||||
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
||||
a_dft.set_size((lhs.size() + di) / dsize);
|
||||
|
||||
// Small optimization for digits > 2
|
||||
// Small optimization for dsize > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
|
||||
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last digits-1 limbs, but this introduce
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
|
||||
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||
|
||||
for j in 0..cols {
|
||||
module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, j, &lhs.data, j);
|
||||
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &lhs.data, j);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
|
||||
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
|
||||
} else {
|
||||
module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
|
||||
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(module.n(), cols, a_size);
|
||||
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module.vec_znx_normalize(
|
||||
self.vec_znx_normalize(
|
||||
basek_ggsw,
|
||||
&mut a_conv,
|
||||
j,
|
||||
@@ -145,37 +302,37 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
);
|
||||
}
|
||||
|
||||
for di in 0..digits {
|
||||
// (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits)
|
||||
a_dft.set_size((a_size + di) / digits);
|
||||
for di in 0..dsize {
|
||||
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
||||
a_dft.set_size((a_size + di) / dsize);
|
||||
|
||||
// Small optimization for digits > 2
|
||||
// Small optimization for dsize > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
|
||||
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last digits-1 limbs, but this introduce
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
|
||||
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||
|
||||
for j in 0..cols {
|
||||
module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, j, &a_conv, j);
|
||||
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a_conv, j);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_3);
|
||||
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_3);
|
||||
} else {
|
||||
module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_3);
|
||||
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let res_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(res_dft);
|
||||
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_big_normalize(
|
||||
self.vec_znx_big_normalize(
|
||||
basek_out,
|
||||
&mut self.data,
|
||||
res.data_mut(),
|
||||
i,
|
||||
basek_ggsw,
|
||||
&res_big,
|
||||
@@ -184,120 +341,4 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn external_product_inplace<DataRhs: DataRef, B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
rhs: &GGSWCiphertextPrepared<DataRhs, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftApply<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
|
||||
{
|
||||
let basek_in: usize = self.base2k().into();
|
||||
let basek_ggsw: usize = rhs.base2k().into();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
use poulpy_hal::api::ScratchAvailable;
|
||||
|
||||
assert_eq!(rhs.rank(), self.rank());
|
||||
assert_eq!(rhs.n(), self.n());
|
||||
assert!(scratch.available() >= GLWECiphertext::external_product_inplace_scratch_space(module, self, rhs,));
|
||||
}
|
||||
|
||||
let cols: usize = (rhs.rank() + 1).into();
|
||||
let digits: usize = rhs.digits().into();
|
||||
let a_size: usize = (self.size() * basek_in).div_ceil(basek_ggsw);
|
||||
|
||||
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), cols, rhs.size()); // Todo optimise
|
||||
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), cols, a_size.div_ceil(digits));
|
||||
a_dft.data_mut().fill(0);
|
||||
|
||||
if basek_in == basek_ggsw {
|
||||
for di in 0..digits {
|
||||
// (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits)
|
||||
a_dft.set_size((self.size() + di) / digits);
|
||||
|
||||
// Small optimization for digits > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
|
||||
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last digits-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
|
||||
|
||||
for j in 0..cols {
|
||||
module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, j, &self.data, j);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
|
||||
} else {
|
||||
module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(module.n(), cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module.vec_znx_normalize(
|
||||
basek_ggsw,
|
||||
&mut a_conv,
|
||||
j,
|
||||
basek_in,
|
||||
&self.data,
|
||||
j,
|
||||
scratch_3,
|
||||
);
|
||||
}
|
||||
|
||||
for di in 0..digits {
|
||||
// (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits)
|
||||
a_dft.set_size((self.size() + di) / digits);
|
||||
|
||||
// Small optimization for digits > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
|
||||
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last digits-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
|
||||
|
||||
for j in 0..cols {
|
||||
module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, j, &self.data, j);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
|
||||
} else {
|
||||
module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let res_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(res_dft);
|
||||
|
||||
for j in 0..cols {
|
||||
module.vec_znx_big_normalize(
|
||||
basek_in,
|
||||
&mut self.data,
|
||||
j,
|
||||
basek_ggsw,
|
||||
&res_big,
|
||||
j,
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,23 @@
|
||||
use poulpy_hal::layouts::{Backend, Scratch};
|
||||
|
||||
use crate::layouts::{GLWECiphertextToMut, GLWECiphertextToRef, prepared::GGSWCiphertextPreparedToRef};
|
||||
|
||||
mod gglwe_atk;
|
||||
mod gglwe_ksk;
|
||||
mod ggsw_ct;
|
||||
mod glwe_ct;
|
||||
|
||||
pub trait GLWEExternalProduct<BE: Backend> {
|
||||
fn external_product<R, A, D>(&self, res: &mut R, a: &A, ggsw: &D, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWECiphertextToMut,
|
||||
A: GLWECiphertextToRef,
|
||||
D: GGSWCiphertextPreparedToRef<BE>;
|
||||
}
|
||||
|
||||
pub trait GLWEExternalProductInplace<BE: Backend> {
|
||||
fn external_product_inplace<R, D>(&self, res: &mut R, ggsw: &D, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWECiphertextToMut,
|
||||
D: GGSWCiphertextPreparedToRef<BE>;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user