Added grlwe encrypt + test

This commit is contained in:
Jean-Philippe Bossuat
2025-05-08 15:35:21 +02:00
parent 1f384ce54d
commit 107e83c65c
2 changed files with 207 additions and 14 deletions

View File

@@ -1,6 +1,16 @@
use base2k::{Backend, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module};
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxOps,
ZnxZero,
};
use sampling::source::Source;
use crate::{elem::Infos, utils::derive_size};
use crate::{
elem::Infos,
elem_rlwe::{RLWECt, RLWECtDft, RLWEPt},
keys::SecretKeyDft,
utils::derive_size,
};
pub struct GRLWECt<C, B: Backend> {
pub data: MatZnxDft<C, B>,
@@ -18,6 +28,18 @@ impl<B: Backend> GRLWECt<Vec<u8>, B> {
}
}
impl<C> GRLWECt<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
{
pub fn get_row(&self, module: &Module<FFT64>, i: usize, res: &mut RLWECtDft<C, FFT64>)
where
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64>,
{
module.vmp_extract_row(res, self, i, 0);
}
}
impl<T, B: Backend> Infos for GRLWECt<T, B> {
type Inner = MatZnxDft<T, B>;
@@ -51,3 +73,169 @@ where
self.data.to_ref()
}
}
impl GRLWECt<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
RLWECt::encrypt_sk_scratch_bytes(module, size)
+ module.bytes_of_vec_znx(2, size)
+ module.bytes_of_vec_znx(1, size)
+ module.bytes_of_vec_znx_dft(2, size)
}
// pub fn encrypt_pk_scratch_bytes(module: &Module<FFT64>, pk_size: usize) -> usize {
// RLWECt::encrypt_pk_scratch_bytes(module, pk_size)
// }
}
pub fn encrypt_grlwe_sk<C, P, S>(
module: &Module<FFT64>,
ct: &mut GRLWECt<C, FFT64>,
pt: &ScalarZnx<P>,
sk: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
bound: f64,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
ScalarZnx<P>: ScalarZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
let rows: usize = ct.rows();
let size: usize = ct.size();
let log_base2k: usize = ct.log_base2k();
let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size);
let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size);
let (mut vec_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size);
let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt {
data: tmp_znx_pt,
log_base2k: log_base2k,
log_k: ct.log_k(),
};
let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt {
data: tmp_znx_ct,
log_base2k: log_base2k,
log_k: ct.log_k(),
};
(0..rows).for_each(|row_i| {
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0);
module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scratch_3);
// rlwe encrypt of vec_znx_pt into vec_znx_ct
vec_znx_ct.encrypt_sk(
module,
Some(&vec_znx_pt),
sk,
source_xa,
source_xe,
sigma,
bound,
scratch_3,
);
vec_znx_pt.data.zero(); // zeroes for next iteration
// Switch vec_znx_ct into DFT domain
module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0);
module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1);
// Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft
module.vmp_prepare_row(ct, row_i, 0, &vec_znx_dft_ct);
});
}
impl<C> GRLWECt<C, FFT64> {
pub fn encrypt_sk<P, S>(
&mut self,
module: &Module<FFT64>,
pt: &ScalarZnx<P>,
sk_dft: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
bound: f64,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
ScalarZnx<P>: ScalarZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
encrypt_grlwe_sk(
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
)
}
}
#[cfg(test)]
mod tests {
use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps};
use sampling::source::Source;
use crate::{
elem::Infos,
elem_rlwe::{RLWECtDft, RLWEPt},
keys::{SecretKey, SecretKeyDft},
};
use super::GRLWECt;
#[test]
fn encrypt_sk_vec_znx_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 8;
let log_k_ct: usize = 54;
let rows: usize = 4;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut ct: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows);
let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
let mut scratch: ScratchOwned = ScratchOwned::new(
GRLWECt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECtDft::decrypt_scratch_bytes(&module, ct.size()),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk_dft.dft(&module, &sk);
ct.encrypt_sk(
&module,
&pt_scalar,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
let mut ct_rlwe_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2);
(0..ct.rows()).for_each(|row_i| {
ct.get_row(&module, row_i, &mut ct_rlwe_dft);
ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0);
let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2();
assert!((sigma - std_pt) <= 0.2, "{} {}", sigma, std_pt);
});
module.free();
}
}

View File

@@ -181,7 +181,7 @@ pub fn encrypt_rlwe_sk<C, P, S>(
module: &Module<FFT64>,
ct: &mut RLWECt<C>,
pt: Option<&RLWEPt<P>>,
sk: &SecretKeyDft<S, FFT64>,
sk_dft: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
@@ -206,7 +206,7 @@ pub fn encrypt_rlwe_sk<C, P, S>(
module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
// c0_dft = DFT(a) * DFT(s)
module.svp_apply_inplace(&mut c0_dft, 0, sk, 0);
module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0);
// c0_big = IDFT(c0_dft)
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
@@ -227,7 +227,7 @@ pub fn decrypt_rlwe<P, C, S>(
module: &Module<FFT64>,
pt: &mut RLWEPt<P>,
ct: &RLWECt<C>,
sk: &SecretKeyDft<S, FFT64>,
sk_dft: &SecretKeyDft<S, FFT64>,
scratch: &mut Scratch,
) where
VecZnx<P>: VecZnxToMut + VecZnxToRef,
@@ -241,7 +241,7 @@ pub fn decrypt_rlwe<P, C, S>(
module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
// c0_dft = DFT(a) * DFT(s)
module.svp_apply_inplace(&mut c0_dft, 0, sk, 0);
module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0);
// c0_big = IDFT(c0_dft)
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
@@ -262,7 +262,7 @@ impl<C> RLWECt<C> {
&mut self,
module: &Module<FFT64>,
pt: Option<&RLWEPt<P>>,
sk: &SecretKeyDft<S, FFT64>,
sk_dft: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
@@ -274,17 +274,22 @@ impl<C> RLWECt<C> {
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
encrypt_rlwe_sk(
module, self, pt, sk, source_xa, source_xe, sigma, bound, scratch,
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
)
}
pub fn decrypt<P, S>(&self, module: &Module<FFT64>, pt: &mut RLWEPt<P>, sk: &SecretKeyDft<S, FFT64>, scratch: &mut Scratch)
where
pub fn decrypt<P, S>(
&self,
module: &Module<FFT64>,
pt: &mut RLWEPt<P>,
sk_dft: &SecretKeyDft<S, FFT64>,
scratch: &mut Scratch,
) where
VecZnx<P>: VecZnxToMut + VecZnxToRef,
VecZnx<C>: VecZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
decrypt_rlwe(module, pt, self, sk, scratch);
decrypt_rlwe(module, pt, self, sk_dft, scratch);
}
pub fn encrypt_pk<P, S>(
@@ -526,7 +531,7 @@ mod tests {
};
#[test]
fn encrypt_sk_vec_znx_fft64() {
fn encrypt_sk_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(32);
let log_base2k: usize = 8;
let log_k_ct: usize = 54;
@@ -597,7 +602,7 @@ mod tests {
}
#[test]
fn encrypt_zero_rlwe_dft_sk_fft64() {
fn encrypt_zero_sk_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(1024);
let log_base2k: usize = 8;
let log_k_ct: usize = 55;
@@ -639,7 +644,7 @@ mod tests {
}
#[test]
fn encrypt_pk_vec_znx_fft64() {
fn encrypt_pk_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(32);
let log_base2k: usize = 8;
let log_k_ct: usize = 54;