From f679f6874d225670bab165f4f2e858569d6608d1 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sat, 17 Jan 2026 07:47:54 +0100 Subject: [PATCH] fix #130 (#133) --- poulpy-core/src/decryption/lwe.rs | 51 +++++++++++++++--------- poulpy-core/src/layouts/lwe_plaintext.rs | 23 ++++++++++- poulpy-core/src/scratch.rs | 18 ++++++++- 3 files changed, 71 insertions(+), 21 deletions(-) diff --git a/poulpy-core/src/decryption/lwe.rs b/poulpy-core/src/decryption/lwe.rs index 997d7ea..9f0ddbf 100644 --- a/poulpy-core/src/decryption/lwe.rs +++ b/poulpy-core/src/decryption/lwe.rs @@ -1,17 +1,17 @@ use poulpy_hal::{ - api::VecZnxNormalizeInplace, + api::{VecZnxNormalize, VecZnxNormalizeTmpBytes}, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut}, }; use crate::{ ScratchTakeCore, - layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToMut, LWESecret, LWESecretToRef, LWEToMut}, + layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToMut, LWESecret, LWESecretToRef, LWEToRef, SetLWEInfos, TorusPrecision}, }; impl LWE { - pub fn decrypt(&mut self, module: &M, pt: &mut P, sk: &S, scratch: &mut Scratch) + pub fn decrypt(&self, module: &M, pt: &mut P, sk: &S, scratch: &mut Scratch) where - P: LWEPlaintextToMut, + P: LWEPlaintextToMut + SetLWEInfos + LWEInfos, S: LWESecretToRef, M: LWEDecrypt, Scratch: ScratchTakeCore, @@ -21,27 +21,36 @@ impl LWE { } pub trait LWEDecrypt { - fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: &S, scratch: &mut Scratch) + fn lwe_decrypt(&self, res: &R, pt: &mut P, sk: &S, scratch: &mut Scratch) where - R: LWEToMut, - P: LWEPlaintextToMut, + R: LWEToRef, + P: LWEPlaintextToMut + SetLWEInfos + LWEInfos, S: LWESecretToRef, Scratch: ScratchTakeCore; + fn lwe_decrypt_tmp_bytes(&self, infos: &A) -> usize + where + A: LWEInfos; } impl LWEDecrypt for Module where - Self: Sized + VecZnxNormalizeInplace, + Self: Sized + VecZnxNormalize + VecZnxNormalizeTmpBytes, { - fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: &S, scratch: &mut Scratch) + fn lwe_decrypt_tmp_bytes(&self, infos: &A) -> usize where - R: LWEToMut, - P: LWEPlaintextToMut, + A: LWEInfos, + { + self.vec_znx_normalize_tmp_bytes() + LWEPlaintext::bytes_of(infos.size()) + } + + fn lwe_decrypt(&self, res: &R, pt: &mut P, sk: &S, scratch: &mut Scratch) + where + R: LWEToRef, + P: LWEPlaintextToMut + SetLWEInfos + LWEInfos, S: LWESecretToRef, Scratch: ScratchTakeCore, { - let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); - let pt: &mut LWEPlaintext<&mut [u8]> = &mut pt.to_mut(); + let res: &LWE<&[u8]> = &res.to_ref(); let sk: LWESecret<&[u8]> = sk.to_ref(); #[cfg(debug_assertions)] @@ -49,16 +58,20 @@ where assert_eq!(res.n(), sk.n()); } - (0..pt.size().min(res.size())).for_each(|i| { - pt.data.at_mut(0, i)[0] = res.data.at(0, i)[0] + let (mut tmp, scratch_1) = scratch.take_lwe_plaintext(res); + for i in 0..res.size() { + tmp.data.at_mut(0, i)[0] = res.data.at(0, i)[0] + res.data.at(0, i)[1..] .iter() .zip(sk.data.at(0, 0)) .map(|(x, y)| x * y) .sum::(); - }); - self.vec_znx_normalize_inplace(res.base2k().into(), &mut pt.data, 0, scratch); - pt.base2k = res.base2k(); - pt.k = crate::layouts::TorusPrecision(res.k().0.min(pt.size() as u32 * res.base2k().0)); + } + + let pt_base2k = pt.base2k().into(); + let res_base2k = res.base2k().into(); + self.vec_znx_normalize(&mut pt.to_mut().data, pt_base2k, 0, 0, tmp.data(), res_base2k, 0, scratch_1); + + pt.set_k(TorusPrecision(res.k().0.min(pt.size() as u32 * res.base2k().0))); } } diff --git a/poulpy-core/src/layouts/lwe_plaintext.rs b/poulpy-core/src/layouts/lwe_plaintext.rs index f568431..6d70829 100644 --- a/poulpy-core/src/layouts/lwe_plaintext.rs +++ b/poulpy-core/src/layouts/lwe_plaintext.rs @@ -2,7 +2,7 @@ use std::fmt; use poulpy_hal::layouts::{Data, DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos}; -use crate::layouts::{Base2K, Degree, LWEInfos, TorusPrecision}; +use crate::layouts::{Base2K, Degree, LWEInfos, SetLWEInfos, TorusPrecision}; #[derive(PartialEq, Eq, Copy, Clone, Debug)] pub struct LWEPlaintextLayout { @@ -34,6 +34,16 @@ pub struct LWEPlaintext { pub(crate) base2k: Base2K, } +impl SetLWEInfos for LWEPlaintext { + fn set_base2k(&mut self, base2k: Base2K) { + self.base2k = base2k + } + + fn set_k(&mut self, k: TorusPrecision) { + self.k = k + } +} + impl LWEInfos for LWEPlaintext { fn base2k(&self) -> Base2K { self.base2k @@ -67,6 +77,17 @@ impl LWEPlaintext> { base2k, } } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: LWEInfos, + { + Self::bytes_of(infos.size()) + } + + pub fn bytes_of(size: usize) -> usize { + VecZnx::bytes_of(1, 1, size) + } } impl fmt::Display for LWEPlaintext { diff --git a/poulpy-core/src/scratch.rs b/poulpy-core/src/scratch.rs index a6675da..b0ee9ee 100644 --- a/poulpy-core/src/scratch.rs +++ b/poulpy-core/src/scratch.rs @@ -7,7 +7,8 @@ use crate::{ dist::Distribution, layouts::{ Degree, GGLWE, GGLWEInfos, GGLWELayout, GGSW, GGSWInfos, GLWE, GLWEAutomorphismKey, GLWEInfos, GLWEPlaintext, - GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESecretTensor, GLWESwitchingKey, GLWETensorKey, LWE, LWEInfos, Rank, + GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESecretTensor, GLWESwitchingKey, GLWETensorKey, LWE, LWEInfos, LWEPlaintext, + Rank, prepared::{ GGLWEPrepared, GGSWPrepared, GLWEAutomorphismKeyPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared, GLWESwitchingKeyPrepared, GLWETensorKeyPrepared, @@ -34,6 +35,21 @@ where ) } + fn take_lwe_plaintext(&mut self, infos: &A) -> (LWEPlaintext<&mut [u8]>, &mut Self) + where + A: LWEInfos, + { + let (data, scratch) = self.take_vec_znx(1, 1, infos.size()); + ( + LWEPlaintext { + k: infos.k(), + base2k: infos.base2k(), + data, + }, + scratch, + ) + } + fn take_glwe(&mut self, infos: &A) -> (GLWE<&mut [u8]>, &mut Self) where A: GLWEInfos,