This commit is contained in:
Pro7ech
2025-10-15 11:11:57 +02:00
parent 008b800c01
commit c604676f2e
13 changed files with 191 additions and 260 deletions

View File

@@ -1,82 +1,46 @@
use poulpy_hal::{
api::{
ScratchAvailable, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch},
use poulpy_hal::layouts::{Backend, DataMut, Scratch};
use crate::{
ScratchTakeCore,
external_product::gglwe_ksk::GGLWEExternalProduct,
layouts::{AutomorphismKey, AutomorphismKeyToRef, GGLWEInfos, GGSWInfos, prepared::GGSWPreparedToRef},
};
use crate::layouts::{AutomorphismKey, GGLWEInfos, GGSWInfos, GLWESwitchingKey, prepared::GGSWPrepared};
impl AutomorphismKey<Vec<u8>> {
pub fn external_product_scratch_space<B: Backend, OUT, IN, GGSW>(
module: &Module<B>,
out_infos: &OUT,
in_infos: &IN,
ggsw_infos: &GGSW,
pub fn external_product_tmp_bytes<R, A, B, M, BE: Backend>(
&self,
module: &M,
res_infos: &R,
a_infos: &A,
b_infos: &B,
) -> usize
where
OUT: GGLWEInfos,
IN: GGLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
R: GGLWEInfos,
A: GGLWEInfos,
B: GGSWInfos,
M: GGLWEExternalProduct<BE>,
{
GLWESwitchingKey::external_product_scratch_space(module, out_infos, in_infos, ggsw_infos)
}
pub fn external_product_inplace_scratch_space<B: Backend, OUT, GGSW>(
module: &Module<B>,
out_infos: &OUT,
ggsw_infos: &GGSW,
) -> usize
where
OUT: GGLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GLWESwitchingKey::external_product_inplace_scratch_space(module, out_infos, ggsw_infos)
module.gglwe_external_product_tmp_bytes(res_infos, a_infos, b_infos)
}
}
impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
pub fn external_product<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
lhs: &AutomorphismKey<DataLhs>,
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: ScratchAvailable,
pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
M: GGLWEExternalProduct<BE>,
A: AutomorphismKeyToRef,
B: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
self.key.external_product(module, &lhs.key, rhs, scratch);
module.gglwe_external_product(&mut self.key.key, &a.to_ref().key.key, b, scratch);
}
pub fn external_product_inplace<DataRhs: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: ScratchAvailable,
pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
where
M: GGLWEExternalProduct<BE>,
A: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
self.key.external_product_inplace(module, rhs, scratch);
module.gglwe_external_product_inplace(&mut self.key.key, a, scratch);
}
}

View File

