From d28ccb4c8f0e94118d4cf1ece9c10ee9a6f8ccb9 Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Mon, 13 Oct 2025 12:55:06 +0200 Subject: [PATCH] wip --- poulpy-core/src/layouts/prepared/glwe_pk.rs | 122 ++++++++++++++++---- poulpy-core/src/layouts/prepared/glwe_sk.rs | 86 +++++++++++--- 2 files changed, 169 insertions(+), 39 deletions(-) diff --git a/poulpy-core/src/layouts/prepared/glwe_pk.rs b/poulpy-core/src/layouts/prepared/glwe_pk.rs index 0cbbdd0..f730cf8 100644 --- a/poulpy-core/src/layouts/prepared/glwe_pk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_pk.rs @@ -6,7 +6,7 @@ use poulpy_hal::{ use crate::{ dist::Distribution, - layouts::{Base2K, BuildError, Degree, GLWEInfos, GLWEPublicKey, LWEInfos, Rank, TorusPrecision}, + layouts::{Base2K, BuildError, Degree, GLWEInfos, GLWEPublicKey, GLWEPublicKeyToRef, LWEInfos, Rank, TorusPrecision}, }; #[derive(PartialEq, Eq)] @@ -17,6 +17,16 @@ pub struct GLWEPublicKeyPrepared { pub(crate) dist: Distribution, } +pub(crate) trait SetDist { + fn set_dist(&mut self, dist: Distribution); +} + +impl SetDist for GLWEPublicKeyPrepared { + fn set_dist(&mut self, dist: Distribution) { + self.dist = dist + } +} + impl LWEInfos for GLWEPublicKeyPrepared { fn base2k(&self) -> Base2K { self.base2k @@ -166,40 +176,104 @@ impl GLWEPublicKeyPrepared, B> { } } -impl PrepareAlloc, B>> for GLWEPublicKey -where - Module: VecZnxDftAlloc + VecZnxDftApply, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GLWEPublicKeyPrepared, B> { - let mut pk_prepared: GLWEPublicKeyPrepared, B> = GLWEPublicKeyPrepared::alloc(module, self); - pk_prepared.prepare(module, self, scratch); - pk_prepared - } +pub trait GLWEPublicKeyPrepareTmpBytes { + fn glwe_public_key_prepare_tmp_bytes(&self, infos: &A) + where + A: GLWEInfos; } -impl PrepareScratchSpace for GLWEPublicKeyPrepared, B> { - fn prepare_scratch_space(_module: &Module, _infos: &A) -> usize { +impl GLWEPublicKeyPrepareTmpBytes for Module { + fn glwe_public_key_prepare_tmp_bytes(&self, infos: &A) + where + A: GLWEInfos, + { 0 } } -impl Prepare> for GLWEPublicKeyPrepared +impl GLWEPublicKeyPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &Module, infos: &A) + where + A: GLWEInfos, + Module: GLWEPublicKeyPrepareTmpBytes, + { + module.glwe_public_key_prepare_tmp_bytes(infos); + } +} + +pub trait GLWEPublicKeyPrepare { + fn glwe_public_key_prepare(&self, res: &mut R, other: &O, scratch: &Scratch) + where + R: GLWEPublicKeyPreparedToMut + SetDist, + O: GLWEPublicKeyToRef; +} + +impl GLWEPublicKeyPrepare for Module where - Module: VecZnxDftApply, + Module: VecZnxDftAlloc + VecZnxDftApply, { - fn prepare(&mut self, module: &Module, other: &GLWEPublicKey, _scratch: &mut Scratch) { - #[cfg(debug_assertions)] + fn glwe_public_key_prepare(&self, res: &mut R, other: &O, scratch: &Scratch) + where + R: GLWEPublicKeyPreparedToMut + SetDist, + O: GLWEPublicKeyToRef, + { { - assert_eq!(self.n(), other.n()); - assert_eq!(self.size(), other.size()); + let res: GLWEPublicKeyPrepared<&mut [u8], B> = res.to_mut(); + let other: GLWEPublicKey<&[u8]> = other.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(other.n(), self.n() as u32); + assert_eq!(res.size(), other.size()); + assert_eq!(res.k(), other.k()); + assert_eq!(res.base2k(), other.base2k()); + + for i in 0..(self.rank() + 1).into() { + self.vec_znx_dft_apply(1, 0, &mut self.data, i, &other.data, i); + } } - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_dft_apply(1, 0, &mut self.data, i, &other.data, i); - }); - self.k = other.k(); - self.base2k = other.base2k(); - self.dist = other.dist; + res.set_dist(other.dist); + } +} + +impl GLWEPublicKeyPrepared +where + Module: GLWEPublicKeyPrepare, +{ + pub fn prepare(&mut self, module: &Module, other: &O, scratch: &mut Scratch) + where + O: GLWEPublicKeyToRef, + { + module.glwe_public_key_prepare(self, other, scratch); + } +} + +pub trait GLWEPublicKeyPrepareAlloc { + fn glwe_public_key_prepare_alloc(&self, other: &O, scratch: &mut Scratch) + where + O: GLWEPublicKeyToRef; +} + +impl GLWEPublicKeyPrepareAlloc for Module +where + Module: GLWEPublicKeyPrepare, +{ + fn glwe_public_key_prepare_alloc(&self, other: &O, scratch: &mut Scratch) + where + O: GLWEPublicKeyToRef, + { + let mut ct_prepared: GLWEPublicKeyPrepared, B> = GLWEPublicKeyPrepared::alloc(self, other); + self.glwe_public_key_prepare(&mut ct_prepared, ct_prepared, scratch); + ct_prepared + } +} + +impl GLWEPublicKey { + pub fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) + where + Module: GLWEPublicKeyPrepareAlloc, + { + module.glwe_public_key_prepare_alloc(self, scratch); } } diff --git a/poulpy-core/src/layouts/prepared/glwe_sk.rs b/poulpy-core/src/layouts/prepared/glwe_sk.rs index a8e436d..c75f425 100644 --- a/poulpy-core/src/layouts/prepared/glwe_sk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_sk.rs @@ -5,7 +5,10 @@ use poulpy_hal::{ use crate::{ dist::Distribution, - layouts::{Base2K, Degree, GLWEInfos, GLWESecret, LWEInfos, Rank, TorusPrecision}, + layouts::{ + Base2K, Degree, GLWEInfos, GLWESecret, GLWESecretToMut, GLWESecretToRef, LWEInfos, Rank, TorusPrecision, + prepared::SetDist, + }, }; pub struct GLWESecretPrepared { @@ -13,6 +16,12 @@ pub struct GLWESecretPrepared { pub(crate) dist: Distribution, } +impl SetDist for GLWESecretPrepared { + fn set_dist(&mut self, dist: Distribution) { + self.dist = dist + } +} + impl LWEInfos for GLWESecretPrepared { fn base2k(&self) -> Base2K { Base2K(0) @@ -82,32 +91,79 @@ impl GLWESecretPrepared { } } -impl PrepareScratchSpace for GLWESecretPrepared, B> { - fn prepare_scratch_space(_module: &Module, _infos: &A) -> usize { +pub trait GLWESecretPrepareTmpBytes { + fn glwe_secret_prepare_tmp_bytes(&self, infos: &A) + where + A: GLWEInfos; +} + +impl GLWESecretPrepareTmpBytes for Module { + fn glwe_secret_prepare_tmp_bytes(&self, infos: &A) + where + A: GLWEInfos, + { 0 } } -impl PrepareAlloc, B>> for GLWESecret +impl GLWESecretPrepared, B> where - Module: SvpPrepare + SvpPPolAlloc, + Module: GLWESecretPrepareTmpBytes, { - fn prepare_alloc(&self, module: &Module, _scratch: &mut Scratch) -> GLWESecretPrepared, B> { - let mut sk_dft: GLWESecretPrepared, B> = GLWESecretPrepared::alloc(module, self); - sk_dft.prepare(module, self, _scratch); - sk_dft + fn prepare_tmp_bytes(&self, module: &Module, infos: &A) -> usize + where + A: GLWEInfos, + { + 0 } } -impl Prepare> for GLWESecretPrepared +pub trait GLWESecretPrepare { + fn glwe_secret_prepare(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GLWESecretPreparedToMut + SetDist, + O: GLWESecretToRef; +} + +impl GLWESecretPrepare for Module where Module: SvpPrepare, { - fn prepare(&mut self, module: &Module, other: &GLWESecret, _scratch: &mut Scratch) { - (0..self.rank().into()).for_each(|i| { - module.svp_prepare(&mut self.data, i, &other.data, i); - }); - self.dist = other.dist + fn glwe_secret_prepare(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GLWESecretPreparedToMut + SetDist, + O: GLWESecretToRef, + { + { + let res: GLWESecretPrepared<&mut [u8], _> = res.to_mut(); + let other: GLWESecret<&[u8]> = other.to_ref(); + + for i in 0..self.rank().into() { + self.svp_prepare(&mut res.data, i, &other.data, i); + } + } + + res.set_dist(other.dist); + } +} + +pub trait GLWESecretPrepareAlloc { + fn glwe_secret_prepare_alloc(&self, other: &O, scratch: &mut Scratch) + where + O: GLWESecretToMut; +} + +impl GLWESecretPrepareAlloc for Module +where + Module: GLWESecretPrepare, +{ + fn glwe_secret_prepare_alloc(&self, other: &O, scratch: &mut Scratch) + where + O: GLWESecretToMut, + { + let mut ct_prep: GLWESecretPrepared, B> = GLWESecretPrepared::alloc(self, self); + self.glwe_secret_prepare(&mut ct_prep, other, scratch); + ct_prep } }