This commit is contained in:
Pro7ech
2025-10-15 10:48:14 +02:00
parent a5df85170d
commit 008b800c01
74 changed files with 890 additions and 871 deletions

View File

@@ -1,13 +1,13 @@
use poulpy_hal::{
api::{ScratchAvailable, ScratchTakeBasic},
layouts::{Backend, Module, Scratch},
api::{ModuleN, ScratchAvailable, ScratchTakeBasic, SvpPPolBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
layouts::{Backend, Scratch},
};
use crate::{
dist::Distribution,
layouts::{
AutomorphismKey, Degree, GGLWE, GGLWEInfos, GGSW, GGSWInfos, GLWE, GLWEInfos, GLWEPlaintext, GLWEPublicKey, GLWESecret,
GLWESwitchingKey, GetDegree, Rank, TensorKey,
AutomorphismKey, GGLWE, GGLWEInfos, GGSW, GGSWInfos, GLWE, GLWEInfos, GLWEPlaintext, GLWEPublicKey, GLWESecret,
GLWESwitchingKey, Rank, TensorKey,
prepared::{
AutomorphismKeyPrepared, GGLWEPrepared, GGSWPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared,
GLWESwitchingKeyPrepared, TensorKeyPrepared,
@@ -17,12 +17,14 @@ use crate::{
pub trait ScratchTakeCore<B: Backend>
where
Self: ScratchTakeBasic<B> + ScratchAvailable,
Self: ScratchTakeBasic + ScratchAvailable,
{
fn take_glwe_ct<A>(&mut self, module: &Module<B>, infos: &A) -> (GLWE<&mut [u8]>, &mut Self)
fn take_glwe_ct<A, M>(&mut self, module: &M, infos: &A) -> (GLWE<&mut [u8]>, &mut Self)
where
A: GLWEInfos,
M: ModuleN,
{
assert_eq!(module.n() as u32, infos.n());
let (data, scratch) = self.take_vec_znx(module, (infos.rank() + 1).into(), infos.size());
(
GLWE {
@@ -34,25 +36,28 @@ where
)
}
fn take_glwe_ct_slice<A>(&mut self, size: usize, infos: &A) -> (Vec<GLWE<&mut [u8]>>, &mut Self)
fn take_glwe_ct_slice<A, M>(&mut self, module: &M, size: usize, infos: &A) -> (Vec<GLWE<&mut [u8]>>, &mut Self)
where
A: GLWEInfos,
M: ModuleN,
{
let mut scratch: &mut Scratch<B> = self;
let mut scratch: &mut Self = self;
let mut cts: Vec<GLWE<&mut [u8]>> = Vec::with_capacity(size);
for _ in 0..size {
let (ct, new_scratch) = scratch.take_glwe_ct(infos);
let (ct, new_scratch) = scratch.take_glwe_ct(module, infos);
scratch = new_scratch;
cts.push(ct);
}
(cts, scratch)
}
fn take_glwe_pt<A>(&mut self, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self)
fn take_glwe_pt<A, M>(&mut self, module: &M, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self)
where
A: GLWEInfos,
M: ModuleN,
{
let (data, scratch) = self.take_vec_znx(infos.n().into(), 1, infos.size());
assert_eq!(module.n() as u32, infos.n());
let (data, scratch) = self.take_vec_znx(module, 1, infos.size());
(
GLWEPlaintext {
k: infos.k(),
@@ -63,12 +68,14 @@ where
)
}
fn take_gglwe<A>(&mut self, infos: &A) -> (GGLWE<&mut [u8]>, &mut Self)
fn take_gglwe<A, M>(&mut self, module: &M, infos: &A) -> (GGLWE<&mut [u8]>, &mut Self)
where
A: GGLWEInfos,
M: ModuleN,
{
assert_eq!(module.n() as u32, infos.n());
let (data, scratch) = self.take_mat_znx(
infos.n().into(),
module,
infos.dnum().0.div_ceil(infos.dsize().0) as usize,
infos.rank_in().into(),
(infos.rank_out() + 1).into(),
@@ -85,12 +92,14 @@ where
)
}
fn take_gglwe_prepared<A>(&mut self, infos: &A) -> (GGLWEPrepared<&mut [u8], B>, &mut Self)
fn take_gglwe_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GGLWEPrepared<&mut [u8], B>, &mut Self)
where
A: GGLWEInfos,
M: ModuleN + VmpPMatBytesOf,
{
assert_eq!(module.n() as u32, infos.n());
let (data, scratch) = self.take_vmp_pmat(
infos.n().into(),
module,
infos.dnum().into(),
infos.rank_in().into(),
(infos.rank_out() + 1).into(),
@@ -107,12 +116,14 @@ where
)
}
fn take_ggsw<A>(&mut self, infos: &A) -> (GGSW<&mut [u8]>, &mut Self)
fn take_ggsw<A, M>(&mut self, module: &M, infos: &A) -> (GGSW<&mut [u8]>, &mut Self)
where
A: GGSWInfos,
M: ModuleN,
{
assert_eq!(module.n() as u32, infos.n());
let (data, scratch) = self.take_mat_znx(
infos.n().into(),
module,
infos.dnum().into(),
(infos.rank() + 1).into(),
(infos.rank() + 1).into(),
@@ -129,12 +140,14 @@ where
)
}
fn take_ggsw_prepared<A>(&mut self, infos: &A) -> (GGSWPrepared<&mut [u8], B>, &mut Self)
fn take_ggsw_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GGSWPrepared<&mut [u8], B>, &mut Self)
where
A: GGSWInfos,
M: ModuleN + VmpPMatBytesOf,
{
assert_eq!(module.n() as u32, infos.n());
let (data, scratch) = self.take_vmp_pmat(
infos.n().into(),
module,
infos.dnum().into(),
(infos.rank() + 1).into(),
(infos.rank() + 1).into(),
@@ -151,25 +164,33 @@ where
)
}
fn take_ggsw_prepared_slice<A>(&mut self, size: usize, infos: &A) -> (Vec<GGSWPrepared<&mut [u8], B>>, &mut Self)
fn take_ggsw_prepared_slice<A, M>(
&mut self,
module: &M,
size: usize,
infos: &A,
) -> (Vec<GGSWPrepared<&mut [u8], B>>, &mut Self)
where
A: GGSWInfos,
M: ModuleN + VmpPMatBytesOf,
{
let mut scratch: &mut Scratch<B> = self;
let mut scratch: &mut Self = self;
let mut cts: Vec<GGSWPrepared<&mut [u8], B>> = Vec::with_capacity(size);
for _ in 0..size {
let (ct, new_scratch) = scratch.take_ggsw_prepared(infos);
let (ct, new_scratch) = scratch.take_ggsw_prepared(module, infos);
scratch = new_scratch;
cts.push(ct)
}
(cts, scratch)
}
fn take_glwe_pk<A>(&mut self, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self)
fn take_glwe_pk<A, M>(&mut self, module: &M, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self)
where
A: GLWEInfos,
M: ModuleN,
{
let (data, scratch) = self.take_vec_znx(infos.n().into(), (infos.rank() + 1).into(), infos.size());
assert_eq!(module.n() as u32, infos.n());
let (data, scratch) = self.take_vec_znx(module, (infos.rank() + 1).into(), infos.size());
(
GLWEPublicKey {
k: infos.k(),
@@ -181,11 +202,13 @@ where
)
}
fn take_glwe_pk_prepared<A>(&mut self, infos: &A) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self)
fn take_glwe_pk_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self)
where
A: GLWEInfos,
M: ModuleN + VecZnxDftBytesOf,
{
let (data, scratch) = self.take_vec_znx_dft(infos.n().into(), (infos.rank() + 1).into(), infos.size());
assert_eq!(module.n() as u32, infos.n());
let (data, scratch) = self.take_vec_znx_dft(module, (infos.rank() + 1).into(), infos.size());
(
GLWEPublicKeyPrepared {
k: infos.k(),
@@ -197,8 +220,11 @@ where
)
}
fn take_glwe_secret(&mut self, n: Degree, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self) {
let (data, scratch) = self.take_scalar_znx(n.into(), rank.into());
fn take_glwe_secret<M>(&mut self, module: &M, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self)
where
M: ModuleN,
{
let (data, scratch) = self.take_scalar_znx(module, rank.into());
(
GLWESecret {
data,
@@ -208,8 +234,11 @@ where
)
}
fn take_glwe_secret_prepared(&mut self, n: Degree, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self) {
let (data, scratch) = self.take_svp_ppol(n.into(), rank.into());
fn take_glwe_secret_prepared<M>(&mut self, module: &M, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self)
where
M: ModuleN + SvpPPolBytesOf,
{
let (data, scratch) = self.take_svp_ppol(module, rank.into());
(
GLWESecretPrepared {
data,
@@ -219,11 +248,13 @@ where
)
}
fn take_glwe_switching_key<A>(&mut self, infos: &A) -> (GLWESwitchingKey<&mut [u8]>, &mut Self)
fn take_glwe_switching_key<A, M>(&mut self, module: &M, infos: &A) -> (GLWESwitchingKey<&mut [u8]>, &mut Self)
where
A: GGLWEInfos,
M: ModuleN,
{
let (data, scratch) = self.take_gglwe(infos);
assert_eq!(module.n() as u32, infos.n());
let (data, scratch) = self.take_gglwe(module, infos);
(
GLWESwitchingKey {
key: data,
@@ -234,11 +265,17 @@ where
)
}
fn take_gglwe_switching_key_prepared<A>(&mut self, infos: &A) -> (GLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self)
fn take_gglwe_switching_key_prepared<A, M>(
&mut self,
module: &M,
infos: &A,
) -> (GLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self)
where
A: GGLWEInfos,
M: ModuleN + VmpPMatBytesOf,
{
let (data, scratch) = self.take_gglwe_prepared(infos);
assert_eq!(module.n() as u32, infos.n());
let (data, scratch) = self.take_gglwe_prepared(module, infos);
(
GLWESwitchingKeyPrepared {
key: data,
@@ -249,26 +286,36 @@ where
)
}
fn take_gglwe_automorphism_key<A>(&mut self, infos: &A) -> (AutomorphismKey<&mut [u8]>, &mut Self)
fn take_gglwe_automorphism_key<A, M>(&mut self, module: &M, infos: &A) -> (AutomorphismKey<&mut [u8]>, &mut Self)
where
A: GGLWEInfos,
M: ModuleN,
{
let (data, scratch) = self.take_glwe_switching_key(infos);
assert_eq!(module.n() as u32, infos.n());
let (data, scratch) = self.take_glwe_switching_key(module, infos);
(AutomorphismKey { key: data, p: 0 }, scratch)
}
fn take_gglwe_automorphism_key_prepared<A>(&mut self, infos: &A) -> (AutomorphismKeyPrepared<&mut [u8], B>, &mut Self)
fn take_gglwe_automorphism_key_prepared<A, M>(
&mut self,
module: &M,
infos: &A,
) -> (AutomorphismKeyPrepared<&mut [u8], B>, &mut Self)
where
A: GGLWEInfos,
M: ModuleN + VmpPMatBytesOf,
{
let (data, scratch) = self.take_gglwe_switching_key_prepared(infos);
assert_eq!(module.n() as u32, infos.n());
let (data, scratch) = self.take_gglwe_switching_key_prepared(module, infos);
(AutomorphismKeyPrepared { key: data, p: 0 }, scratch)
}
fn take_tensor_key<A>(&mut self, infos: &A) -> (TensorKey<&mut [u8]>, &mut Self)
fn take_tensor_key<A, M>(&mut self, module: &M, infos: &A) -> (TensorKey<&mut [u8]>, &mut Self)
where
A: GGLWEInfos,
M: ModuleN,
{
assert_eq!(module.n() as u32, infos.n());
assert_eq!(
infos.rank_in(),
infos.rank_out(),
@@ -277,28 +324,30 @@ where
let mut keys: Vec<GLWESwitchingKey<&mut [u8]>> = Vec::new();
let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize;
let mut scratch: &mut Scratch<B> = self;
let mut scratch: &mut Self = self;
let mut ksk_infos: crate::layouts::GGLWECiphertextLayout = infos.gglwe_layout();
ksk_infos.rank_in = Rank(1);
if pairs != 0 {
let (gglwe, s) = scratch.take_glwe_switching_key(&ksk_infos);
let (gglwe, s) = scratch.take_glwe_switching_key(module, &ksk_infos);
scratch = s;
keys.push(gglwe);
}
for _ in 1..pairs {
let (gglwe, s) = scratch.take_glwe_switching_key(&ksk_infos);
let (gglwe, s) = scratch.take_glwe_switching_key(module, &ksk_infos);
scratch = s;
keys.push(gglwe);
}
(TensorKey { keys }, scratch)
}
fn take_gglwe_tensor_key_prepared<A>(&mut self, infos: &A) -> (TensorKeyPrepared<&mut [u8], B>, &mut Self)
fn take_gglwe_tensor_key_prepared<A, M>(&mut self, module: &M, infos: &A) -> (TensorKeyPrepared<&mut [u8], B>, &mut Self)
where
A: GGLWEInfos,
M: ModuleN + VmpPMatBytesOf,
{
assert_eq!(module.n() as u32, infos.n());
assert_eq!(
infos.rank_in(),
infos.rank_out(),
@@ -308,18 +357,18 @@ where
let mut keys: Vec<GLWESwitchingKeyPrepared<&mut [u8], B>> = Vec::new();
let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize;
let mut scratch: &mut Scratch<B> = self;
let mut scratch: &mut Self = self;
let mut ksk_infos: crate::layouts::GGLWECiphertextLayout = infos.gglwe_layout();
ksk_infos.rank_in = Rank(1);
if pairs != 0 {
let (gglwe, s) = scratch.take_gglwe_switching_key_prepared(&ksk_infos);
let (gglwe, s) = scratch.take_gglwe_switching_key_prepared(module, &ksk_infos);
scratch = s;
keys.push(gglwe);
}
for _ in 1..pairs {
let (gglwe, s) = scratch.take_gglwe_switching_key_prepared(&ksk_infos);
let (gglwe, s) = scratch.take_gglwe_switching_key_prepared(module, &ksk_infos);
scratch = s;
keys.push(gglwe);
}
@@ -327,4 +376,4 @@ where
}
}
impl<B: Backend> ScratchTakeCore<B> for Scratch<B> where Self: ScratchTakeBasic<B> + ScratchAvailable {}
impl<B: Backend> ScratchTakeCore<B> for Scratch<B> where Self: ScratchTakeBasic + ScratchAvailable {}