mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add cross-basek normalization (#90)
* added cross_basek_normalization * updated method signatures to take layouts * fixed cross-base normalization fix #91 fix #93
This commit is contained in:
committed by
GitHub
parent
4da790ea6a
commit
37e13b965c
@@ -2,7 +2,7 @@ use poulpy_hal::{
|
||||
api::{
|
||||
ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
|
||||
VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
|
||||
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero},
|
||||
source::Source,
|
||||
@@ -11,15 +11,16 @@ use poulpy_hal::{
|
||||
use crate::{
|
||||
TakeGLWEPt,
|
||||
encryption::{SIGMA, glwe_encrypt_sk_internal},
|
||||
layouts::{GGLWECiphertext, Infos, compressed::GGLWECiphertextCompressed, prepared::GLWESecretPrepared},
|
||||
layouts::{GGLWECiphertext, GGLWELayoutInfos, LWEInfos, compressed::GGLWECiphertextCompressed, prepared::GLWESecretPrepared},
|
||||
};
|
||||
|
||||
impl GGLWECiphertextCompressed<Vec<u8>> {
|
||||
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> usize
|
||||
pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
|
||||
where
|
||||
A: GGLWELayoutInfos,
|
||||
Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
|
||||
GGLWECiphertext::encrypt_sk_scratch_space(module, infos)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,7 +43,7 @@ impl<D: DataMut> GGLWECiphertextCompressed<D> {
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
+ VecZnxSubInplace
|
||||
+ VecZnxAddInplace
|
||||
+ VecZnxNormalizeInplace<B>
|
||||
+ VecZnxAddNormal
|
||||
@@ -56,7 +57,7 @@ impl<D: DataMut> GGLWECiphertextCompressed<D> {
|
||||
|
||||
assert_eq!(
|
||||
self.rank_in(),
|
||||
pt.cols(),
|
||||
pt.cols() as u32,
|
||||
"self.rank_in(): {} != pt.cols(): {}",
|
||||
self.rank_in(),
|
||||
pt.cols()
|
||||
@@ -69,36 +70,33 @@ impl<D: DataMut> GGLWECiphertextCompressed<D> {
|
||||
sk.rank()
|
||||
);
|
||||
assert_eq!(self.n(), sk.n());
|
||||
assert_eq!(pt.n(), sk.n());
|
||||
assert_eq!(pt.n() as u32, sk.n());
|
||||
assert!(
|
||||
scratch.available() >= GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k()),
|
||||
"scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}",
|
||||
scratch.available() >= GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self),
|
||||
"scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space: {}",
|
||||
scratch.available(),
|
||||
self.rank(),
|
||||
self.size(),
|
||||
GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k())
|
||||
GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self)
|
||||
);
|
||||
assert!(
|
||||
self.rows() * self.digits() * self.basek() <= self.k(),
|
||||
"self.rows() : {} * self.digits() : {} * self.basek() : {} = {} >= self.k() = {}",
|
||||
self.rows().0 * self.digits().0 * self.base2k().0 <= self.k().0,
|
||||
"self.rows() : {} * self.digits() : {} * self.base2k() : {} = {} >= self.k() = {}",
|
||||
self.rows(),
|
||||
self.digits(),
|
||||
self.basek(),
|
||||
self.rows() * self.digits() * self.basek(),
|
||||
self.base2k(),
|
||||
self.rows().0 * self.digits().0 * self.base2k().0,
|
||||
self.k()
|
||||
);
|
||||
}
|
||||
|
||||
let rows: usize = self.rows();
|
||||
let digits: usize = self.digits();
|
||||
let basek: usize = self.basek();
|
||||
let k: usize = self.k();
|
||||
let rank_in: usize = self.rank_in();
|
||||
let cols: usize = self.rank_out() + 1;
|
||||
let rows: usize = self.rows().into();
|
||||
let digits: usize = self.digits().into();
|
||||
let base2k: usize = self.base2k().into();
|
||||
let rank_in: usize = self.rank_in().into();
|
||||
let cols: usize = (self.rank_out() + 1).into();
|
||||
|
||||
let mut source_xa = Source::new(seed);
|
||||
|
||||
let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(sk.n(), basek, k);
|
||||
let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(self);
|
||||
(0..rank_in).for_each(|col_i| {
|
||||
(0..rows).for_each(|row_i| {
|
||||
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||
@@ -110,15 +108,15 @@ impl<D: DataMut> GGLWECiphertextCompressed<D> {
|
||||
pt,
|
||||
col_i,
|
||||
);
|
||||
module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scrach_1);
|
||||
module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1);
|
||||
|
||||
let (seed, mut source_xa_tmp) = source_xa.branch();
|
||||
self.seed[col_i * rows + row_i] = seed;
|
||||
|
||||
glwe_encrypt_sk_internal(
|
||||
module,
|
||||
self.basek(),
|
||||
self.k(),
|
||||
self.base2k().into(),
|
||||
self.k().into(),
|
||||
&mut self.at_mut(row_i, col_i).data,
|
||||
cols,
|
||||
true,
|
||||
|
||||
Reference in New Issue
Block a user