This commit is contained in:
Pro7ech
2025-10-14 23:39:16 +02:00
parent 72dca47cbe
commit 779e02acc4
94 changed files with 784 additions and 1688 deletions

View File

@@ -1,6 +1,6 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf,
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
@@ -20,7 +20,7 @@ impl AutomorphismKey<Vec<u8>> {
OUT: GGLWEInfos,
IN: GGLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GLWESwitchingKey::external_product_scratch_space(module, out_infos, in_infos, ggsw_infos)
}
@@ -33,7 +33,7 @@ impl AutomorphismKey<Vec<u8>> {
where
OUT: GGLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GLWESwitchingKey::external_product_inplace_scratch_space(module, out_infos, ggsw_infos)
}
@@ -47,7 +47,7 @@ impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftAllocBytes
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>
@@ -67,7 +67,7 @@ impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftAllocBytes
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>

View File

@@ -1,6 +1,6 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf,
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
@@ -20,7 +20,7 @@ impl GLWESwitchingKey<Vec<u8>> {
OUT: GGLWEInfos,
IN: GGLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GLWE::external_product_scratch_space(
module,
@@ -38,7 +38,7 @@ impl GLWESwitchingKey<Vec<u8>> {
where
OUT: GGLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GLWE::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), ggsw_infos)
}
@@ -52,7 +52,7 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftAllocBytes
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>
@@ -110,7 +110,7 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftAllocBytes
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>

View File

