This commit is contained in:
Pro7ech
2025-10-15 11:11:57 +02:00
parent 008b800c01
commit c604676f2e
13 changed files with 191 additions and 260 deletions

View File

@@ -1,82 +1,46 @@
use poulpy_hal::{ use poulpy_hal::layouts::{Backend, DataMut, Scratch};
api::{
ScratchAvailable, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, use crate::{
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, ScratchTakeCore,
}, external_product::gglwe_ksk::GGLWEExternalProduct,
layouts::{Backend, DataMut, DataRef, Module, Scratch}, layouts::{AutomorphismKey, AutomorphismKeyToRef, GGLWEInfos, GGSWInfos, prepared::GGSWPreparedToRef},
}; };
use crate::layouts::{AutomorphismKey, GGLWEInfos, GGSWInfos, GLWESwitchingKey, prepared::GGSWPrepared};
impl AutomorphismKey<Vec<u8>> { impl AutomorphismKey<Vec<u8>> {
pub fn external_product_scratch_space<B: Backend, OUT, IN, GGSW>( pub fn external_product_tmp_bytes<R, A, B, M, BE: Backend>(
module: &Module<B>, &self,
out_infos: &OUT, module: &M,
in_infos: &IN, res_infos: &R,
ggsw_infos: &GGSW, a_infos: &A,
b_infos: &B,
) -> usize ) -> usize
where where
OUT: GGLWEInfos, R: GGLWEInfos,
IN: GGLWEInfos, A: GGLWEInfos,
GGSW: GGSWInfos, B: GGSWInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, M: GGLWEExternalProduct<BE>,
{ {
GLWESwitchingKey::external_product_scratch_space(module, out_infos, in_infos, ggsw_infos) module.gglwe_external_product_tmp_bytes(res_infos, a_infos, b_infos)
}
pub fn external_product_inplace_scratch_space<B: Backend, OUT, GGSW>(
module: &Module<B>,
out_infos: &OUT,
ggsw_infos: &GGSW,
) -> usize
where
OUT: GGLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GLWESwitchingKey::external_product_inplace_scratch_space(module, out_infos, ggsw_infos)
} }
} }
impl<DataSelf: DataMut> AutomorphismKey<DataSelf> { impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
pub fn external_product<DataLhs: DataRef, DataRhs: DataRef, B: Backend>( pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
&mut self, where
module: &Module<B>, M: GGLWEExternalProduct<BE>,
lhs: &AutomorphismKey<DataLhs>, A: AutomorphismKeyToRef,
rhs: &GGSWPrepared<DataRhs, B>, B: GGSWPreparedToRef<BE>,
scratch: &mut Scratch<B>, Scratch<BE>: ScratchTakeCore<BE>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: ScratchAvailable,
{ {
self.key.external_product(module, &lhs.key, rhs, scratch); module.gglwe_external_product(&mut self.key.key, &a.to_ref().key.key, b, scratch);
} }
pub fn external_product_inplace<DataRhs: DataRef, B: Backend>( pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
&mut self, where
module: &Module<B>, M: GGLWEExternalProduct<BE>,
rhs: &GGSWPrepared<DataRhs, B>, A: GGSWPreparedToRef<BE>,
scratch: &mut Scratch<B>, Scratch<BE>: ScratchTakeCore<BE>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: ScratchAvailable,
{ {
self.key.external_product_inplace(module, rhs, scratch); module.gglwe_external_product_inplace(&mut self.key.key, a, scratch);
} }
} }

View File

@@ -1,143 +1,134 @@
use poulpy_hal::{ use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero};
api::{
ScratchAvailable, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, use crate::{
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, GLWEExternalProduct, ScratchTakeCore,
layouts::{
GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GGSWInfos, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyToRef,
prepared::{GGSWPrepared, GGSWPreparedToRef},
}, },
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
}; };
use crate::layouts::{GGLWEInfos, GGSWInfos, GLWE, GLWESwitchingKey, prepared::GGSWPrepared}; pub trait GGLWEExternalProduct<BE: Backend>
impl GLWESwitchingKey<Vec<u8>> {
pub fn external_product_scratch_space<B: Backend, OUT, IN, GGSW>(
module: &Module<B>,
out_infos: &OUT,
in_infos: &IN,
ggsw_infos: &GGSW,
) -> usize
where where
OUT: GGLWEInfos, Self: GLWEExternalProduct<BE>,
IN: GGLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{ {
GLWE::external_product_scratch_space( fn gglwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
module, where
&out_infos.glwe_layout(), R: GGLWEInfos,
&in_infos.glwe_layout(), A: GGLWEInfos,
ggsw_infos, B: GGSWInfos,
) {
self.glwe_external_product_scratch_space(res_infos, a_infos, b_infos)
} }
pub fn external_product_inplace_scratch_space<B: Backend, OUT, GGSW>( fn gglwe_external_product<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
module: &Module<B>, where
out_infos: &OUT, R: GGLWEToMut,
ggsw_infos: &GGSW, A: GGLWEToRef,
B: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut();
let a: &GGLWE<&[u8]> = &a.to_ref();
let b: &GGSWPrepared<&[u8], BE> = &b.to_ref();
assert_eq!(
res.rank_in(),
a.rank_in(),
"res input rank_in: {} != a input rank_in: {}",
res.rank_in(),
a.rank_in()
);
assert_eq!(
a.rank_out(),
b.rank(),
"a output rank_out: {} != b rank: {}",
a.rank_out(),
b.rank()
);
assert_eq!(
res.rank_out(),
b.rank(),
"res output rank_out: {} != b rank: {}",
res.rank_out(),
b.rank()
);
for row in 0..res.dnum().into() {
for col in 0..res.rank_in().into() {
self.glwe_external_product(&mut res.at_mut(row, col), &a.at(row, col), b, scratch);
}
}
for row in res.dnum().min(a.dnum()).into()..res.dnum().into() {
for col in 0..res.rank_in().into() {
res.at_mut(row, col).data_mut().zero();
}
}
}
fn gglwe_external_product_inplace<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
where
R: GGLWEToMut,
A: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut();
let a: &GGSWPrepared<&[u8], BE> = &a.to_ref();
assert_eq!(
res.rank_out(),
a.rank(),
"res output rank: {} != a rank: {}",
res.rank_out(),
a.rank()
);
for row in 0..res.dnum().into() {
for col in 0..res.rank_in().into() {
self.glwe_external_product_inplace(&mut res.at_mut(row, col), a, scratch);
}
}
}
}
impl<BE: Backend> GGLWEExternalProduct<BE> for Module<BE> where Self: GLWEExternalProduct<BE> {}
impl GLWESwitchingKey<Vec<u8>> {
pub fn external_product_tmp_bytes<R, A, B, M, BE: Backend>(
&self,
module: &M,
res_infos: &R,
a_infos: &A,
b_infos: &B,
) -> usize ) -> usize
where where
OUT: GGLWEInfos, R: GGLWEInfos,
GGSW: GGSWInfos, A: GGLWEInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, B: GGSWInfos,
M: GGLWEExternalProduct<BE>,
{ {
GLWE::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), ggsw_infos) module.gglwe_external_product_tmp_bytes(res_infos, a_infos, b_infos)
} }
} }
impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> { impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
pub fn external_product<DataLhs: DataRef, DataRhs: DataRef, B: Backend>( pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
&mut self, where
module: &Module<B>, M: GGLWEExternalProduct<BE>,
lhs: &GLWESwitchingKey<DataLhs>, A: GLWESwitchingKeyToRef,
rhs: &GGSWPrepared<DataRhs, B>, B: GGSWPreparedToRef<BE>,
scratch: &mut Scratch<B>, Scratch<BE>: ScratchTakeCore<BE>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: ScratchAvailable,
{ {
#[cfg(debug_assertions)] module.gglwe_external_product(&mut self.key, &a.to_ref().key, b, scratch);
}
pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
where
M: GGLWEExternalProduct<BE>,
A: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
use crate::layouts::GLWEInfos; module.gglwe_external_product_inplace(&mut self.key, a, scratch);
assert_eq!(
self.rank_in(),
lhs.rank_in(),
"ksk_out input rank: {} != ksk_in input rank: {}",
self.rank_in(),
lhs.rank_in()
);
assert_eq!(
lhs.rank_out(),
rhs.rank(),
"ksk_in output rank: {} != ggsw rank: {}",
self.rank_out(),
rhs.rank()
);
assert_eq!(
self.rank_out(),
rhs.rank(),
"ksk_out output rank: {} != ggsw rank: {}",
self.rank_out(),
rhs.rank()
);
}
(0..self.rank_in().into()).for_each(|col_i| {
(0..self.dnum().into()).for_each(|row_j| {
self.at_mut(row_j, col_i)
.external_product(module, &lhs.at(row_j, col_i), rhs, scratch);
});
});
(self.dnum().min(lhs.dnum()).into()..self.dnum().into()).for_each(|row_i| {
(0..self.rank_in().into()).for_each(|col_j| {
self.at_mut(row_i, col_j).data.zero();
});
});
}
pub fn external_product_inplace<DataRhs: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: ScratchAvailable,
{
#[cfg(debug_assertions)]
{
use crate::layouts::GLWEInfos;
assert_eq!(
self.rank_out(),
rhs.rank(),
"ksk_out output rank: {} != ggsw rank: {}",
self.rank_out(),
rhs.rank()
);
}
(0..self.rank_in().into()).for_each(|col_i| {
(0..self.dnum().into()).for_each(|row_j| {
self.at_mut(row_j, col_i)
.external_product_inplace(module, rhs, scratch);
});
});
} }
} }

