Fix compressed encryptions & add GGSW compressed encryption (#67)

* Added decompress test

* updated encryption sampling & fixed bug in glwe -> lwe test

* Added GGSW compressed encryption
This commit is contained in:
Jean-Philippe Bossuat
2025-08-13 09:45:44 +02:00
committed by GitHub
parent 9aa4b1f1e2
commit 068470783e
13 changed files with 345 additions and 68 deletions

View File

@@ -6,7 +6,10 @@ use backend::hal::{
};
use sampling::source::Source;
use crate::{GGSWCiphertext, GLWECiphertext, GLWEEncryptSkFamily, GLWESecretExec, Infos, TakeGLWEPt};
use crate::{
GGLWEEncryptSkFamily, GGSWCiphertext, GGSWCiphertextCompressed, GLWECiphertext, GLWEEncryptSkFamily, GLWESecretExec, Infos,
TakeGLWEPt, encrypt_sk_internal,
};
pub trait GGSWEncryptSkFamily<B: Backend> = GLWEEncryptSkFamily<B>;
@@ -77,3 +80,79 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
});
}
}
impl GGSWCiphertextCompressed<Vec<u8>> {
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
where
Module<B>: GGSWEncryptSkFamily<B> + VecZnxAllocBytes,
{
GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank)
}
}
impl<DataSelf: DataMut> GGSWCiphertextCompressed<DataSelf> {
pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
pt: &ScalarZnx<DataPt>,
sk: &GLWESecretExec<DataSk, B>,
seed_xa: [u8; 32],
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch<B>,
) where
Module<B>: GGSWEncryptSkFamily<B> + VecZnxAddScalarInplace,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx<B>,
{
#[cfg(debug_assertions)]
{
use backend::hal::api::ZnxInfos;
assert_eq!(self.rank(), sk.rank());
assert_eq!(self.n(), module.n());
assert_eq!(pt.n(), module.n());
assert_eq!(sk.n(), module.n());
}
let basek: usize = self.basek();
let k: usize = self.k();
let rank: usize = self.rank();
let cols: usize = rank + 1;
let digits: usize = self.digits();
let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(module, basek, k);
let mut source = Source::new(seed_xa);
(0..self.rows()).for_each(|row_i| {
tmp_pt.data.zero();
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0);
module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_1);
(0..rank + 1).for_each(|col_j| {
// rlwe encrypt of vec_znx_pt into vec_znx_ct
let (seed, mut source_xa_tmp) = source.branch();
self.seed[row_i * cols + col_j] = seed;
encrypt_sk_internal(
module,
self.basek(),
self.k(),
&mut self.at_mut(row_i, col_j).data,
cols,
true,
Some((&tmp_pt, col_j)),
sk,
&mut source_xa_tmp,
source_xe,
sigma,
scratch_1,
);
});
});
}
}

View File