@@ -1,143 +1,134 @@
use poulpy_hal::{
api::{
ScratchAvailable, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero};
use crate::{
GLWEExternalProduct, ScratchTakeCore,
layouts::{
GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GGSWInfos, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyToRef,
prepared::{GGSWPrepared, GGSWPreparedToRef},
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
};
use crate::layouts::{GGLWEInfos, GGSWInfos, GLWE, GLWESwitchingKey, prepared::GGSWPrepared};
impl GLWESwitchingKey<Vec<u8>> {
pub fn external_product_scratch_space<B: Backend, OUT, IN, GGSW>(
module: &Module<B>,
out_infos: &OUT,
in_infos: &IN,
ggsw_infos: &GGSW,
) -> usize
pub trait GGLWEExternalProduct<BE: Backend>
where
Self: GLWEExternalProduct<BE>,
{
fn gglwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where
OUT: GGLWEInfos,
IN: GGLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
R: GGLWEInfos,
A: GGLWEInfos,
B: GGSWInfos,
{
GLWE::external_product_scratch_space(
module,
&out_infos.glwe_layout(),
&in_infos.glwe_layout(),
ggsw_infos,
)
self.glwe_external_product_scratch_space(res_infos, a_infos, b_infos)
}
pub fn external_product_inplace_scratch_space<B: Backend, OUT, GGSW>(
module: &Module<B>,
out_infos: &OUT,
ggsw_infos: &GGSW,
fn gglwe_external_product<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
R: GGLWEToMut,
A: GGLWEToRef,
B: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut();
let a: &GGLWE<&[u8]> = &a.to_ref();
let b: &GGSWPrepared<&[u8], BE> = &b.to_ref();
assert_eq!(
res.rank_in(),
a.rank_in(),
"res input rank_in: {} != a input rank_in: {}",
res.rank_in(),
a.rank_in()
);
assert_eq!(
a.rank_out(),
b.rank(),
"a output rank_out: {} != b rank: {}",
a.rank_out(),
b.rank()
);
assert_eq!(
res.rank_out(),
b.rank(),
"res output rank_out: {} != b rank: {}",
res.rank_out(),
b.rank()
);
for row in 0..res.dnum().into() {
for col in 0..res.rank_in().into() {
self.glwe_external_product(&mut res.at_mut(row, col), &a.at(row, col), b, scratch);
}
}
for row in res.dnum().min(a.dnum()).into()..res.dnum().into() {
for col in 0..res.rank_in().into() {
res.at_mut(row, col).data_mut().zero();
}
}
}
fn gglwe_external_product_inplace<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
where
R: GGLWEToMut,
A: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut();
let a: &GGSWPrepared<&[u8], BE> = &a.to_ref();
assert_eq!(
res.rank_out(),
a.rank(),
"res output rank: {} != a rank: {}",
res.rank_out(),
a.rank()
);
for row in 0..res.dnum().into() {
for col in 0..res.rank_in().into() {
self.glwe_external_product_inplace(&mut res.at_mut(row, col), a, scratch);
}
}
}
}
impl<BE: Backend> GGLWEExternalProduct<BE> for Module<BE> where Self: GLWEExternalProduct<BE> {}
impl GLWESwitchingKey<Vec<u8>> {
pub fn external_product_tmp_bytes<R, A, B, M, BE: Backend>(
&self,
module: &M,
res_infos: &R,
a_infos: &A,
b_infos: &B,
) -> usize
where
OUT: GGLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
R: GGLWEInfos,
A: GGLWEInfos,
B: GGSWInfos,
M: GGLWEExternalProduct<BE>,
{
GLWE::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), ggsw_infos)
module.gglwe_external_product_tmp_bytes(res_infos, a_infos, b_infos)
}
}
impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
pub fn external_product<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
lhs: &GLWESwitchingKey<DataLhs>,
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: ScratchAvailable,
pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
M: GGLWEExternalProduct<BE>,
A: GLWESwitchingKeyToRef,
B: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
#[cfg(debug_assertions)]
{
use crate::layouts::GLWEInfos;
assert_eq!(
self.rank_in(),
lhs.rank_in(),
"ksk_out input rank: {} != ksk_in input rank: {}",
self.rank_in(),
lhs.rank_in()
);
assert_eq!(
lhs.rank_out(),
rhs.rank(),
"ksk_in output rank: {} != ggsw rank: {}",
self.rank_out(),
rhs.rank()
);
assert_eq!(
self.rank_out(),
rhs.rank(),
"ksk_out output rank: {} != ggsw rank: {}",
self.rank_out(),
rhs.rank()
);
module.gglwe_external_product(&mut self.key, &a.to_ref().key, b, scratch);
}
(0..self.rank_in().into()).for_each(|col_i| {
(0..self.dnum().into()).for_each(|row_j| {
self.at_mut(row_j, col_i)
.external_product(module, &lhs.at(row_j, col_i), rhs, scratch);
});
});
(self.dnum().min(lhs.dnum()).into()..self.dnum().into()).for_each(|row_i| {
(0..self.rank_in().into()).for_each(|col_j| {
self.at_mut(row_i, col_j).data.zero();
});
});
}
pub fn external_product_inplace<DataRhs: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: ScratchAvailable,
pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
where
M: GGLWEExternalProduct<BE>,
A: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
#[cfg(debug_assertions)]
{
use crate::layouts::GLWEInfos;
assert_eq!(
self.rank_out(),
rhs.rank(),
"ksk_out output rank: {} != ggsw rank: {}",
self.rank_out(),
rhs.rank()
);
}
(0..self.rank_in().into()).for_each(|col_i| {
(0..self.dnum().into()).for_each(|row_j| {
self.at_mut(row_j, col_i)
.external_product_inplace(module, rhs, scratch);
});
});
module.gglwe_external_product_inplace(&mut self.key, a, scratch);
}
}