View File

@@ -1,6 +1,6 @@
use poulpy_hal::{ use poulpy_hal::{
api::ScratchAvailable, api::ScratchAvailable,
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, layouts::{Backend, DataMut, Module, Scratch, ZnxZero},
}; };
use crate::{ use crate::{
@@ -115,20 +115,22 @@ impl GGSW<Vec<u8>> {
} }
impl<DataSelf: DataMut> GGSW<DataSelf> { impl<DataSelf: DataMut> GGSW<DataSelf> {
pub fn external_product<DataLhs: DataRef, DataRhs: DataRef, B: Backend>( pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
&mut self, where
module: &Module<B>, M: GGSWExternalProduct<BE>,
lhs: &GGSW<DataLhs>, A: GGSWToRef,
rhs: &GGSWPrepared<DataRhs, B>, B: GGSWPreparedToRef<BE>,
scratch: &mut Scratch<B>, Scratch<BE>: ScratchTakeCore<BE>,
) { {
module.ggsw_external_product(self, a, b, scratch);
} }
pub fn external_product_inplace<DataRhs: DataRef, B: Backend>( pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
&mut self, where
module: &Module<B>, M: GGSWExternalProduct<BE>,
rhs: &GGSWPrepared<DataRhs, B>, A: GGSWPreparedToRef<BE>,
scratch: &mut Scratch<B>, Scratch<BE>: ScratchTakeCore<BE>,
) { {
module.ggsw_external_product_inplace(self, a, scratch);
} }
} }

View File

@@ -15,35 +15,30 @@ use crate::{
}; };
impl<DataSelf: DataMut> GLWE<DataSelf> { impl<DataSelf: DataMut> GLWE<DataSelf> {
pub fn external_product_scratch_space<R, A, B, BE: Backend>( pub fn external_product_scratch_space<R, A, B, M, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
module: Module<BE>,
res_infos: &R,
a_infos: &A,
b_infos: &B,
) -> usize
where where
R: GLWEInfos, R: GLWEInfos,
A: GLWEInfos, A: GLWEInfos,
B: GGSWInfos, B: GGSWInfos,
Module<BE>: GLWEExternalProduct<BE>, M: GLWEExternalProduct<BE>,
{ {
module.glwe_external_product_scratch_space(res_infos, a_infos, b_infos) module.glwe_external_product_scratch_space(res_infos, a_infos, b_infos)
} }
pub fn external_product<A, B, BE: Backend>(&mut self, module: &Module<BE>, a: &A, b: &B, scratch: &mut Scratch<BE>) pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where where
A: GLWEToRef, A: GLWEToRef,
B: GGSWPreparedToRef<BE>, B: GGSWPreparedToRef<BE>,
Module<BE>: GLWEExternalProduct<BE>, M: GLWEExternalProduct<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
module.glwe_external_product(self, a, b, scratch); module.glwe_external_product(self, a, b, scratch);
} }
pub fn external_product_inplace<A, BE: Backend>(&mut self, module: &Module<BE>, a: &A, scratch: &mut Scratch<BE>) pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
where where
A: GGSWPreparedToRef<BE>, A: GGSWPreparedToRef<BE>,
Module<BE>: GLWEExternalProduct<BE>, M: GLWEExternalProduct<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
module.glwe_external_product_inplace(self, a, scratch); module.glwe_external_product_inplace(self, a, scratch);