@@ -3,11 +3,16 @@ use backend::hal::{
layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, WriterTo},
};
use crate::{GGLWECiphertextCompressed, GGSWCiphertext, Infos};
use crate::{Decompress, GGSWCiphertext, GLWECiphertextCompressed, Infos};
#[derive(PartialEq, Eq)]
pub struct GGSWCiphertextCompressed<D: Data> {
pub(crate) data: GGLWECiphertextCompressed<D>,
pub(crate) data: MatZnx<D>,
pub(crate) basek: usize,
pub(crate) k: usize,
pub(crate) digits: usize,
pub(crate) rank: usize,
pub(crate) seed: Vec<[u8; 32]>,
}
impl GGSWCiphertextCompressed<Vec<u8>> {
@@ -15,8 +20,31 @@ impl GGSWCiphertextCompressed<Vec<u8>> {
where
Module<B>: MatZnxAlloc,
{
GGSWCiphertextCompressed {
data: GGLWECiphertextCompressed::alloc(module, basek, k, rows, digits, rank, rank),
let size: usize = k.div_ceil(basek);
debug_assert!(digits > 0, "invalid ggsw: `digits` == 0");
debug_assert!(
size > digits,
"invalid ggsw: ceil(k/basek): {} <= digits: {}",
size,
digits
);
assert!(
rows * digits <= size,
"invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}",
rows,
digits,
size
);
Self {
data: module.mat_znx_alloc(rows, rank + 1, 1, k.div_ceil(basek)),
basek,
k: k,
digits,
rank,
seed: vec![[0u8; 32]; rows * (rank + 1)],
}
}
@@ -24,7 +52,48 @@ impl GGSWCiphertextCompressed<Vec<u8>> {
where
Module<B>: MatZnxAllocBytes,
{
GGLWECiphertextCompressed::bytes_of(module, basek, k, rows, digits, rank)
let size: usize = k.div_ceil(basek);
debug_assert!(
size > digits,
"invalid ggsw: ceil(k/basek): {} <= digits: {}",
size,
digits
);
assert!(
rows * digits <= size,
"invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}",
rows,
digits,
size
);
module.mat_znx_alloc_bytes(rows, rank + 1, 1, size)
}
}
impl<D: DataRef> GGSWCiphertextCompressed<D> {
pub fn at(&self, row: usize, col: usize) -> GLWECiphertextCompressed<&[u8]> {
GLWECiphertextCompressed {
data: self.data.at(row, col),
basek: self.basek,
k: self.k,
rank: self.rank(),
seed: self.seed[row * (self.rank() + 1) + col],
}
}
}
impl<D: DataMut> GGSWCiphertextCompressed<D> {
pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertextCompressed<&mut [u8]> {
let rank: usize = self.rank();
GLWECiphertextCompressed {
data: self.data.at_mut(row, col),
basek: self.basek,
k: self.k,
rank: rank,
seed: self.seed[row * (rank + 1) + col],
}
}
}
@@ -32,25 +101,25 @@ impl<D: Data> Infos for GGSWCiphertextCompressed<D> {
type Inner = MatZnx<D>;
fn inner(&self) -> &Self::Inner {
self.data.inner()
&self.data
}
fn basek(&self) -> usize {
self.data.basek()
self.basek
}
fn k(&self) -> usize {
self.data.k()
self.k
}
}
impl<D: Data> GGSWCiphertextCompressed<D> {
pub fn rank(&self) -> usize {
self.data.rank()
self.rank
}
pub fn digits(&self) -> usize {
self.data.digits()
self.digits
}
}
@@ -66,15 +135,23 @@ impl<D: DataRef> WriterTo for GGSWCiphertextCompressed<D> {
}
}
impl<D: DataMut> GGSWCiphertext<D> {
pub fn decompress<DataOther: DataRef, B: Backend>(&mut self, module: &Module<B>, other: &GGSWCiphertextCompressed<DataOther>)
impl<D: DataMut, B: Backend, DR: DataRef> Decompress<B, GGSWCiphertextCompressed<DR>> for GGSWCiphertext<D> {
fn decompress(&mut self, module: &Module<B>, other: &GGSWCiphertextCompressed<DR>)
where
Module<B>: VecZnxFillUniform + VecZnxCopy,
{
let rows = self.rows();
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), other.rank())
}
let rows: usize = self.rows();
let rank: usize = self.rank();
(0..rows).for_each(|row_i| {
self.at_mut(row_i, 0)
.decompress(module, &other.data.at(row_i, 0));
(0..rank + 1).for_each(|col_j| {
self.at_mut(row_i, col_j)
.decompress(module, &other.at(row_i, col_j));
});
});
}
}

View File