View File

@@ -1,6 +1,6 @@
use poulpy_hal::{
api::ScratchAvailable,
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
layouts::{Backend, DataMut, Module, Scratch, ZnxZero},
};
use crate::{
@@ -115,20 +115,22 @@ impl GGSW<Vec<u8>> {
}
impl<DataSelf: DataMut> GGSW<DataSelf> {
pub fn external_product<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
lhs: &GGSW<DataLhs>,
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) {
pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
M: GGSWExternalProduct<BE>,
A: GGSWToRef,
B: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.ggsw_external_product(self, a, b, scratch);
}
pub fn external_product_inplace<DataRhs: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
rhs: &GGSWPrepared<DataRhs, B>,
scratch: &mut Scratch<B>,
) {
pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
where
M: GGSWExternalProduct<BE>,
A: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.ggsw_external_product_inplace(self, a, scratch);
}
}

View File

@@ -15,35 +15,30 @@ use crate::{
};
impl<DataSelf: DataMut> GLWE<DataSelf> {
pub fn external_product_scratch_space<R, A, B, BE: Backend>(
module: Module<BE>,
res_infos: &R,
a_infos: &A,
b_infos: &B,
) -> usize
pub fn external_product_scratch_space<R, A, B, M, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GGSWInfos,
Module<BE>: GLWEExternalProduct<BE>,
M: GLWEExternalProduct<BE>,
{
module.glwe_external_product_scratch_space(res_infos, a_infos, b_infos)
}
pub fn external_product<A, B, BE: Backend>(&mut self, module: &Module<BE>, a: &A, b: &B, scratch: &mut Scratch<BE>)
pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
A: GLWEToRef,
B: GGSWPreparedToRef<BE>,
Module<BE>: GLWEExternalProduct<BE>,
M: GLWEExternalProduct<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.glwe_external_product(self, a, b, scratch);
}
pub fn external_product_inplace<A, BE: Backend>(&mut self, module: &Module<BE>, a: &A, scratch: &mut Scratch<BE>)
pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
where
A: GGSWPreparedToRef<BE>,
Module<BE>: GLWEExternalProduct<BE>,
M: GLWEExternalProduct<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.glwe_external_product_inplace(self, a, scratch);

View File

