This commit is contained in:
Jean-Philippe Bossuat
2026-01-17 07:47:54 +01:00
committed by GitHub
parent 2559d8ea81
commit f679f6874d
3 changed files with 71 additions and 21 deletions

View File

@@ -1,17 +1,17 @@
use poulpy_hal::{ use poulpy_hal::{
api::VecZnxNormalizeInplace, api::{VecZnxNormalize, VecZnxNormalizeTmpBytes},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut}, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut},
}; };
use crate::{ use crate::{
ScratchTakeCore, ScratchTakeCore,
layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToMut, LWESecret, LWESecretToRef, LWEToMut}, layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToMut, LWESecret, LWESecretToRef, LWEToRef, SetLWEInfos, TorusPrecision},
}; };
impl<DataSelf: DataRef + DataMut> LWE<DataSelf> { impl<DataSelf: DataRef + DataMut> LWE<DataSelf> {
pub fn decrypt<P, S, M, BE: Backend>(&mut self, module: &M, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>) pub fn decrypt<P, S, M, BE: Backend>(&self, module: &M, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>)
where where
P: LWEPlaintextToMut, P: LWEPlaintextToMut + SetLWEInfos + LWEInfos,
S: LWESecretToRef, S: LWESecretToRef,
M: LWEDecrypt<BE>, M: LWEDecrypt<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
@@ -21,27 +21,36 @@ impl<DataSelf: DataRef + DataMut> LWE<DataSelf> {
} }
pub trait LWEDecrypt<BE: Backend> { pub trait LWEDecrypt<BE: Backend> {
fn lwe_decrypt<R, P, S>(&self, res: &mut R, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>) fn lwe_decrypt<R, P, S>(&self, res: &R, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>)
where where
R: LWEToMut, R: LWEToRef,
P: LWEPlaintextToMut, P: LWEPlaintextToMut + SetLWEInfos + LWEInfos,
S: LWESecretToRef, S: LWESecretToRef,
Scratch<BE>: ScratchTakeCore<BE>; Scratch<BE>: ScratchTakeCore<BE>;
fn lwe_decrypt_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: LWEInfos;
} }
impl<BE: Backend> LWEDecrypt<BE> for Module<BE> impl<BE: Backend> LWEDecrypt<BE> for Module<BE>
where where
Self: Sized + VecZnxNormalizeInplace<BE>, Self: Sized + VecZnxNormalize<BE> + VecZnxNormalizeTmpBytes,
{ {
fn lwe_decrypt<R, P, S>(&self, res: &mut R, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>) fn lwe_decrypt_tmp_bytes<A>(&self, infos: &A) -> usize
where where
R: LWEToMut, A: LWEInfos,
P: LWEPlaintextToMut, {
self.vec_znx_normalize_tmp_bytes() + LWEPlaintext::bytes_of(infos.size())
}
fn lwe_decrypt<R, P, S>(&self, res: &R, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>)
where
R: LWEToRef,
P: LWEPlaintextToMut + SetLWEInfos + LWEInfos,
S: LWESecretToRef, S: LWESecretToRef,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); let res: &LWE<&[u8]> = &res.to_ref();
let pt: &mut LWEPlaintext<&mut [u8]> = &mut pt.to_mut();
let sk: LWESecret<&[u8]> = sk.to_ref(); let sk: LWESecret<&[u8]> = sk.to_ref();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -49,16 +58,20 @@ where
assert_eq!(res.n(), sk.n()); assert_eq!(res.n(), sk.n());
} }
(0..pt.size().min(res.size())).for_each(|i| { let (mut tmp, scratch_1) = scratch.take_lwe_plaintext(res);
pt.data.at_mut(0, i)[0] = res.data.at(0, i)[0] for i in 0..res.size() {
tmp.data.at_mut(0, i)[0] = res.data.at(0, i)[0]
+ res.data.at(0, i)[1..] + res.data.at(0, i)[1..]
.iter() .iter()
.zip(sk.data.at(0, 0)) .zip(sk.data.at(0, 0))
.map(|(x, y)| x * y) .map(|(x, y)| x * y)
.sum::<i64>(); .sum::<i64>();
}); }
self.vec_znx_normalize_inplace(res.base2k().into(), &mut pt.data, 0, scratch);
pt.base2k = res.base2k(); let pt_base2k = pt.base2k().into();
pt.k = crate::layouts::TorusPrecision(res.k().0.min(pt.size() as u32 * res.base2k().0)); let res_base2k = res.base2k().into();
self.vec_znx_normalize(&mut pt.to_mut().data, pt_base2k, 0, 0, tmp.data(), res_base2k, 0, scratch_1);
pt.set_k(TorusPrecision(res.k().0.min(pt.size() as u32 * res.base2k().0)));
} }
} }