View File

@@ -161,14 +161,7 @@ impl AutomorphismKeyCompressed<Vec<u8>> {
module.bytes_of_automorphism_key_compressed_from_infos(infos) module.bytes_of_automorphism_key_compressed_from_infos(infos)
} }
pub fn bytes_of<M>( pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize
module: &M,
base2k: Base2K,
k: TorusPrecision,
rank: Rank,
dnum: Dnum,
dsize: Dsize,
) -> usize
where where
M: AutomorphismKeyCompressedAlloc, M: AutomorphismKeyCompressedAlloc,
{ {

View File

@@ -217,14 +217,7 @@ impl GGLWECompressed<Vec<u8>> {
module.bytes_of_gglwe_compressed_from_infos(infos) module.bytes_of_gglwe_compressed_from_infos(infos)
} }
pub fn byte_of<M>( pub fn byte_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, dsize: Dsize) -> usize
module: &M,
base2k: Base2K,
k: TorusPrecision,
rank_in: Rank,
dnum: Dnum,
dsize: Dsize,
) -> usize
where where
M: GGLWECompressedAlloc, M: GGLWECompressedAlloc,
{ {

View File

@@ -166,14 +166,7 @@ impl GLWESwitchingKeyCompressed<Vec<u8>> {
module.bytes_of_glwe_switching_key_compressed_from_infos(infos) module.bytes_of_glwe_switching_key_compressed_from_infos(infos)
} }
pub fn bytes_of<M>( pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, dsize: Dsize) -> usize
module: &M,
base2k: Base2K,
k: TorusPrecision,
rank_in: Rank,
dnum: Dnum,
dsize: Dsize,
) -> usize
where where
M: GLWESwitchingKeyCompressedAlloc, M: GLWESwitchingKeyCompressedAlloc,
{ {

View File

@@ -163,14 +163,7 @@ impl TensorKeyCompressed<Vec<u8>> {
module.bytes_of_tensor_key_compressed_from_infos(infos) module.bytes_of_tensor_key_compressed_from_infos(infos)
} }
pub fn bytes_of<M>( pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize
module: &M,
base2k: Base2K,
k: TorusPrecision,
rank: Rank,
dnum: Dnum,
dsize: Dsize,
) -> usize
where where
M: TensorKeyCompressedAlloc, M: TensorKeyCompressedAlloc,
{ {

View File

@@ -193,14 +193,7 @@ impl GGSWCompressed<Vec<u8>> {
module.bytes_of_ggsw_compressed_key_from_infos(infos) module.bytes_of_ggsw_compressed_key_from_infos(infos)
} }
pub fn bytes_of<M>( pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize
module: &M,
base2k: Base2K,
k: TorusPrecision,
rank: Rank,
dnum: Dnum,
dsize: Dsize,
) -> usize
where where
M: GGSWCompressedAlloc, M: GGSWCompressedAlloc,
{ {

View File

@@ -304,7 +304,15 @@ impl GGLWE<Vec<u8>> {
module.alloc_glwe_from_infos(infos) module.alloc_glwe_from_infos(infos)
} }
pub fn alloc<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize) -> Self pub fn alloc<M>(
module: &M,
base2k: Base2K,
k: TorusPrecision,
rank_in: Rank,
rank_out: Rank,
dnum: Dnum,
dsize: Dsize,
) -> Self
where where
M: GGLWEAlloc, M: GGLWEAlloc,
{ {

View File

@@ -123,7 +123,10 @@ impl<B: Backend> LWEToGLWESwitchingKeyPrepared<Vec<u8>, B>{
module.alloc_lwe_to_glwe_switching_key_prepared_from_infos(infos) module.alloc_lwe_to_glwe_switching_key_prepared_from_infos(infos)
} }
pub fn alloc<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self where M: LWEToGLWESwitchingKeyPreparedAlloc<B>, { pub fn alloc<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self
where
M: LWEToGLWESwitchingKeyPreparedAlloc<B>,
{
module.alloc_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum) module.alloc_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum)
} }
@@ -135,7 +138,10 @@ impl<B: Backend> LWEToGLWESwitchingKeyPrepared<Vec<u8>, B>{
module.bytes_of_lwe_to_glwe_switching_key_prepared_from_infos(infos) module.bytes_of_lwe_to_glwe_switching_key_prepared_from_infos(infos)
} }
pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize where M: LWEToGLWESwitchingKeyPreparedAlloc<B>,{ pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize
where
M: LWEToGLWESwitchingKeyPreparedAlloc<B>,
{
module.bytes_of_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum) module.bytes_of_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum)
} }
} }