gglwe compressed encrypt

This commit is contained in:
Pro7ech
2025-10-17 10:59:35 +02:00
parent e0d3ca5cea
commit 69d04c71bc
2 changed files with 96 additions and 86 deletions

View File

@@ -1,50 +1,56 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{ModuleN, ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
ModuleN, ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero},
ZnNormalizeInplace,
},
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero},
source::Source, source::Source,
}; };
use crate::{ use crate::{
ScratchTakeCore, ScratchTakeCore,
encryption::{SIGMA, glwe_ct::GLWEEncryptSkInternal}, encryption::{
SIGMA,
glwe_ct::{GLWEEncryptSk, GLWEEncryptSkInternal},
},
layouts::{ layouts::{
GGLWE, GGLWEInfos, LWEInfos, GGLWEInfos, GLWEPlaintextAlloc, LWEInfos,
compressed::{GGLWECompressed, GGLWECompressedToMut}, compressed::{GGLWECompressed, GGLWECompressedToMut},
prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, prepared::GLWESecretPreparedToRef,
}, },
}; };
impl<D: DataMut> GGLWECompressed<D> { impl<D: DataMut> GGLWECompressed<D> {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>( pub fn encrypt_sk<M, P, S, BE: Backend>(
&mut self, &mut self,
module: &Module<B>, module: &M,
pt: &ScalarZnx<DataPt>, pt: &P,
sk: &GLWESecretPrepared<DataSk, B>, sk: &S,
seed: [u8; 32], seed: [u8; 32],
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<B>, scratch: &mut Scratch<BE>,
) where ) where
Module<B>: GGLWECompressedEncryptSk<B>, P: ScalarZnxToRef,
S: GLWESecretPreparedToRef<BE>,
M: GGLWECompressedEncryptSk<BE>,
{ {
module.gglwe_compressed_encrypt_sk(self, pt, sk, seed, source_xe, scratch); module.gglwe_compressed_encrypt_sk(self, pt, sk, seed, source_xe, scratch);
} }
} }
impl GGLWECompressed<Vec<u8>> { impl GGLWECompressed<Vec<u8>> {
pub fn encrypt_sk_tmp_bytes<B: Backend, A>(module: &Module<B>, infos: &A) -> usize pub fn encrypt_sk_tmp_bytes<M, BE: Backend, A>(module: &M, infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, M: GGLWECompressedEncryptSk<BE>,
{ {
GGLWE::encrypt_sk_tmp_bytes(module, infos) module.gglwe_compressed_encrypt_sk_tmp_bytes(infos)
} }
} }
pub trait GGLWECompressedEncryptSk<B: Backend> { pub trait GGLWECompressedEncryptSk<BE: Backend> {
fn gglwe_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GGLWEInfos;
fn gglwe_compressed_encrypt_sk<R, P, S>( fn gglwe_compressed_encrypt_sk<R, P, S>(
&self, &self,
res: &mut R, res: &mut R,
@@ -52,24 +58,33 @@ pub trait GGLWECompressedEncryptSk<B: Backend> {
sk: &S, sk: &S,
seed: [u8; 32], seed: [u8; 32],
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<B>, scratch: &mut Scratch<BE>,
) where ) where
R: GGLWECompressedToMut, R: GGLWECompressedToMut,
P: ScalarZnxToRef, P: ScalarZnxToRef,
S: GLWESecretPreparedToRef<B>; S: GLWESecretPreparedToRef<BE>;
} }
impl<B: Backend> GGLWECompressedEncryptSk<B> for Module<B> impl<BE: Backend> GGLWECompressedEncryptSk<BE> for Module<BE>
where where
Module<B>: ModuleN Module<BE>: ModuleN
+ GLWEEncryptSkInternal<B> + GLWEEncryptSkInternal<BE>
+ VecZnxNormalizeInplace<B> + GLWEEncryptSk<BE>
+ VecZnxNormalizeTmpBytes
+ VecZnxDftBytesOf + VecZnxDftBytesOf
+ VecZnxNormalizeInplace<BE>
+ VecZnxAddScalarInplace + VecZnxAddScalarInplace
+ ZnNormalizeInplace<B>, + VecZnxNormalizeTmpBytes,
Scratch<B>: ScratchAvailable + ScratchTakeCore<B>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn gglwe_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GGLWEInfos,
{
self.glwe_encrypt_sk_tmp_bytes(infos)
.max(self.vec_znx_normalize_tmp_bytes())
+ self.bytes_of_glwe_plaintext_from_infos(infos)
}
fn gglwe_compressed_encrypt_sk<R, P, S>( fn gglwe_compressed_encrypt_sk<R, P, S>(
&self, &self,
res: &mut R, res: &mut R,
@@ -77,18 +92,15 @@ where
sk: &S, sk: &S,
seed: [u8; 32], seed: [u8; 32],
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<B>, scratch: &mut Scratch<BE>,
) where ) where
R: GGLWECompressedToMut, R: GGLWECompressedToMut,
P: ScalarZnxToRef, P: ScalarZnxToRef,
S: GLWESecretPreparedToRef<B>, S: GLWESecretPreparedToRef<BE>,
{ {
let res: &mut GGLWECompressed<&mut [u8]> = &mut res.to_mut(); let res: &mut GGLWECompressed<&mut [u8]> = &mut res.to_mut();
let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); let pt: &ScalarZnx<&[u8]> = &pt.to_ref();
#[cfg(debug_assertions)]
{
use poulpy_hal::layouts::ZnxInfos;
let sk = &sk.to_ref(); let sk = &sk.to_ref();
assert_eq!( assert_eq!(
@@ -122,7 +134,6 @@ where
res.dnum().0 * res.dsize().0 * res.base2k().0, res.dnum().0 * res.dsize().0 * res.base2k().0,
res.k() res.k()
); );
}
let dnum: usize = res.dnum().into(); let dnum: usize = res.dnum().into();
let dsize: usize = res.dsize().into(); let dsize: usize = res.dsize().into();

View File

@@ -1,6 +1,6 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ModuleN, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace}, api::{ModuleN, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, ZnxInfos, ZnxZero}, layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero},
source::Source, source::Source,
}; };
@@ -8,7 +8,7 @@ use crate::{
SIGMA, ScratchTakeCore, SIGMA, ScratchTakeCore,
encryption::glwe_ct::{GLWEEncryptSk, GLWEEncryptSkInternal}, encryption::glwe_ct::{GLWEEncryptSk, GLWEEncryptSkInternal},
layouts::{ layouts::{
GGSW, GGSWInfos, GGSWToMut, GLWEInfos, LWEInfos, GGSW, GGSWInfos, GGSWToMut, GLWEInfos, GLWEPlaintextAlloc, LWEInfos,
prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, prepared::{GLWESecretPrepared, GLWESecretPreparedToRef},
}, },
}; };
@@ -43,7 +43,7 @@ impl<D: DataMut> GGSW<D> {
} }
} }
pub trait GGSWEncryptSk<B: Backend> { pub trait GGSWEncryptSk<BE: Backend> {
fn ggsw_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize fn ggsw_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where where
A: GGSWInfos; A: GGSWInfos;
@@ -55,32 +55,31 @@ pub trait GGSWEncryptSk<B: Backend> {
sk: &S, sk: &S,
source_xa: &mut Source, source_xa: &mut Source,
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<B>, scratch: &mut Scratch<BE>,
) where ) where
R: GGSWToMut, R: GGSWToMut,
P: ScalarZnxToRef, P: ScalarZnxToRef,
S: GLWESecretPreparedToRef<B>; S: GLWESecretPreparedToRef<BE>;
} }
impl<B: Backend> GGSWEncryptSk<B> for Module<B> impl<BE: Backend> GGSWEncryptSk<BE> for Module<BE>
where where
Module<B>: ModuleN Module<BE>: ModuleN
+ GLWEEncryptSkInternal<B> + GLWEEncryptSkInternal<BE>
+ GLWEEncryptSk<B> + GLWEEncryptSk<BE>
+ VecZnxDftBytesOf + VecZnxDftBytesOf
+ VecZnxNormalizeInplace<B> + VecZnxNormalizeInplace<BE>
+ VecZnxAddScalarInplace, + VecZnxAddScalarInplace
Scratch<B>: ScratchTakeCore<B>, + VecZnxNormalizeTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn ggsw_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize fn ggsw_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where where
A: GGSWInfos, A: GGSWInfos,
{ {
let size = infos.size();
self.glwe_encrypt_sk_tmp_bytes(infos) self.glwe_encrypt_sk_tmp_bytes(infos)
+ VecZnx::bytes_of(self.n(), (infos.rank() + 1).into(), size) .max(self.vec_znx_normalize_tmp_bytes())
+ VecZnx::bytes_of(self.n(), 1, size) + self.bytes_of_glwe_plaintext_from_infos(infos)
+ self.bytes_of_vec_znx_dft((infos.rank() + 1).into(), size)
} }
fn ggsw_encrypt_sk<R, P, S>( fn ggsw_encrypt_sk<R, P, S>(
@@ -90,15 +89,15 @@ where
sk: &S, sk: &S,
source_xa: &mut Source, source_xa: &mut Source,
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<B>, scratch: &mut Scratch<BE>,
) where ) where
R: GGSWToMut, R: GGSWToMut,
P: ScalarZnxToRef, P: ScalarZnxToRef,
S: GLWESecretPreparedToRef<B>, S: GLWESecretPreparedToRef<BE>,
{ {
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); let pt: &ScalarZnx<&[u8]> = &pt.to_ref();
let sk: &GLWESecretPrepared<&[u8], B> = &sk.to_ref(); let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
assert_eq!(res.rank(), sk.rank()); assert_eq!(res.rank(), sk.rank());
assert_eq!(res.n(), self.n() as u32); assert_eq!(res.n(), self.n() as u32);
@@ -111,7 +110,7 @@ where
let dsize: usize = res.dsize().into(); let dsize: usize = res.dsize().into();
let cols: usize = (rank + 1).into(); let cols: usize = (rank + 1).into();
let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(self, &res.glwe_layout()); let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(self, res);
for row_i in 0..res.dnum().into() { for row_i in 0..res.dnum().into() {
tmp_pt.data.zero(); tmp_pt.data.zero();