@@ -1,6 +1,6 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf,
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
@@ -21,7 +21,7 @@ impl GGSW<Vec<u8>> {
OUT: GGSWInfos,
IN: GGSWInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GLWE::external_product_scratch_space(
module,
@@ -39,7 +39,7 @@ impl GGSW<Vec<u8>> {
where
OUT: GGSWInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GLWE::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), apply_infos)
}
@@ -53,7 +53,7 @@ impl<DataSelf: DataMut> GGSW<DataSelf> {
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftAllocBytes
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>
@@ -108,7 +108,7 @@ impl<DataSelf: DataMut> GGSW<DataSelf> {
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftAllocBytes
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>

View File

@@ -1,21 +1,22 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
ScratchAvailable, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig},
};
use crate::layouts::{
GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos,
prepared::{GGSWCiphertextPreparedToRef, GGSWPrepared},
use crate::{
ScratchTakeCore,
layouts::{
GGSWInfos, GGSWToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetDegree, LWEInfos,
prepared::{GGSWCiphertextPreparedToRef, GGSWPrepared},
},
};
impl GLWE<Vec<u8>> {
#[allow(clippy::too_many_arguments)]
pub fn external_product_scratch_space<B: Backend, OUT, IN, GGSW>(
module: &Module<B>,
impl<DataSelf: DataMut> GLWE<DataSelf> {
pub fn external_product_scratch_space<OUT, IN, GGSW, B: Backend>(
module: Module<B>,
out_infos: &OUT,
in_infos: &IN,
apply_infos: &GGSW,
@@ -24,76 +25,35 @@ impl GLWE<Vec<u8>> {
OUT: GLWEInfos,
IN: GLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
Module<B>: GLWEExternalProduct<B>,
{
let in_size: usize = in_infos
.k()
.div_ceil(apply_infos.base2k())
.div_ceil(apply_infos.dsize().into()) as usize;
let out_size: usize = out_infos.size();
let ggsw_size: usize = apply_infos.size();
let res_dft: usize = module.vec_znx_dft_bytes_of((apply_infos.rank() + 1).into(), ggsw_size);
let a_dft: usize = module.vec_znx_dft_bytes_of((apply_infos.rank() + 1).into(), in_size);
let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(
out_size,
in_size,
in_size, // rows
(apply_infos.rank() + 1).into(), // cols in
(apply_infos.rank() + 1).into(), // cols out
ggsw_size,
);
let normalize_big: usize = module.vec_znx_normalize_tmp_bytes();
if in_infos.base2k() == apply_infos.base2k() {
res_dft + a_dft + (vmp | normalize_big)
} else {
let normalize_conv: usize = VecZnx::bytes_of(module.n(), (apply_infos.rank() + 1).into(), in_size);
res_dft + ((a_dft + normalize_conv + (module.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
}
module.glwe_external_product_scratch_space(out_infos, in_infos, apply_infos)
}
pub fn external_product_inplace_scratch_space<B: Backend, OUT, GGSW>(
module: &Module<B>,
out_infos: &OUT,
apply_infos: &GGSW,
) -> usize
pub fn external_product<L, R, B: Backend>(&mut self, module: &Module<B>, lhs: &L, rhs: &R, scratch: &mut Scratch<B>)
where
OUT: GLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
Self::external_product_scratch_space(module, out_infos, out_infos, apply_infos)
}
}
impl<DataSelf: DataMut> GLWE<DataSelf> {
pub fn external_product<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
lhs: &GLWE<DataLhs>,
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
L: GLWEToRef,
R: GGSWToRef,
Module<B>: GLWEExternalProduct<B>,
Scratch<B>: ScratchTakeCore<B>,
{
module.external_product(self, lhs, rhs, scratch);
module.glwe_external_product(self, lhs, rhs, scratch);
}
pub fn external_product_inplace<DataRhs: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
pub fn external_product_inplace<R, B: Backend>(&mut self, module: &Module<B>, rhs: &R, scratch: &mut Scratch<B>)
where
R: GGSWToRef,
Module<B>: GLWEExternalProduct<B>,
Scratch<B>: ScratchTakeCore<B>,
{
module.external_product_inplace(self, rhs, scratch);
module.glwe_external_product_inplace(self, rhs, scratch);
}
}
pub trait GLWEExternalProduct<BE: Backend>
where
Self: VecZnxDftAllocBytes
Self: GetDegree
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<BE>
@@ -101,13 +61,49 @@ where
+ VmpApplyDftToDftAdd<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalize<BE>,
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx,
+ VecZnxNormalize<BE>
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes,
{
#[allow(clippy::too_many_arguments)]
fn glwe_external_product_scratch_space<OUT, IN, GGSW>(&self, out_infos: &OUT, in_infos: &IN, apply_infos: &GGSW) -> usize
where
OUT: GLWEInfos,
IN: GLWEInfos,
GGSW: GGSWInfos,
{
let in_size: usize = in_infos
.k()
.div_ceil(apply_infos.base2k())
.div_ceil(apply_infos.dsize().into()) as usize;
let out_size: usize = out_infos.size();
let ggsw_size: usize = apply_infos.size();
let res_dft: usize = self.bytes_of_vec_znx_dft((apply_infos.rank() + 1).into(), ggsw_size);
let a_dft: usize = self.bytes_of_vec_znx_dft((apply_infos.rank() + 1).into(), in_size);
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
out_size,
in_size,
in_size, // rows
(apply_infos.rank() + 1).into(), // cols in
(apply_infos.rank() + 1).into(), // cols out
ggsw_size,
);
let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
if in_infos.base2k() == apply_infos.base2k() {
res_dft + a_dft + (vmp | normalize_big)
} else {
let normalize_conv: usize = VecZnx::bytes_of(self.n().into(), (apply_infos.rank() + 1).into(), in_size);
res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
}
}
fn glwe_external_product_inplace<R, D>(&self, res: &mut R, ggsw: &D, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
D: GGSWCiphertextPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let rhs: &GGSWPrepared<&[u8], BE> = &ggsw.to_ref();
@@ -121,7 +117,7 @@ where
assert_eq!(rhs.rank(), res.rank());
assert_eq!(rhs.n(), res.n());
assert!(scratch.available() >= GLWE::external_product_inplace_scratch_space(self, res, rhs));
assert!(scratch.available() >= self.glwe_external_product_scratch_space(res, res, rhs));
}
let cols: usize = (rhs.rank() + 1).into();
@@ -157,7 +153,7 @@ where
}
}
} else {
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n().into(), cols, a_size);
for j in 0..cols {
self.vec_znx_normalize(
@@ -216,6 +212,7 @@ where
R: GLWEToMut,
A: GLWEToRef,
D: GGSWCiphertextPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let lhs: &GLWE<&[u8]> = &lhs.to_ref();
@@ -234,7 +231,7 @@ where
assert_eq!(rhs.rank(), res.rank());
assert_eq!(rhs.n(), res.n());
assert_eq!(lhs.n(), res.n());
assert!(scratch.available() >= GLWE::external_product_scratch_space(self, res, lhs, rhs));
assert!(scratch.available() >= self.glwe_external_product_scratch_space(res, lhs, rhs));
}
let cols: usize = (rhs.rank() + 1).into();
@@ -242,8 +239,8 @@ where
let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw);
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), cols, a_size.div_ceil(dsize));
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), cols, a_size.div_ceil(dsize));
a_dft.data_mut().fill(0);
if basek_in == basek_ggsw {
@@ -271,7 +268,7 @@ where
}
}
} else {
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n().into(), cols, a_size);
for j in 0..cols {
self.vec_znx_normalize(
@@ -326,9 +323,9 @@ where
}
}
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE>
where
Self: VecZnxDftAllocBytes
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE> where
Self: GetDegree
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<BE>
@@ -336,7 +333,9 @@ where
+ VmpApplyDftToDftAdd<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalize<BE>,
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx,
+ VecZnxNormalize<BE>
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
{
}

View File

@@ -1,8 +1,6 @@
use poulpy_hal::layouts::{Backend, Scratch};
use crate::layouts::{GLWEToMut, GLWEToRef, prepared::GGSWCiphertextPreparedToRef};
mod gglwe_atk;
mod gglwe_ksk;
mod ggsw_ct;
mod glwe_ct;
pub use glwe_ct::*;