@@ -161,14 +161,7 @@ impl AutomorphismKeyCompressed<Vec<u8>> {
module.bytes_of_automorphism_key_compressed_from_infos(infos)
}
pub fn bytes_of<M>(
module: &M,
base2k: Base2K,
k: TorusPrecision,
rank: Rank,
dnum: Dnum,
dsize: Dsize,
) -> usize
pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize
where
M: AutomorphismKeyCompressedAlloc,
{

View File

@@ -217,14 +217,7 @@ impl GGLWECompressed<Vec<u8>> {
module.bytes_of_gglwe_compressed_from_infos(infos)
}
pub fn byte_of<M>(
module: &M,
base2k: Base2K,
k: TorusPrecision,
rank_in: Rank,
dnum: Dnum,
dsize: Dsize,
) -> usize
pub fn byte_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, dsize: Dsize) -> usize
where
M: GGLWECompressedAlloc,
{

View File

@@ -166,14 +166,7 @@ impl GLWESwitchingKeyCompressed<Vec<u8>> {
module.bytes_of_glwe_switching_key_compressed_from_infos(infos)
}
pub fn bytes_of<M>(
module: &M,
base2k: Base2K,
k: TorusPrecision,
rank_in: Rank,
dnum: Dnum,
dsize: Dsize,
) -> usize
pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, dsize: Dsize) -> usize
where
M: GLWESwitchingKeyCompressedAlloc,
{

View File

@@ -163,14 +163,7 @@ impl TensorKeyCompressed<Vec<u8>> {
module.bytes_of_tensor_key_compressed_from_infos(infos)
}
pub fn bytes_of<M>(
module: &M,
base2k: Base2K,
k: TorusPrecision,
rank: Rank,
dnum: Dnum,
dsize: Dsize,
) -> usize
pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize
where
M: TensorKeyCompressedAlloc,
{

View File

@@ -193,14 +193,7 @@ impl GGSWCompressed<Vec<u8>> {
module.bytes_of_ggsw_compressed_key_from_infos(infos)
}
pub fn bytes_of<M>(
module: &M,
base2k: Base2K,
k: TorusPrecision,
rank: Rank,
dnum: Dnum,
dsize: Dsize,
) -> usize
pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize
where
M: GGSWCompressedAlloc,
{

View File

@@ -91,7 +91,7 @@ pub trait LWECompressedAlloc {
}
}
impl<B: Backend> LWECompressedAlloc for Module<B>{}
impl<B: Backend> LWECompressedAlloc for Module<B> {}
impl LWECompressed<Vec<u8>> {
pub fn alloc_from_infos<A, M>(module: &M, infos: &A) -> Self

View File

@@ -147,7 +147,7 @@ where
}
}
impl<B: Backend> LWESwitchingKeyCompressedAlloc for Module<B> where Self: GLWESwitchingKeyCompressedAlloc{}
impl<B: Backend> LWESwitchingKeyCompressedAlloc for Module<B> where Self: GLWESwitchingKeyCompressedAlloc {}
impl LWESwitchingKeyCompressed<Vec<u8>> {
pub fn alloc_from_infos<A, M>(module: &M, infos: &A) -> Self

View File

@@ -304,7 +304,15 @@ impl GGLWE<Vec<u8>> {
module.alloc_glwe_from_infos(infos)
}
pub fn alloc<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize) -> Self
pub fn alloc<M>(
module: &M,
base2k: Base2K,
k: TorusPrecision,
rank_in: Rank,
rank_out: Rank,
dnum: Dnum,
dsize: Dsize,
) -> Self
where
M: GGLWEAlloc,
{

View File

@@ -114,7 +114,7 @@ where
impl<B: Backend> LWEToGLWESwitchingKeyPreparedAlloc<B> for Module<B> where Self: GLWESwitchingKeyPreparedAlloc<B> {}
impl<B: Backend> LWEToGLWESwitchingKeyPrepared<Vec<u8>, B>{
impl<B: Backend> LWEToGLWESwitchingKeyPrepared<Vec<u8>, B> {
pub fn alloc_from_infos<A, M>(module: &M, infos: &A) -> Self
where
A: GGLWEInfos,
@@ -123,7 +123,10 @@ impl<B: Backend> LWEToGLWESwitchingKeyPrepared<Vec<u8>, B>{
module.alloc_lwe_to_glwe_switching_key_prepared_from_infos(infos)
}
pub fn alloc<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self where M: LWEToGLWESwitchingKeyPreparedAlloc<B>, {
pub fn alloc<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self
where
M: LWEToGLWESwitchingKeyPreparedAlloc<B>,
{
module.alloc_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum)
}
@@ -135,7 +138,10 @@ impl<B: Backend> LWEToGLWESwitchingKeyPrepared<Vec<u8>, B>{
module.bytes_of_lwe_to_glwe_switching_key_prepared_from_infos(infos)
}
pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize where M: LWEToGLWESwitchingKeyPreparedAlloc<B>,{
pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize
where
M: LWEToGLWESwitchingKeyPreparedAlloc<B>,
{
module.bytes_of_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum)
}
}