gglwe compressed encrypt

This commit is contained in:
Pro7ech
2025-10-17 10:59:35 +02:00
parent e0d3ca5cea
commit 69d04c71bc
2 changed files with 96 additions and 86 deletions

View File

@@ -1,50 +1,56 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{ModuleN, ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
ModuleN, ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero},
ZnNormalizeInplace,
},
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero},
source::Source, source::Source,
}; };
use crate::{ use crate::{
ScratchTakeCore, ScratchTakeCore,
encryption::{SIGMA, glwe_ct::GLWEEncryptSkInternal}, encryption::{
SIGMA,
glwe_ct::{GLWEEncryptSk, GLWEEncryptSkInternal},
},
layouts::{ layouts::{
GGLWE, GGLWEInfos, LWEInfos, GGLWEInfos, GLWEPlaintextAlloc, LWEInfos,
compressed::{GGLWECompressed, GGLWECompressedToMut}, compressed::{GGLWECompressed, GGLWECompressedToMut},
prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, prepared::GLWESecretPreparedToRef,
}, },
}; };
impl<D: DataMut> GGLWECompressed<D> { impl<D: DataMut> GGLWECompressed<D> {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>( pub fn encrypt_sk<M, P, S, BE: Backend>(
&mut self, &mut self,
module: &Module<B>, module: &M,
pt: &ScalarZnx<DataPt>, pt: &P,
sk: &GLWESecretPrepared<DataSk, B>, sk: &S,
seed: [u8; 32], seed: [u8; 32],
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<B>, scratch: &mut Scratch<BE>,
) where ) where
Module<B>: GGLWECompressedEncryptSk<B>, P: ScalarZnxToRef,
S: GLWESecretPreparedToRef<BE>,
M: GGLWECompressedEncryptSk<BE>,
{ {
module.gglwe_compressed_encrypt_sk(self, pt, sk, seed, source_xe, scratch); module.gglwe_compressed_encrypt_sk(self, pt, sk, seed, source_xe, scratch);
} }
} }
impl GGLWECompressed<Vec<u8>> { impl GGLWECompressed<Vec<u8>> {
pub fn encrypt_sk_tmp_bytes<B: Backend, A>(module: &Module<B>, infos: &A) -> usize pub fn encrypt_sk_tmp_bytes<M, BE: Backend, A>(module: &M, infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, M: GGLWECompressedEncryptSk<BE>,
{ {
GGLWE::encrypt_sk_tmp_bytes(module, infos) module.gglwe_compressed_encrypt_sk_tmp_bytes(infos)
} }
} }
pub trait GGLWECompressedEncryptSk<B: Backend> { pub trait GGLWECompressedEncryptSk<BE: Backend> {
fn gglwe_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GGLWEInfos;
fn gglwe_compressed_encrypt_sk<R, P, S>( fn gglwe_compressed_encrypt_sk<R, P, S>(
&self, &self,
res: &mut R, res: &mut R,
@@ -52,24 +58,33 @@ pub trait GGLWECompressedEncryptSk<B: Backend> {
sk: &S, sk: &S,
seed: [u8; 32], seed: [u8; 32],
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<B>, scratch: &mut Scratch<BE>,
) where ) where
R: GGLWECompressedToMut, R: GGLWECompressedToMut,
P: ScalarZnxToRef, P: ScalarZnxToRef,
S: GLWESecretPreparedToRef<B>; S: GLWESecretPreparedToRef<BE>;
} }
impl<B: Backend> GGLWECompressedEncryptSk<B> for Module<B> impl<BE: Backend> GGLWECompressedEncryptSk<BE> for Module<BE>
where where
Module<B>: ModuleN Module<BE>: ModuleN
+ GLWEEncryptSkInternal<B> + GLWEEncryptSkInternal<BE>
+ VecZnxNormalizeInplace<B> + GLWEEncryptSk<BE>
+ VecZnxNormalizeTmpBytes
+ VecZnxDftBytesOf + VecZnxDftBytesOf
+ VecZnxNormalizeInplace<BE>
+ VecZnxAddScalarInplace + VecZnxAddScalarInplace
+ ZnNormalizeInplace<B>, + VecZnxNormalizeTmpBytes,
Scratch<B>: ScratchAvailable + ScratchTakeCore<B>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn gglwe_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GGLWEInfos,
{
self.glwe_encrypt_sk_tmp_bytes(infos)
.max(self.vec_znx_normalize_tmp_bytes())
+ self.bytes_of_glwe_plaintext_from_infos(infos)
}
fn gglwe_compressed_encrypt_sk<R, P, S>( fn gglwe_compressed_encrypt_sk<R, P, S>(
&self, &self,
res: &mut R, res: &mut R,
@@ -77,52 +92,48 @@ where
sk: &S, sk: &S,
seed: [u8; 32], seed: [u8; 32],
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<B>, scratch: &mut Scratch<BE>,
) where ) where
R: GGLWECompressedToMut, R: GGLWECompressedToMut,
P: ScalarZnxToRef, P: ScalarZnxToRef,
S: GLWESecretPreparedToRef<B>, S: GLWESecretPreparedToRef<BE>,
{ {
let res: &mut GGLWECompressed<&mut [u8]> = &mut res.to_mut(); let res: &mut GGLWECompressed<&mut [u8]> = &mut res.to_mut();
let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); let pt: &ScalarZnx<&[u8]> = &pt.to_ref();
#[cfg(debug_assertions)] let sk = &sk.to_ref();
{
use poulpy_hal::layouts::ZnxInfos;
let sk = &sk.to_ref();
assert_eq!( assert_eq!(
res.rank_in(), res.rank_in(),
pt.cols() as u32, pt.cols() as u32,
"res.rank_in(): {} != pt.cols(): {}", "res.rank_in(): {} != pt.cols(): {}",
res.rank_in(), res.rank_in(),
pt.cols() pt.cols()
); );
assert_eq!( assert_eq!(
res.rank_out(), res.rank_out(),
sk.rank(), sk.rank(),
"res.rank_out(): {} != sk.rank(): {}", "res.rank_out(): {} != sk.rank(): {}",
res.rank_out(), res.rank_out(),
sk.rank() sk.rank()
); );
assert_eq!(res.n(), sk.n()); assert_eq!(res.n(), sk.n());
assert_eq!(pt.n() as u32, sk.n()); assert_eq!(pt.n() as u32, sk.n());
assert!( assert!(
scratch.available() >= GGLWECompressed::encrypt_sk_tmp_bytes(self, res), scratch.available() >= GGLWECompressed::encrypt_sk_tmp_bytes(self, res),
"scratch.available: {} < GGLWECiphertext::encrypt_sk_tmp_bytes: {}", "scratch.available: {} < GGLWECiphertext::encrypt_sk_tmp_bytes: {}",
scratch.available(), scratch.available(),
GGLWECompressed::encrypt_sk_tmp_bytes(self, res) GGLWECompressed::encrypt_sk_tmp_bytes(self, res)
); );
assert!( assert!(
res.dnum().0 * res.dsize().0 * res.base2k().0 <= res.k().0, res.dnum().0 * res.dsize().0 * res.base2k().0 <= res.k().0,
"res.dnum() : {} * res.dsize() : {} * res.base2k() : {} = {} >= res.k() = {}", "res.dnum() : {} * res.dsize() : {} * res.base2k() : {} = {} >= res.k() = {}",
res.dnum(), res.dnum(),
res.dsize(), res.dsize(),
res.base2k(), res.base2k(),
res.dnum().0 * res.dsize().0 * res.base2k().0, res.dnum().0 * res.dsize().0 * res.base2k().0,
res.k() res.k()
); );
}
let dnum: usize = res.dnum().into(); let dnum: usize = res.dnum().into();
let dsize: usize = res.dsize().into(); let dsize: usize = res.dsize().into();

View File

@@ -1,6 +1,6 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ModuleN, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace}, api::{ModuleN, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, ZnxInfos, ZnxZero}, layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero},
source::Source, source::Source,
}; };
@@ -8,7 +8,7 @@ use crate::{
SIGMA, ScratchTakeCore, SIGMA, ScratchTakeCore,
encryption::glwe_ct::{GLWEEncryptSk, GLWEEncryptSkInternal}, encryption::glwe_ct::{GLWEEncryptSk, GLWEEncryptSkInternal},
layouts::{ layouts::{
GGSW, GGSWInfos, GGSWToMut, GLWEInfos, LWEInfos, GGSW, GGSWInfos, GGSWToMut, GLWEInfos, GLWEPlaintextAlloc, LWEInfos,
prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, prepared::{GLWESecretPrepared, GLWESecretPreparedToRef},
}, },
}; };
@@ -43,7 +43,7 @@ impl<D: DataMut> GGSW<D> {
} }
} }
pub trait GGSWEncryptSk<B: Backend> { pub trait GGSWEncryptSk<BE: Backend> {
fn ggsw_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize fn ggsw_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where where
A: GGSWInfos; A: GGSWInfos;
@@ -55,32 +55,31 @@ pub trait GGSWEncryptSk<B: Backend> {
sk: &S, sk: &S,
source_xa: &mut Source, source_xa: &mut Source,
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<B>, scratch: &mut Scratch<BE>,
) where ) where
R: GGSWToMut, R: GGSWToMut,
P: ScalarZnxToRef, P: ScalarZnxToRef,
S: GLWESecretPreparedToRef<B>; S: GLWESecretPreparedToRef<BE>;
} }
impl<B: Backend> GGSWEncryptSk<B> for Module<B> impl<BE: Backend> GGSWEncryptSk<BE> for Module<BE>
where where
Module<B>: ModuleN Module<BE>: ModuleN
+ GLWEEncryptSkInternal<B> + GLWEEncryptSkInternal<BE>
+ GLWEEncryptSk<B> + GLWEEncryptSk<BE>
+ VecZnxDftBytesOf + VecZnxDftBytesOf
+ VecZnxNormalizeInplace<B> + VecZnxNormalizeInplace<BE>
+ VecZnxAddScalarInplace, + VecZnxAddScalarInplace
Scratch<B>: ScratchTakeCore<B>, + VecZnxNormalizeTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn ggsw_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize fn ggsw_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where where
A: GGSWInfos, A: GGSWInfos,
{ {
let size = infos.size();
self.glwe_encrypt_sk_tmp_bytes(infos) self.glwe_encrypt_sk_tmp_bytes(infos)
+ VecZnx::bytes_of(self.n(), (infos.rank() + 1).into(), size) .max(self.vec_znx_normalize_tmp_bytes())
+ VecZnx::bytes_of(self.n(), 1, size) + self.bytes_of_glwe_plaintext_from_infos(infos)
+ self.bytes_of_vec_znx_dft((infos.rank() + 1).into(), size)
} }
fn ggsw_encrypt_sk<R, P, S>( fn ggsw_encrypt_sk<R, P, S>(
@@ -90,15 +89,15 @@ where
sk: &S, sk: &S,
source_xa: &mut Source, source_xa: &mut Source,
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<B>, scratch: &mut Scratch<BE>,
) where ) where
R: GGSWToMut, R: GGSWToMut,
P: ScalarZnxToRef, P: ScalarZnxToRef,
S: GLWESecretPreparedToRef<B>, S: GLWESecretPreparedToRef<BE>,
{ {
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); let pt: &ScalarZnx<&[u8]> = &pt.to_ref();
let sk: &GLWESecretPrepared<&[u8], B> = &sk.to_ref(); let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
assert_eq!(res.rank(), sk.rank()); assert_eq!(res.rank(), sk.rank());
assert_eq!(res.n(), self.n() as u32); assert_eq!(res.n(), self.n() as u32);
@@ -111,7 +110,7 @@ where
let dsize: usize = res.dsize().into(); let dsize: usize = res.dsize().into();
let cols: usize = (rank + 1).into(); let cols: usize = (rank + 1).into();
let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(self, &res.glwe_layout()); let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(self, res);
for row_i in 0..res.dnum().into() { for row_i in 0..res.dnum().into() {
tmp_pt.data.zero(); tmp_pt.data.zero();