View File

@@ -2,7 +2,7 @@ use std::fmt;
use poulpy_hal::layouts::{Data, DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos}; use poulpy_hal::layouts::{Data, DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos};
use crate::layouts::{Base2K, Degree, LWEInfos, TorusPrecision}; use crate::layouts::{Base2K, Degree, LWEInfos, SetLWEInfos, TorusPrecision};
#[derive(PartialEq, Eq, Copy, Clone, Debug)] #[derive(PartialEq, Eq, Copy, Clone, Debug)]
pub struct LWEPlaintextLayout { pub struct LWEPlaintextLayout {
@@ -34,6 +34,16 @@ pub struct LWEPlaintext<D: Data> {
pub(crate) base2k: Base2K, pub(crate) base2k: Base2K,
} }
impl<D: DataMut> SetLWEInfos for LWEPlaintext<D> {
fn set_base2k(&mut self, base2k: Base2K) {
self.base2k = base2k
}
fn set_k(&mut self, k: TorusPrecision) {
self.k = k
}
}
impl<D: Data> LWEInfos for LWEPlaintext<D> { impl<D: Data> LWEInfos for LWEPlaintext<D> {
fn base2k(&self) -> Base2K { fn base2k(&self) -> Base2K {
self.base2k self.base2k
@@ -67,6 +77,17 @@ impl LWEPlaintext<Vec<u8>> {
base2k, base2k,
} }
} }
pub fn bytes_of_from_infos<A>(infos: &A) -> usize
where
A: LWEInfos,
{
Self::bytes_of(infos.size())
}
pub fn bytes_of(size: usize) -> usize {
VecZnx::bytes_of(1, 1, size)
}
} }
impl<D: DataRef> fmt::Display for LWEPlaintext<D> { impl<D: DataRef> fmt::Display for LWEPlaintext<D> {

View File

@@ -7,7 +7,8 @@ use crate::{
dist::Distribution, dist::Distribution,
layouts::{ layouts::{
Degree, GGLWE, GGLWEInfos, GGLWELayout, GGSW, GGSWInfos, GLWE, GLWEAutomorphismKey, GLWEInfos, GLWEPlaintext, Degree, GGLWE, GGLWEInfos, GGLWELayout, GGSW, GGSWInfos, GLWE, GLWEAutomorphismKey, GLWEInfos, GLWEPlaintext,
GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESecretTensor, GLWESwitchingKey, GLWETensorKey, LWE, LWEInfos, Rank, GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESecretTensor, GLWESwitchingKey, GLWETensorKey, LWE, LWEInfos, LWEPlaintext,
Rank,
prepared::{ prepared::{
GGLWEPrepared, GGSWPrepared, GLWEAutomorphismKeyPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared, GGLWEPrepared, GGSWPrepared, GLWEAutomorphismKeyPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared,
GLWESwitchingKeyPrepared, GLWETensorKeyPrepared, GLWESwitchingKeyPrepared, GLWETensorKeyPrepared,
@@ -34,6 +35,21 @@ where
) )
} }
fn take_lwe_plaintext<A>(&mut self, infos: &A) -> (LWEPlaintext<&mut [u8]>, &mut Self)
where
A: LWEInfos,
{
let (data, scratch) = self.take_vec_znx(1, 1, infos.size());
(
LWEPlaintext {
k: infos.k(),
base2k: infos.base2k(),
data,
},
scratch,
)
}
fn take_glwe<A>(&mut self, infos: &A) -> (GLWE<&mut [u8]>, &mut Self) fn take_glwe<A>(&mut self, infos: &A) -> (GLWE<&mut [u8]>, &mut Self)
where where
A: GLWEInfos, A: GLWEInfos,