This commit is contained in:
Pro7ech
2025-10-15 10:48:14 +02:00
parent a5df85170d
commit 008b800c01
74 changed files with 890 additions and 871 deletions

View File

@@ -1,8 +1,7 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf,
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
ScratchAvailable, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch},
};
@@ -56,7 +55,7 @@ impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
Scratch<B>: ScratchAvailable,
{
self.key.external_product(module, &lhs.key, rhs, scratch);
}
@@ -76,7 +75,7 @@ impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
Scratch<B>: ScratchAvailable,
{
self.key.external_product_inplace(module, rhs, scratch);
}

View File

@@ -1,8 +1,7 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf,
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
ScratchAvailable, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
};
@@ -61,7 +60,7 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
Scratch<B>: ScratchAvailable,
{
#[cfg(debug_assertions)]
{
@@ -119,7 +118,7 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
Scratch<B>: ScratchAvailable,
{
#[cfg(debug_assertions)]
{

View File

@@ -1,47 +1,116 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf,
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
api::ScratchAvailable,
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
};
use crate::layouts::{GGSW, GGSWInfos, GLWE, GLWEInfos, prepared::GGSWPrepared};
use crate::{
GLWEExternalProduct, ScratchTakeCore,
layouts::{
GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWEInfos, LWEInfos,
prepared::{GGSWPrepared, GGSWPreparedToRef},
},
};
impl GGSW<Vec<u8>> {
#[allow(clippy::too_many_arguments)]
pub fn external_product_scratch_space<B: Backend, OUT, IN, GGSW>(
module: &Module<B>,
out_infos: &OUT,
in_infos: &IN,
ggsw_infos: &GGSW,
) -> usize
pub trait GGSWExternalProduct<BE: Backend>
where
Self: GLWEExternalProduct<BE>,
{
fn ggsw_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where
OUT: GGSWInfos,
IN: GGSWInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
R: GGSWInfos,
A: GGSWInfos,
B: GGSWInfos,
{
GLWE::external_product_scratch_space(
module,
&out_infos.glwe_layout(),
&in_infos.glwe_layout(),
ggsw_infos,
)
self.glwe_external_product_scratch_space(res_infos, a_infos, b_infos)
}
pub fn external_product_inplace_scratch_space<B: Backend, OUT, GGSW>(
module: &Module<B>,
out_infos: &OUT,
ggsw_infos: &GGSW,
fn ggsw_external_product<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
R: GGSWToMut,
A: GGSWToRef,
B: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let a: &GGSW<&[u8]> = &a.to_ref();
let b: &GGSWPrepared<&[u8], BE> = &b.to_ref();
assert_eq!(
res.rank(),
a.rank(),
"res rank: {} != a rank: {}",
res.rank(),
a.rank()
);
assert_eq!(
res.rank(),
b.rank(),
"res rank: {} != b rank: {}",
res.rank(),
b.rank()
);
assert!(scratch.available() >= self.ggsw_external_product_tmp_bytes(res, a, b));
let min_dnum: usize = res.dnum().min(a.dnum()).into();
for row in 0..min_dnum {
for col in 0..(res.rank() + 1).into() {
self.glwe_external_product(&mut res.at_mut(row, col), &a.at(row, col), b, scratch);
}
}
for row in min_dnum..res.dnum().into() {
for col in 0..(res.rank() + 1).into() {
res.at_mut(row, col).data.zero();
}
}
}
fn ggsw_external_product_inplace<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
where
R: GGSWToMut,
A: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let a: &GGSWPrepared<&[u8], BE> = &a.to_ref();
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(
res.rank(),
a.rank(),
"res rank: {} != a rank: {}",
res.rank(),
a.rank()
);
for row in 0..res.dnum().into() {
for col in 0..(res.rank() + 1).into() {
self.glwe_external_product_inplace(&mut res.at_mut(row, col), a, scratch);
}
}
}
}
impl<BE: Backend> GGSWExternalProduct<BE> for Module<BE> where Self: GLWEExternalProduct<BE> {}
impl GGSW<Vec<u8>> {
pub fn external_product_tmp_bytes<R, A, B, M, BE: Backend>(
&self,
module: &M,
res_infos: &R,
a_infos: &A,
b_infos: &B,
) -> usize
where
OUT: GGSWInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
R: GGSWInfos,
A: GGSWInfos,
B: GGSWInfos,
M: GGSWExternalProduct<BE>,
{
GLWE::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), ggsw_infos)
module.ggsw_external_product_tmp_bytes(res_infos, a_infos, b_infos)
}
}
@@ -52,54 +121,7 @@ impl<DataSelf: DataMut> GGSW<DataSelf> {
lhs: &GGSW<DataLhs>,
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
use crate::layouts::LWEInfos;
assert_eq!(lhs.n(), self.n());
assert_eq!(rhs.n(), self.n());
assert_eq!(
self.rank(),
lhs.rank(),
"ggsw_out rank: {} != ggsw_in rank: {}",
self.rank(),
lhs.rank()
);
assert_eq!(
self.rank(),
rhs.rank(),
"ggsw_in rank: {} != ggsw_apply rank: {}",
self.rank(),
rhs.rank()
);
assert!(scratch.available() >= GGSW::external_product_scratch_space(module, self, lhs, rhs))
}
let min_dnum: usize = self.dnum().min(lhs.dnum()).into();
(0..(self.rank() + 1).into()).for_each(|col_i| {
(0..min_dnum).for_each(|row_j| {
self.at_mut(row_j, col_i)
.external_product(module, &lhs.at(row_j, col_i), rhs, scratch);
});
(min_dnum..self.dnum().into()).for_each(|row_i| {
self.at_mut(row_i, col_i).data.zero();
});
});
) {
}
pub fn external_product_inplace<DataRhs: DataRef, B: Backend>(
@@ -107,37 +129,6 @@ impl<DataSelf: DataMut> GGSW<DataSelf> {
module: &Module<B>,
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
use crate::layouts::LWEInfos;
assert_eq!(rhs.n(), self.n());
assert_eq!(
self.rank(),
rhs.rank(),
"ggsw_out rank: {} != ggsw_apply: {}",
self.rank(),
rhs.rank()
);
}
(0..(self.rank() + 1).into()).for_each(|col_i| {
(0..self.dnum().into()).for_each(|row_j| {
self.at_mut(row_j, col_i)
.external_product_inplace(module, rhs, scratch);
});
});
) {
}
}

