From 068470783e6045372b8ecfc788106ff43fdfa982 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 13 Aug 2025 09:45:44 +0200 Subject: [PATCH] Fix compressed encryptions & add GGSW compressed encryption (#67) * Added decompress test * updated encryption sampling & fixed bug in glwe -> lwe test * Added GGSW compressed encryption --- core/src/blind_rotation/test/cggi.rs | 5 +- core/src/elem.rs | 11 ++- core/src/gglwe/encryption.rs | 2 + core/src/gglwe/layouts_compressed.rs | 6 +- core/src/ggsw/encryption.rs | 81 +++++++++++++++- core/src/ggsw/layout_compressed.rs | 107 +++++++++++++++++++--- core/src/ggsw/test/cpu_spqlios/fft64.rs | 19 +++- core/src/ggsw/test/generic_tests.rs | 64 ++++++++++++- core/src/glwe/encryption.rs | 67 ++++++++------ core/src/glwe/keyswitch.rs | 23 ++++- core/src/glwe/layout.rs | 17 +++- core/src/glwe/tests/generic_encryption.rs | 9 +- core/src/lwe/test_fft64/conversion.rs | 2 +- 13 files changed, 345 insertions(+), 68 deletions(-) diff --git a/core/src/blind_rotation/test/cggi.rs b/core/src/blind_rotation/test/cggi.rs index a04d5c1..73f7c57 100644 --- a/core/src/blind_rotation/test/cggi.rs +++ b/core/src/blind_rotation/test/cggi.rs @@ -3,7 +3,7 @@ use backend::{ api::{ MatZnxAlloc, ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes, VecZnxEncodeCoeffsi64, VecZnxFillUniform, VecZnxRotateInplace, - VecZnxSwithcDegree, ZnxView, + VecZnxSub, VecZnxSwithcDegree, ZnxView, }, layouts::{Backend, Module, ScratchOwned}, oep::{ @@ -57,7 +57,8 @@ pub(crate) trait CGGITestModuleFamily = CCGIBlindRotationFamily + VecZnxEncodeCoeffsi64 + VecZnxRotateInplace + VecZnxSwithcDegree - + MatZnxAlloc; + + MatZnxAlloc + + VecZnxSub; pub(crate) trait CGGITestScratchFamily = VecZnxDftAllocBytesImpl + VecZnxBigAllocBytesImpl + ScratchOwnedAllocImpl diff --git a/core/src/elem.rs b/core/src/elem.rs index 6de038b..3ba5499 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,4 +1,7 @@ -use backend::hal::api::ZnxInfos; +use backend::hal::{ + api::{VecZnxCopy, VecZnxFillUniform, ZnxInfos}, + layouts::{Backend, Module}, +}; pub trait Infos { type Inner: ZnxInfos; @@ -52,3 +55,9 @@ pub trait SetMetaData { fn set_basek(&mut self, basek: usize); fn set_k(&mut self, k: usize); } + +pub trait Decompress { + fn decompress(&mut self, module: &Module, other: &C) + where + Module: VecZnxFillUniform + VecZnxCopy; +} diff --git a/core/src/gglwe/encryption.rs b/core/src/gglwe/encryption.rs index 27bc91d..9754002 100644 --- a/core/src/gglwe/encryption.rs +++ b/core/src/gglwe/encryption.rs @@ -192,6 +192,7 @@ impl GGLWECiphertextCompressed { let basek: usize = self.basek(); let k: usize = self.k(); let rank_in: usize = self.rank_in(); + let cols: usize = self.rank_out() + 1; let mut source_xa = Source::new(seed); @@ -217,6 +218,7 @@ impl GGLWECiphertextCompressed { self.basek(), self.k(), &mut self.at_mut(row_i, col_i).data, + cols, true, Some((&tmp_pt, 0)), sk, diff --git a/core/src/gglwe/layouts_compressed.rs b/core/src/gglwe/layouts_compressed.rs index eee5b9f..305d812 100644 --- a/core/src/gglwe/layouts_compressed.rs +++ b/core/src/gglwe/layouts_compressed.rs @@ -3,7 +3,7 @@ use backend::hal::{ layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, WriterTo}, }; -use crate::{AutomorphismKey, GGLWECiphertext, GLWECiphertextCompressed, GLWESwitchingKey, GLWETensorKey, Infos}; +use crate::{AutomorphismKey, Decompress, GGLWECiphertext, GLWECiphertextCompressed, GLWESwitchingKey, GLWETensorKey, Infos}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; #[derive(PartialEq, Eq)] @@ -169,8 +169,8 @@ impl WriterTo for GGLWECiphertextCompressed { } } -impl GGLWECiphertext { - pub fn decompress(&mut self, module: &Module, other: &GGLWECiphertextCompressed) +impl Decompress> for GGLWECiphertext { + fn decompress(&mut self, module: &Module, other: &GGLWECiphertextCompressed) where Module: VecZnxFillUniform + VecZnxCopy, { diff --git a/core/src/ggsw/encryption.rs b/core/src/ggsw/encryption.rs index 96abdf7..073f04c 100644 --- a/core/src/ggsw/encryption.rs +++ b/core/src/ggsw/encryption.rs @@ -6,7 +6,10 @@ use backend::hal::{ }; use sampling::source::Source; -use crate::{GGSWCiphertext, GLWECiphertext, GLWEEncryptSkFamily, GLWESecretExec, Infos, TakeGLWEPt}; +use crate::{ + GGLWEEncryptSkFamily, GGSWCiphertext, GGSWCiphertextCompressed, GLWECiphertext, GLWEEncryptSkFamily, GLWESecretExec, Infos, + TakeGLWEPt, encrypt_sk_internal, +}; pub trait GGSWEncryptSkFamily = GLWEEncryptSkFamily; @@ -77,3 +80,79 @@ impl GGSWCiphertext { }); } } + +impl GGSWCiphertextCompressed> { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + where + Module: GGSWEncryptSkFamily + VecZnxAllocBytes, + { + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank) + } +} + +impl GGSWCiphertextCompressed { + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx, + sk: &GLWESecretExec, + seed_xa: [u8; 32], + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + Module: GGSWEncryptSkFamily + VecZnxAddScalarInplace, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + { + #[cfg(debug_assertions)] + { + use backend::hal::api::ZnxInfos; + + assert_eq!(self.rank(), sk.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk.n(), module.n()); + } + + let basek: usize = self.basek(); + let k: usize = self.k(); + let rank: usize = self.rank(); + let cols: usize = rank + 1; + let digits: usize = self.digits(); + + let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(module, basek, k); + + let mut source = Source::new(seed_xa); + + (0..self.rows()).for_each(|row_i| { + tmp_pt.data.zero(); + + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0); + module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_1); + + (0..rank + 1).for_each(|col_j| { + // rlwe encrypt of vec_znx_pt into vec_znx_ct + + let (seed, mut source_xa_tmp) = source.branch(); + + self.seed[row_i * cols + col_j] = seed; + + encrypt_sk_internal( + module, + self.basek(), + self.k(), + &mut self.at_mut(row_i, col_j).data, + cols, + true, + Some((&tmp_pt, col_j)), + sk, + &mut source_xa_tmp, + source_xe, + sigma, + scratch_1, + ); + }); + }); + } +} diff --git a/core/src/ggsw/layout_compressed.rs b/core/src/ggsw/layout_compressed.rs index 4792d4f..2d139fd 100644 --- a/core/src/ggsw/layout_compressed.rs +++ b/core/src/ggsw/layout_compressed.rs @@ -3,11 +3,16 @@ use backend::hal::{ layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, WriterTo}, }; -use crate::{GGLWECiphertextCompressed, GGSWCiphertext, Infos}; +use crate::{Decompress, GGSWCiphertext, GLWECiphertextCompressed, Infos}; #[derive(PartialEq, Eq)] pub struct GGSWCiphertextCompressed { - pub(crate) data: GGLWECiphertextCompressed, + pub(crate) data: MatZnx, + pub(crate) basek: usize, + pub(crate) k: usize, + pub(crate) digits: usize, + pub(crate) rank: usize, + pub(crate) seed: Vec<[u8; 32]>, } impl GGSWCiphertextCompressed> { @@ -15,8 +20,31 @@ impl GGSWCiphertextCompressed> { where Module: MatZnxAlloc, { - GGSWCiphertextCompressed { - data: GGLWECiphertextCompressed::alloc(module, basek, k, rows, digits, rank, rank), + let size: usize = k.div_ceil(basek); + debug_assert!(digits > 0, "invalid ggsw: `digits` == 0"); + + debug_assert!( + size > digits, + "invalid ggsw: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + Self { + data: module.mat_znx_alloc(rows, rank + 1, 1, k.div_ceil(basek)), + basek, + k: k, + digits, + rank, + seed: vec![[0u8; 32]; rows * (rank + 1)], } } @@ -24,7 +52,48 @@ impl GGSWCiphertextCompressed> { where Module: MatZnxAllocBytes, { - GGLWECiphertextCompressed::bytes_of(module, basek, k, rows, digits, rank) + let size: usize = k.div_ceil(basek); + debug_assert!( + size > digits, + "invalid ggsw: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + module.mat_znx_alloc_bytes(rows, rank + 1, 1, size) + } +} + +impl GGSWCiphertextCompressed { + pub fn at(&self, row: usize, col: usize) -> GLWECiphertextCompressed<&[u8]> { + GLWECiphertextCompressed { + data: self.data.at(row, col), + basek: self.basek, + k: self.k, + rank: self.rank(), + seed: self.seed[row * (self.rank() + 1) + col], + } + } +} + +impl GGSWCiphertextCompressed { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertextCompressed<&mut [u8]> { + let rank: usize = self.rank(); + GLWECiphertextCompressed { + data: self.data.at_mut(row, col), + basek: self.basek, + k: self.k, + rank: rank, + seed: self.seed[row * (rank + 1) + col], + } } } @@ -32,25 +101,25 @@ impl Infos for GGSWCiphertextCompressed { type Inner = MatZnx; fn inner(&self) -> &Self::Inner { - self.data.inner() + &self.data } fn basek(&self) -> usize { - self.data.basek() + self.basek } fn k(&self) -> usize { - self.data.k() + self.k } } impl GGSWCiphertextCompressed { pub fn rank(&self) -> usize { - self.data.rank() + self.rank } pub fn digits(&self) -> usize { - self.data.digits() + self.digits } } @@ -66,15 +135,23 @@ impl WriterTo for GGSWCiphertextCompressed { } } -impl GGSWCiphertext { - pub fn decompress(&mut self, module: &Module, other: &GGSWCiphertextCompressed) +impl Decompress> for GGSWCiphertext { + fn decompress(&mut self, module: &Module, other: &GGSWCiphertextCompressed) where Module: VecZnxFillUniform + VecZnxCopy, { - let rows = self.rows(); + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), other.rank()) + } + + let rows: usize = self.rows(); + let rank: usize = self.rank(); (0..rows).for_each(|row_i| { - self.at_mut(row_i, 0) - .decompress(module, &other.data.at(row_i, 0)); + (0..rank + 1).for_each(|col_j| { + self.at_mut(row_i, col_j) + .decompress(module, &other.at(row_i, col_j)); + }); }); } } diff --git a/core/src/ggsw/test/cpu_spqlios/fft64.rs b/core/src/ggsw/test/cpu_spqlios/fft64.rs index 5ec265d..f18815c 100644 --- a/core/src/ggsw/test/cpu_spqlios/fft64.rs +++ b/core/src/ggsw/test/cpu_spqlios/fft64.rs @@ -4,8 +4,8 @@ use backend::{ }; use crate::ggsw::test::generic_tests::{ - test_automorphism, test_automorphism_inplace, test_encrypt_sk, test_external_product, test_external_product_inplace, - test_keyswitch, test_keyswitch_inplace, + test_automorphism, test_automorphism_inplace, test_encrypt_sk, test_encrypt_sk_compressed, test_external_product, + test_external_product_inplace, test_keyswitch, test_keyswitch_inplace, }; #[test] @@ -23,6 +23,21 @@ fn encrypt_sk() { }); } +#[test] +fn encrypt_sk_compressed() { + let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); + let basek: usize = 12; + let k_ct: usize = 54; + let digits: usize = k_ct / basek; + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + println!("test encrypt_sk_compressed digits: {} rank: {}", di, rank); + test_encrypt_sk_compressed(&module, basek, k_ct, di, rank, 3.2); + }); + }); +} + #[test] fn keyswitch() { let log_n: usize = 8; diff --git a/core/src/ggsw/test/generic_tests.rs b/core/src/ggsw/test/generic_tests.rs index ec797ab..182ab16 100644 --- a/core/src/ggsw/test/generic_tests.rs +++ b/core/src/ggsw/test/generic_tests.rs @@ -1,7 +1,7 @@ use backend::hal::{ api::{ MatZnxAlloc, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, - VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxRotateInplace, VecZnxStd, + VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxCopy, VecZnxRotateInplace, VecZnxStd, VecZnxSubABInplace, VecZnxSwithcDegree, ZnxViewMut, }, layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned}, @@ -13,9 +13,10 @@ use backend::hal::{ use sampling::source::Source; use crate::{ - AutomorphismKey, AutomorphismKeyExec, GGLWEExecLayoutFamily, GGSWAssertNoiseFamily, GGSWCiphertext, GGSWCiphertextExec, - GGSWEncryptSkFamily, GGSWKeySwitchFamily, GLWESecret, GLWESecretExec, GLWESecretFamily, GLWESwitchingKey, - GLWESwitchingKeyEncryptSkFamily, GLWESwitchingKeyExec, GLWETensorKey, GLWETensorKeyEncryptSkFamily, GLWETensorKeyExec, + AutomorphismKey, AutomorphismKeyExec, Decompress, GGLWEExecLayoutFamily, GGSWAssertNoiseFamily, GGSWCiphertext, + GGSWCiphertextCompressed, GGSWCiphertextExec, GGSWEncryptSkFamily, GGSWKeySwitchFamily, GLWESecret, GLWESecretExec, + GLWESecretFamily, GLWESwitchingKey, GLWESwitchingKeyEncryptSkFamily, GLWESwitchingKeyExec, GLWETensorKey, + GLWETensorKeyEncryptSkFamily, GLWETensorKeyExec, noise::{noise_ggsw_keyswitch, noise_ggsw_product}, }; @@ -29,7 +30,8 @@ pub(crate) trait TestModuleFamily = GLWESecretFamily + VecZnxAddScalarInplace + VecZnxSubABInplace + VecZnxStd - + ScalarZnxAllocBytes; + + ScalarZnxAllocBytes + + VecZnxCopy; pub(crate) trait TestScratchFamily = TakeVecZnxDftImpl + TakeVecZnxBigImpl + TakeSvpPPolImpl @@ -83,6 +85,58 @@ where ct.assert_noise(module, &sk_exec, &pt_scalar, &noise_f); } +pub(crate) fn test_encrypt_sk_compressed( + module: &Module, + basek: usize, + k: usize, + digits: usize, + rank: usize, + sigma: f64, +) where + Module: TestModuleFamily, + B: TestScratchFamily, +{ + let rows: usize = (k - digits * basek) / (digits * basek); + + let mut ct_compressed: GGSWCiphertextCompressed> = + GGSWCiphertextCompressed::alloc(module, basek, k, rows, digits, rank); + + let mut pt_scalar: ScalarZnx> = module.scalar_znx_alloc(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertextCompressed::encrypt_sk_scratch_space( + module, basek, k, rank, + )); + + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); + sk_exec.prepare(module, &sk); + + let seed_xa: [u8; 32] = [1u8; 32]; + + ct_compressed.encrypt_sk( + module, + &pt_scalar, + &sk_exec, + seed_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let noise_f = |_col_i: usize| -(k as f64) + sigma.log2() + 0.5; + + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k, rows, digits, rank); + ct.decompress(module, &ct_compressed); + + ct.assert_noise(module, &sk_exec, &pt_scalar, &noise_f); +} + pub(crate) fn test_keyswitch( module: &Module, basek: usize, diff --git a/core/src/glwe/encryption.rs b/core/src/glwe/encryption.rs index e749a96..bc9a95e 100644 --- a/core/src/glwe/encryption.rs +++ b/core/src/glwe/encryption.rs @@ -1,10 +1,10 @@ use backend::hal::{ api::{ ScalarZnxAllocBytes, ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, - TakeSvpPPol, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, - VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSubABInplace, ZnxInfos, - ZnxZero, + TakeSvpPPol, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAllocBytes, VecZnxBigAddNormal, + VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, + VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, + VecZnxSub, VecZnxSubABInplace, ZnxInfos, ZnxZero, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VecZnxBig}, }; @@ -26,7 +26,9 @@ pub trait GLWEEncryptSkFamily = VecZnxDftAllocBytes + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal - + VecZnxNormalize; + + VecZnxNormalize + + VecZnxSub + + VecZnxAllocBytes; pub trait GLWEEncryptPkFamily = VecZnxDftAllocBytes + VecZnxBigAllocBytes @@ -47,7 +49,7 @@ impl GLWECiphertext> { { let size: usize = k.div_ceil(basek); module.vec_znx_normalize_tmp_bytes(module.n()) - + module.vec_znx_dft_alloc_bytes(1, size) + + 2 * module.vec_znx_alloc_bytes(1, size) + module.vec_znx_dft_alloc_bytes(1, size) } pub fn encrypt_pk_scratch_space(module: &Module, basek: usize, k: usize) -> usize @@ -146,11 +148,13 @@ impl GLWECiphertext { Module: GLWEEncryptSkFamily, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { + let cols: usize = self.rank() + 1; encrypt_sk_internal( module, self.basek(), self.k(), &mut self.data, + cols, false, pt, sk, @@ -345,11 +349,13 @@ impl GLWECiphertextCompressed { Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { let mut source_xa = Source::new(seed_xa); + let cols: usize = self.rank() + 1; encrypt_sk_internal( module, self.basek(), self.k(), &mut self.data, + cols, true, pt, sk, @@ -367,6 +373,7 @@ pub(crate) fn encrypt_sk_internal, + cols: usize, compressed: bool, pt: Option<(&GLWEPlaintext, usize)>, sk: &GLWESecretExec, @@ -381,12 +388,6 @@ pub(crate) fn encrypt_sk_internal0 if compressed encryption" - ) - } assert_eq!( ct.cols(), 1, @@ -397,14 +398,15 @@ pub(crate) fn encrypt_sk_internal = module.vec_znx_dft_to_vec_znx_big_consume(ci_dft); // use c[0] as buffer, which is overwritten later by the normalization step - module.vec_znx_big_normalize(basek, ct, 0, &ci_big, 0, scratch_2); + module.vec_znx_big_normalize(basek, &mut ci, 0, &ci_big, 0, scratch_3); // c0_tmp = -c[i] * s[i] (use c[0] as buffer) - module.vec_znx_sub_ab_inplace(&mut c0, 0, ct, 0); - - // c[i] += m if col = i - // note: case cannot happen if compressed = true - if let Some((pt, col)) = pt { - if i == col { - module.vec_znx_add_inplace(ct, i, &pt.data, 0); - } - } + module.vec_znx_sub_ab_inplace(&mut c0, 0, &ci, 0); }); } diff --git a/core/src/glwe/keyswitch.rs b/core/src/glwe/keyswitch.rs index 87e494d..a3b40aa 100644 --- a/core/src/glwe/keyswitch.rs +++ b/core/src/glwe/keyswitch.rs @@ -119,7 +119,28 @@ impl GLWECiphertext { rhs.digits(), rhs.rank_in(), rhs.rank_out(), - ) + ), + "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank_in(), + rhs.rank_out(), + )={}", + scratch.available(), + GLWECiphertext::keyswitch_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank_in(), + rhs.rank_out(), + ) ); } } diff --git a/core/src/glwe/layout.rs b/core/src/glwe/layout.rs index 4a40430..ba50a95 100644 --- a/core/src/glwe/layout.rs +++ b/core/src/glwe/layout.rs @@ -6,7 +6,7 @@ use backend::hal::{ }; use sampling::source::Source; -use crate::{GLWEOps, Infos, SetMetaData}; +use crate::{Decompress, GLWEOps, Infos, SetMetaData}; #[derive(PartialEq, Eq, Clone)] pub struct GLWECiphertext { @@ -266,10 +266,9 @@ impl WriterTo for GLWECiphertextCompressed { } } -impl GLWECiphertext { - pub fn decompress(&mut self, module: &Module, other: &GLWECiphertextCompressed) +impl Decompress> for GLWECiphertext { + fn decompress(&mut self, module: &Module, other: &GLWECiphertextCompressed) where - DataOther: DataRef, Module: VecZnxFillUniform + VecZnxCopy, { #[cfg(debug_assertions)] @@ -299,7 +298,9 @@ impl GLWECiphertext { self.decompress_internal(module, other, &mut source); } } +} +impl GLWECiphertext { pub(crate) fn decompress_internal( &mut self, module: &Module, @@ -309,13 +310,19 @@ impl GLWECiphertext { DataOther: DataRef, Module: VecZnxFillUniform + VecZnxCopy, { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), other.rank()) + } + let k: usize = other.k; let basek: usize = other.basek; - let cols: usize = other.cols(); + let cols: usize = other.rank() + 1; module.vec_znx_copy(&mut self.data, 0, &other.data, 0); (1..cols).for_each(|i| { module.vec_znx_fill_uniform(basek, &mut self.data, i, k, source); }); + self.basek = basek; self.k = k; } diff --git a/core/src/glwe/tests/generic_encryption.rs b/core/src/glwe/tests/generic_encryption.rs index 6639bbc..f4eb063 100644 --- a/core/src/glwe/tests/generic_encryption.rs +++ b/core/src/glwe/tests/generic_encryption.rs @@ -12,7 +12,7 @@ use backend::hal::{ use sampling::source::Source; use crate::{ - GLWECiphertext, GLWECiphertextCompressed, GLWEDecryptFamily, GLWEEncryptPkFamily, GLWEEncryptSkFamily, GLWEOps, + Decompress, GLWECiphertext, GLWECiphertextCompressed, GLWEDecryptFamily, GLWEEncryptPkFamily, GLWEEncryptSkFamily, GLWEOps, GLWEPlaintext, GLWEPublicKey, GLWEPublicKeyExec, GLWESecret, GLWESecretExec, GLWESecretFamily, Infos, }; @@ -125,7 +125,12 @@ pub(crate) fn test_encrypt_sk_compressed( let noise_have: f64 = module.vec_znx_std(basek, &pt_want.data, 0) * (ct.k() as f64).exp2(); let noise_want: f64 = sigma; - assert!(noise_have <= noise_want + 0.2); + assert!( + noise_have <= noise_want + 0.2, + "{} <= {}", + noise_have, + noise_want + 0.2 + ); } pub(crate) fn test_encrypt_zero_sk(module: &Module, basek: usize, k_ct: usize, sigma: f64, rank: usize) diff --git a/core/src/lwe/test_fft64/conversion.rs b/core/src/lwe/test_fft64/conversion.rs index b403146..3d3a2f1 100644 --- a/core/src/lwe/test_fft64/conversion.rs +++ b/core/src/lwe/test_fft64/conversion.rs @@ -163,7 +163,7 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc( LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) - | GLWECiphertext::from_lwe_scratch_space(module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) + | LWECiphertext::from_glwe_scratch_space(module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) | GLWECiphertext::decrypt_scratch_space(module, basek, k_glwe_ct), );