@@ -4,8 +4,8 @@ use backend::{
};
use crate::ggsw::test::generic_tests::{
test_automorphism, test_automorphism_inplace, test_encrypt_sk, test_external_product, test_external_product_inplace,
test_keyswitch, test_keyswitch_inplace,
test_automorphism, test_automorphism_inplace, test_encrypt_sk, test_encrypt_sk_compressed, test_external_product,
test_external_product_inplace, test_keyswitch, test_keyswitch_inplace,
};
#[test]
@@ -23,6 +23,21 @@ fn encrypt_sk() {
});
}
#[test]
fn encrypt_sk_compressed() {
let log_n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let basek: usize = 12;
let k_ct: usize = 54;
let digits: usize = k_ct / basek;
(1..4).for_each(|rank| {
(1..digits + 1).for_each(|di| {
println!("test encrypt_sk_compressed digits: {} rank: {}", di, rank);
test_encrypt_sk_compressed(&module, basek, k_ct, di, rank, 3.2);
});
});
}
#[test]
fn keyswitch() {
let log_n: usize = 8;

View File

@@ -1,7 +1,7 @@
use backend::hal::{
api::{
MatZnxAlloc, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace,
VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxRotateInplace, VecZnxStd,
VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxCopy, VecZnxRotateInplace, VecZnxStd,
VecZnxSubABInplace, VecZnxSwithcDegree, ZnxViewMut,
},
layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned},
@@ -13,9 +13,10 @@ use backend::hal::{
use sampling::source::Source;
use crate::{
AutomorphismKey, AutomorphismKeyExec, GGLWEExecLayoutFamily, GGSWAssertNoiseFamily, GGSWCiphertext, GGSWCiphertextExec,
GGSWEncryptSkFamily, GGSWKeySwitchFamily, GLWESecret, GLWESecretExec, GLWESecretFamily, GLWESwitchingKey,
GLWESwitchingKeyEncryptSkFamily, GLWESwitchingKeyExec, GLWETensorKey, GLWETensorKeyEncryptSkFamily, GLWETensorKeyExec,
AutomorphismKey, AutomorphismKeyExec, Decompress, GGLWEExecLayoutFamily, GGSWAssertNoiseFamily, GGSWCiphertext,
GGSWCiphertextCompressed, GGSWCiphertextExec, GGSWEncryptSkFamily, GGSWKeySwitchFamily, GLWESecret, GLWESecretExec,
GLWESecretFamily, GLWESwitchingKey, GLWESwitchingKeyEncryptSkFamily, GLWESwitchingKeyExec, GLWETensorKey,
GLWETensorKeyEncryptSkFamily, GLWETensorKeyExec,
noise::{noise_ggsw_keyswitch, noise_ggsw_product},
};
@@ -29,7 +30,8 @@ pub(crate) trait TestModuleFamily<B: Backend> = GLWESecretFamily<B>
+ VecZnxAddScalarInplace
+ VecZnxSubABInplace
+ VecZnxStd
+ ScalarZnxAllocBytes;
+ ScalarZnxAllocBytes
+ VecZnxCopy;
pub(crate) trait TestScratchFamily<B: Backend> = TakeVecZnxDftImpl<B>
+ TakeVecZnxBigImpl<B>
+ TakeSvpPPolImpl<B>
@@ -83,6 +85,58 @@ where
ct.assert_noise(module, &sk_exec, &pt_scalar, &noise_f);
}
pub(crate) fn test_encrypt_sk_compressed<B: Backend>(
module: &Module<B>,
basek: usize,
k: usize,
digits: usize,
rank: usize,
sigma: f64,
) where
Module<B>: TestModuleFamily<B>,
B: TestScratchFamily<B>,
{
let rows: usize = (k - digits * basek) / (digits * basek);
let mut ct_compressed: GGSWCiphertextCompressed<Vec<u8>> =
GGSWCiphertextCompressed::alloc(module, basek, k, rows, digits, rank);
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.scalar_znx_alloc(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(GGSWCiphertextCompressed::encrypt_sk_scratch_space(
module, basek, k, rank,
));
let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc(module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_exec: GLWESecretExec<Vec<u8>, B> = GLWESecretExec::from(module, &sk);
sk_exec.prepare(module, &sk);
let seed_xa: [u8; 32] = [1u8; 32];
ct_compressed.encrypt_sk(
module,
&pt_scalar,
&sk_exec,
seed_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
let noise_f = |_col_i: usize| -(k as f64) + sigma.log2() + 0.5;
let mut ct: GGSWCiphertext<Vec<u8>> = GGSWCiphertext::alloc(module, basek, k, rows, digits, rank);
ct.decompress(module, &ct_compressed);
ct.assert_noise(module, &sk_exec, &pt_scalar, &noise_f);
}
pub(crate) fn test_keyswitch<B: Backend>(
module: &Module<B>,
basek: usize,