View File

@@ -1,58 +1,59 @@
use poulpy_hal::{
api::{
ScratchAvailable, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig},
layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig},
};
use crate::{
ScratchTakeCore,
layouts::{
GGSWInfos, GGSWToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetDegree, LWEInfos,
prepared::{GGSWCiphertextPreparedToRef, GGSWPrepared},
GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos,
prepared::{GGSWPrepared, GGSWPreparedToRef},
},
};
impl<DataSelf: DataMut> GLWE<DataSelf> {
pub fn external_product_scratch_space<OUT, IN, GGSW, B: Backend>(
module: Module<B>,
out_infos: &OUT,
in_infos: &IN,
ggsw_infos: &GGSW,
pub fn external_product_scratch_space<R, A, B, BE: Backend>(
module: Module<BE>,
res_infos: &R,
a_infos: &A,
b_infos: &B,
) -> usize
where
OUT: GLWEInfos,
IN: GLWEInfos,
GGSW: GGSWInfos,
Module<B>: GLWEExternalProduct<B>,
R: GLWEInfos,
A: GLWEInfos,
B: GGSWInfos,
Module<BE>: GLWEExternalProduct<BE>,
{
module.glwe_external_product_scratch_space(out_infos, in_infos, ggsw_infos)
module.glwe_external_product_scratch_space(res_infos, a_infos, b_infos)
}
pub fn external_product<L, R, B: Backend>(&mut self, module: &Module<B>, lhs: &L, rhs: &R, scratch: &mut Scratch<B>)
pub fn external_product<A, B, BE: Backend>(&mut self, module: &Module<BE>, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
L: GLWEToRef,
R: GGSWToRef,
Module<B>: GLWEExternalProduct<B>,
Scratch<B>: ScratchTakeCore<B>,
A: GLWEToRef,
B: GGSWPreparedToRef<BE>,
Module<BE>: GLWEExternalProduct<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.glwe_external_product(self, lhs, rhs, scratch);
module.glwe_external_product(self, a, b, scratch);
}
pub fn external_product_inplace<R, B: Backend>(&mut self, module: &Module<B>, rhs: &R, scratch: &mut Scratch<B>)
pub fn external_product_inplace<A, BE: Backend>(&mut self, module: &Module<BE>, a: &A, scratch: &mut Scratch<BE>)
where
R: GGSWToRef,
Module<B>: GLWEExternalProduct<B>,
Scratch<B>: ScratchTakeCore<B>,
A: GGSWPreparedToRef<BE>,
Module<BE>: GLWEExternalProduct<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.glwe_external_product_inplace(self, rhs, scratch);
module.glwe_external_product_inplace(self, a, scratch);
}
}
pub trait GLWEExternalProduct<BE: Backend>
where
Self: GetDegree
Self: Sized
+ ModuleN
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
@@ -61,52 +62,48 @@ where
+ VmpApplyDftToDftAdd<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalize<BE>
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes,
+ VecZnxNormalize<BE>,
{
#[allow(clippy::too_many_arguments)]
fn glwe_external_product_scratch_space<OUT, IN, GGSW>(&self, out_infos: &OUT, in_infos: &IN, ggsw_infos: &GGSW) -> usize
fn glwe_external_product_scratch_space<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where
OUT: GLWEInfos,
IN: GLWEInfos,
GGSW: GGSWInfos,
R: GLWEInfos,
A: GLWEInfos,
B: GGSWInfos,
{
let in_size: usize = in_infos
let in_size: usize = a_infos
.k()
.div_ceil(ggsw_infos.base2k())
.div_ceil(ggsw_infos.dsize().into()) as usize;
let out_size: usize = out_infos.size();
let ggsw_size: usize = ggsw_infos.size();
let res_dft: usize = self.bytes_of_vec_znx_dft((ggsw_infos.rank() + 1).into(), ggsw_size);
let a_dft: usize = self.bytes_of_vec_znx_dft((ggsw_infos.rank() + 1).into(), in_size);
.div_ceil(b_infos.base2k())
.div_ceil(b_infos.dsize().into()) as usize;
let out_size: usize = res_infos.size();
let ggsw_size: usize = b_infos.size();
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), ggsw_size);
let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size);
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
out_size,
in_size,
in_size, // rows
(ggsw_infos.rank() + 1).into(), // cols in
(ggsw_infos.rank() + 1).into(), // cols out
in_size, // rows
(b_infos.rank() + 1).into(), // cols in
(b_infos.rank() + 1).into(), // cols out
ggsw_size,
);
let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
if in_infos.base2k() == ggsw_infos.base2k() {
if a_infos.base2k() == b_infos.base2k() {
res_dft + a_dft + (vmp | normalize_big)
} else {
let normalize_conv: usize = VecZnx::bytes_of(self.n().into(), (ggsw_infos.rank() + 1).into(), in_size);
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size);
res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
}
}
fn glwe_external_product_inplace<R, D>(&self, res: &mut R, ggsw: &D, scratch: &mut Scratch<BE>)
fn glwe_external_product_inplace<R, D>(&self, res: &mut R, a: &D, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
D: GGSWCiphertextPreparedToRef<BE>,
D: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let rhs: &GGSWPrepared<&[u8], BE> = &ggsw.to_ref();
let rhs: &GGSWPrepared<&[u8], BE> = &a.to_ref();
let basek_in: usize = res.base2k().into();
let basek_ggsw: usize = rhs.base2k().into();
@@ -124,8 +121,8 @@ where
let dsize: usize = rhs.dsize().into();
let a_size: usize = (res.size() * basek_in).div_ceil(basek_ggsw);
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(res.n().into(), cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(res.n().into(), cols, a_size.div_ceil(dsize));
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
a_dft.data_mut().fill(0);
if basek_in == basek_ggsw {
@@ -153,7 +150,7 @@ where
}
}
} else {
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n().into(), cols, a_size);
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self, cols, a_size);
for j in 0..cols {
self.vec_znx_normalize(
@@ -211,7 +208,7 @@ where
where
R: GLWEToMut,
A: GLWEToRef,
D: GGSWCiphertextPreparedToRef<BE>,
D: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
@@ -239,8 +236,8 @@ where
let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw);
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), cols, a_size.div_ceil(dsize));
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
a_dft.data_mut().fill(0);
if basek_in == basek_ggsw {
@@ -268,7 +265,7 @@ where
}
}
} else {
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n().into(), cols, a_size);
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self, cols, a_size);
for j in 0..cols {
self.vec_znx_normalize(
@@ -324,7 +321,7 @@ where
}
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE> where
Self: GetDegree
Self: ModuleN
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes