mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
wip
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
|
||||
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf,
|
||||
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
@@ -20,7 +20,7 @@ impl AutomorphismKey<Vec<u8>> {
|
||||
OUT: GGLWEInfos,
|
||||
IN: GGLWEInfos,
|
||||
GGSW: GGSWInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
GLWESwitchingKey::external_product_scratch_space(module, out_infos, in_infos, ggsw_infos)
|
||||
}
|
||||
@@ -33,7 +33,7 @@ impl AutomorphismKey<Vec<u8>> {
|
||||
where
|
||||
OUT: GGLWEInfos,
|
||||
GGSW: GGSWInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
GLWESwitchingKey::external_product_inplace_scratch_space(module, out_infos, ggsw_infos)
|
||||
}
|
||||
@@ -47,7 +47,7 @@ impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
|
||||
rhs: &GGSWPrepared<DataRhs, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
Module<B>: VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftApply<B>
|
||||
@@ -67,7 +67,7 @@ impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
|
||||
rhs: &GGSWPrepared<DataRhs, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
Module<B>: VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftApply<B>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
|
||||
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf,
|
||||
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
@@ -20,7 +20,7 @@ impl GLWESwitchingKey<Vec<u8>> {
|
||||
OUT: GGLWEInfos,
|
||||
IN: GGLWEInfos,
|
||||
GGSW: GGSWInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
GLWE::external_product_scratch_space(
|
||||
module,
|
||||
@@ -38,7 +38,7 @@ impl GLWESwitchingKey<Vec<u8>> {
|
||||
where
|
||||
OUT: GGLWEInfos,
|
||||
GGSW: GGSWInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
GLWE::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), ggsw_infos)
|
||||
}
|
||||
@@ -52,7 +52,7 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
|
||||
rhs: &GGSWPrepared<DataRhs, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
Module<B>: VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftApply<B>
|
||||
@@ -110,7 +110,7 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
|
||||
rhs: &GGSWPrepared<DataRhs, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
Module<B>: VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftApply<B>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
|
||||
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf,
|
||||
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
@@ -21,7 +21,7 @@ impl GGSW<Vec<u8>> {
|
||||
OUT: GGSWInfos,
|
||||
IN: GGSWInfos,
|
||||
GGSW: GGSWInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
GLWE::external_product_scratch_space(
|
||||
module,
|
||||
@@ -39,7 +39,7 @@ impl GGSW<Vec<u8>> {
|
||||
where
|
||||
OUT: GGSWInfos,
|
||||
GGSW: GGSWInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
GLWE::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), apply_infos)
|
||||
}
|
||||
@@ -53,7 +53,7 @@ impl<DataSelf: DataMut> GGSW<DataSelf> {
|
||||
rhs: &GGSWPrepared<DataRhs, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
Module<B>: VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftApply<B>
|
||||
@@ -108,7 +108,7 @@ impl<DataSelf: DataMut> GGSW<DataSelf> {
|
||||
rhs: &GGSWPrepared<DataRhs, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
Module<B>: VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftApply<B>
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
|
||||
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
|
||||
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig},
|
||||
};
|
||||
|
||||
use crate::layouts::{
|
||||
GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos,
|
||||
prepared::{GGSWCiphertextPreparedToRef, GGSWPrepared},
|
||||
use crate::{
|
||||
ScratchTakeCore,
|
||||
layouts::{
|
||||
GGSWInfos, GGSWToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetDegree, LWEInfos,
|
||||
prepared::{GGSWCiphertextPreparedToRef, GGSWPrepared},
|
||||
},
|
||||
};
|
||||
|
||||
impl GLWE<Vec<u8>> {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn external_product_scratch_space<B: Backend, OUT, IN, GGSW>(
|
||||
module: &Module<B>,
|
||||
impl<DataSelf: DataMut> GLWE<DataSelf> {
|
||||
pub fn external_product_scratch_space<OUT, IN, GGSW, B: Backend>(
|
||||
module: Module<B>,
|
||||
out_infos: &OUT,
|
||||
in_infos: &IN,
|
||||
apply_infos: &GGSW,
|
||||
@@ -24,76 +25,35 @@ impl GLWE<Vec<u8>> {
|
||||
OUT: GLWEInfos,
|
||||
IN: GLWEInfos,
|
||||
GGSW: GGSWInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
Module<B>: GLWEExternalProduct<B>,
|
||||
{
|
||||
let in_size: usize = in_infos
|
||||
.k()
|
||||
.div_ceil(apply_infos.base2k())
|
||||
.div_ceil(apply_infos.dsize().into()) as usize;
|
||||
let out_size: usize = out_infos.size();
|
||||
let ggsw_size: usize = apply_infos.size();
|
||||
let res_dft: usize = module.vec_znx_dft_bytes_of((apply_infos.rank() + 1).into(), ggsw_size);
|
||||
let a_dft: usize = module.vec_znx_dft_bytes_of((apply_infos.rank() + 1).into(), in_size);
|
||||
let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(
|
||||
out_size,
|
||||
in_size,
|
||||
in_size, // rows
|
||||
(apply_infos.rank() + 1).into(), // cols in
|
||||
(apply_infos.rank() + 1).into(), // cols out
|
||||
ggsw_size,
|
||||
);
|
||||
let normalize_big: usize = module.vec_znx_normalize_tmp_bytes();
|
||||
|
||||
if in_infos.base2k() == apply_infos.base2k() {
|
||||
res_dft + a_dft + (vmp | normalize_big)
|
||||
} else {
|
||||
let normalize_conv: usize = VecZnx::bytes_of(module.n(), (apply_infos.rank() + 1).into(), in_size);
|
||||
res_dft + ((a_dft + normalize_conv + (module.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
|
||||
}
|
||||
module.glwe_external_product_scratch_space(out_infos, in_infos, apply_infos)
|
||||
}
|
||||
|
||||
pub fn external_product_inplace_scratch_space<B: Backend, OUT, GGSW>(
|
||||
module: &Module<B>,
|
||||
out_infos: &OUT,
|
||||
apply_infos: &GGSW,
|
||||
) -> usize
|
||||
pub fn external_product<L, R, B: Backend>(&mut self, module: &Module<B>, lhs: &L, rhs: &R, scratch: &mut Scratch<B>)
|
||||
where
|
||||
OUT: GLWEInfos,
|
||||
GGSW: GGSWInfos,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
Self::external_product_scratch_space(module, out_infos, out_infos, apply_infos)
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf: DataMut> GLWE<DataSelf> {
|
||||
pub fn external_product<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
lhs: &GLWE<DataLhs>,
|
||||
rhs: &GGSWPrepared<DataRhs, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
L: GLWEToRef,
|
||||
R: GGSWToRef,
|
||||
Module<B>: GLWEExternalProduct<B>,
|
||||
Scratch<B>: ScratchTakeCore<B>,
|
||||
{
|
||||
module.external_product(self, lhs, rhs, scratch);
|
||||
module.glwe_external_product(self, lhs, rhs, scratch);
|
||||
}
|
||||
|
||||
pub fn external_product_inplace<DataRhs: DataRef, B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
rhs: &GGSWPrepared<DataRhs, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
pub fn external_product_inplace<R, B: Backend>(&mut self, module: &Module<B>, rhs: &R, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: GGSWToRef,
|
||||
Module<B>: GLWEExternalProduct<B>,
|
||||
Scratch<B>: ScratchTakeCore<B>,
|
||||
{
|
||||
module.external_product_inplace(self, rhs, scratch);
|
||||
module.glwe_external_product_inplace(self, rhs, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
pub trait GLWEExternalProduct<BE: Backend>
|
||||
where
|
||||
Self: VecZnxDftAllocBytes
|
||||
Self: GetDegree
|
||||
+ VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftApply<BE>
|
||||
@@ -101,13 +61,49 @@ where
|
||||
+ VmpApplyDftToDftAdd<BE>
|
||||
+ VecZnxIdftApplyConsume<BE>
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxNormalize<BE>,
|
||||
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx,
|
||||
+ VecZnxNormalize<BE>
|
||||
+ VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn glwe_external_product_scratch_space<OUT, IN, GGSW>(&self, out_infos: &OUT, in_infos: &IN, apply_infos: &GGSW) -> usize
|
||||
where
|
||||
OUT: GLWEInfos,
|
||||
IN: GLWEInfos,
|
||||
GGSW: GGSWInfos,
|
||||
{
|
||||
let in_size: usize = in_infos
|
||||
.k()
|
||||
.div_ceil(apply_infos.base2k())
|
||||
.div_ceil(apply_infos.dsize().into()) as usize;
|
||||
let out_size: usize = out_infos.size();
|
||||
let ggsw_size: usize = apply_infos.size();
|
||||
let res_dft: usize = self.bytes_of_vec_znx_dft((apply_infos.rank() + 1).into(), ggsw_size);
|
||||
let a_dft: usize = self.bytes_of_vec_znx_dft((apply_infos.rank() + 1).into(), in_size);
|
||||
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
|
||||
out_size,
|
||||
in_size,
|
||||
in_size, // rows
|
||||
(apply_infos.rank() + 1).into(), // cols in
|
||||
(apply_infos.rank() + 1).into(), // cols out
|
||||
ggsw_size,
|
||||
);
|
||||
let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
|
||||
|
||||
if in_infos.base2k() == apply_infos.base2k() {
|
||||
res_dft + a_dft + (vmp | normalize_big)
|
||||
} else {
|
||||
let normalize_conv: usize = VecZnx::bytes_of(self.n().into(), (apply_infos.rank() + 1).into(), in_size);
|
||||
res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
|
||||
}
|
||||
}
|
||||
|
||||
fn glwe_external_product_inplace<R, D>(&self, res: &mut R, ggsw: &D, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
D: GGSWCiphertextPreparedToRef<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||
let rhs: &GGSWPrepared<&[u8], BE> = &ggsw.to_ref();
|
||||
@@ -121,7 +117,7 @@ where
|
||||
|
||||
assert_eq!(rhs.rank(), res.rank());
|
||||
assert_eq!(rhs.n(), res.n());
|
||||
assert!(scratch.available() >= GLWE::external_product_inplace_scratch_space(self, res, rhs));
|
||||
assert!(scratch.available() >= self.glwe_external_product_scratch_space(res, res, rhs));
|
||||
}
|
||||
|
||||
let cols: usize = (rhs.rank() + 1).into();
|
||||
@@ -157,7 +153,7 @@ where
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
|
||||
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n().into(), cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_normalize(
|
||||
@@ -216,6 +212,7 @@ where
|
||||
R: GLWEToMut,
|
||||
A: GLWEToRef,
|
||||
D: GGSWCiphertextPreparedToRef<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||
let lhs: &GLWE<&[u8]> = &lhs.to_ref();
|
||||
@@ -234,7 +231,7 @@ where
|
||||
assert_eq!(rhs.rank(), res.rank());
|
||||
assert_eq!(rhs.n(), res.n());
|
||||
assert_eq!(lhs.n(), res.n());
|
||||
assert!(scratch.available() >= GLWE::external_product_scratch_space(self, res, lhs, rhs));
|
||||
assert!(scratch.available() >= self.glwe_external_product_scratch_space(res, lhs, rhs));
|
||||
}
|
||||
|
||||
let cols: usize = (rhs.rank() + 1).into();
|
||||
@@ -242,8 +239,8 @@ where
|
||||
|
||||
let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw);
|
||||
|
||||
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise
|
||||
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), cols, a_size.div_ceil(dsize));
|
||||
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), cols, rhs.size()); // Todo optimise
|
||||
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), cols, a_size.div_ceil(dsize));
|
||||
a_dft.data_mut().fill(0);
|
||||
|
||||
if basek_in == basek_ggsw {
|
||||
@@ -271,7 +268,7 @@ where
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
|
||||
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n().into(), cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_normalize(
|
||||
@@ -326,9 +323,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE>
|
||||
where
|
||||
Self: VecZnxDftAllocBytes
|
||||
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE> where
|
||||
Self: GetDegree
|
||||
+ VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftApply<BE>
|
||||
@@ -336,7 +333,9 @@ where
|
||||
+ VmpApplyDftToDftAdd<BE>
|
||||
+ VecZnxIdftApplyConsume<BE>
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxNormalize<BE>,
|
||||
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx,
|
||||
+ VecZnxNormalize<BE>
|
||||
+ VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
{
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
use poulpy_hal::layouts::{Backend, Scratch};
|
||||
|
||||
use crate::layouts::{GLWEToMut, GLWEToRef, prepared::GGSWCiphertextPreparedToRef};
|
||||
|
||||
mod gglwe_atk;
|
||||
mod gglwe_ksk;
|
||||
mod ggsw_ct;
|
||||
mod glwe_ct;
|
||||
|
||||
pub use glwe_ct::*;
|
||||
|
||||
Reference in New Issue
Block a user