From 107e83c65c062f8a6c5761846fd05baa30c52d85 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 8 May 2025 15:35:21 +0200 Subject: [PATCH] Added grlwe encrypt + test --- rlwe/src/elem_grlwe.rs | 192 ++++++++++++++++++++++++++++++++++++++++- rlwe/src/elem_rlwe.rs | 29 ++++--- 2 files changed, 207 insertions(+), 14 deletions(-) diff --git a/rlwe/src/elem_grlwe.rs b/rlwe/src/elem_grlwe.rs index b269cb3..a0000cf 100644 --- a/rlwe/src/elem_grlwe.rs +++ b/rlwe/src/elem_grlwe.rs @@ -1,6 +1,16 @@ -use base2k::{Backend, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module}; +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxOps, + ZnxZero, +}; +use sampling::source::Source; -use crate::{elem::Infos, utils::derive_size}; +use crate::{ + elem::Infos, + elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, + keys::SecretKeyDft, + utils::derive_size, +}; pub struct GRLWECt { pub data: MatZnxDft, @@ -18,6 +28,18 @@ impl GRLWECt, B> { } } +impl GRLWECt +where + MatZnxDft: MatZnxDftToRef, +{ + pub fn get_row(&self, module: &Module, i: usize, res: &mut RLWECtDft) + where + VecZnxDft: VecZnxDftToMut, + { + module.vmp_extract_row(res, self, i, 0); + } +} + impl Infos for GRLWECt { type Inner = MatZnxDft; @@ -51,3 +73,169 @@ where self.data.to_ref() } } + +impl GRLWECt, FFT64> { + pub fn encrypt_sk_scratch_bytes(module: &Module, size: usize) -> usize { + RLWECt::encrypt_sk_scratch_bytes(module, size) + + module.bytes_of_vec_znx(2, size) + + module.bytes_of_vec_znx(1, size) + + module.bytes_of_vec_znx_dft(2, size) + } + + // pub fn encrypt_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { + // RLWECt::encrypt_pk_scratch_bytes(module, pk_size) + // } +} + +pub fn encrypt_grlwe_sk( + module: &Module, + ct: &mut GRLWECt, + pt: &ScalarZnx

, + sk: &SecretKeyDft, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + MatZnxDft: MatZnxDftToMut, + ScalarZnx

: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let rows: usize = ct.rows(); + let size: usize = ct.size(); + let log_base2k: usize = ct.log_base2k(); + + let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); + let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size); + let (mut vec_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size); + + let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt { + data: tmp_znx_pt, + log_base2k: log_base2k, + log_k: ct.log_k(), + }; + + let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt { + data: tmp_znx_ct, + log_base2k: log_base2k, + log_k: ct.log_k(), + }; + + (0..rows).for_each(|row_i| { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0); + module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scratch_3); + + // rlwe encrypt of vec_znx_pt into vec_znx_ct + vec_znx_ct.encrypt_sk( + module, + Some(&vec_znx_pt), + sk, + source_xa, + source_xe, + sigma, + bound, + scratch_3, + ); + + vec_znx_pt.data.zero(); // zeroes for next iteration + + // Switch vec_znx_ct into DFT domain + module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0); + module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1); + + // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft + module.vmp_prepare_row(ct, row_i, 0, &vec_znx_dft_ct); + }); +} + +impl GRLWECt { + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx

, + sk_dft: &SecretKeyDft, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToMut, + ScalarZnx

: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + encrypt_grlwe_sk( + module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ) + } +} + +#[cfg(test)] +mod tests { + use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps}; + use sampling::source::Source; + + use crate::{ + elem::Infos, + elem_rlwe::{RLWECtDft, RLWEPt}, + keys::{SecretKey, SecretKeyDft}, + }; + + use super::GRLWECt; + + #[test] + fn encrypt_sk_vec_znx_fft64() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let rows: usize = 4; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECtDft::decrypt_scratch_bytes(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2); + + (0..ct.rows()).for_each(|row_i| { + ct.get_row(&module, row_i, &mut ct_rlwe_dft); + ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); + let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); + assert!((sigma - std_pt) <= 0.2, "{} {}", sigma, std_pt); + }); + + module.free(); + } +} diff --git a/rlwe/src/elem_rlwe.rs b/rlwe/src/elem_rlwe.rs index 8a7d444..19b5496 100644 --- a/rlwe/src/elem_rlwe.rs +++ b/rlwe/src/elem_rlwe.rs @@ -181,7 +181,7 @@ pub fn encrypt_rlwe_sk( module: &Module, ct: &mut RLWECt, pt: Option<&RLWEPt

>, - sk: &SecretKeyDft, + sk_dft: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -206,7 +206,7 @@ pub fn encrypt_rlwe_sk( module.vec_znx_dft(&mut c0_dft, 0, ct, 1); // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, sk, 0); + module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); // c0_big = IDFT(c0_dft) module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); @@ -227,7 +227,7 @@ pub fn decrypt_rlwe( module: &Module, pt: &mut RLWEPt

, ct: &RLWECt, - sk: &SecretKeyDft, + sk_dft: &SecretKeyDft, scratch: &mut Scratch, ) where VecZnx

: VecZnxToMut + VecZnxToRef, @@ -241,7 +241,7 @@ pub fn decrypt_rlwe( module.vec_znx_dft(&mut c0_dft, 0, ct, 1); // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, sk, 0); + module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); // c0_big = IDFT(c0_dft) module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); @@ -262,7 +262,7 @@ impl RLWECt { &mut self, module: &Module, pt: Option<&RLWEPt

>, - sk: &SecretKeyDft, + sk_dft: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -274,17 +274,22 @@ impl RLWECt { ScalarZnxDft: ScalarZnxDftToRef, { encrypt_rlwe_sk( - module, self, pt, sk, source_xa, source_xe, sigma, bound, scratch, + module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, ) } - pub fn decrypt(&self, module: &Module, pt: &mut RLWEPt

, sk: &SecretKeyDft, scratch: &mut Scratch) - where + pub fn decrypt( + &self, + module: &Module, + pt: &mut RLWEPt

, + sk_dft: &SecretKeyDft, + scratch: &mut Scratch, + ) where VecZnx

: VecZnxToMut + VecZnxToRef, VecZnx: VecZnxToRef, ScalarZnxDft: ScalarZnxDftToRef, { - decrypt_rlwe(module, pt, self, sk, scratch); + decrypt_rlwe(module, pt, self, sk_dft, scratch); } pub fn encrypt_pk( @@ -526,7 +531,7 @@ mod tests { }; #[test] - fn encrypt_sk_vec_znx_fft64() { + fn encrypt_sk_fft64() { let module: Module = Module::::new(32); let log_base2k: usize = 8; let log_k_ct: usize = 54; @@ -597,7 +602,7 @@ mod tests { } #[test] - fn encrypt_zero_rlwe_dft_sk_fft64() { + fn encrypt_zero_sk_fft64() { let module: Module = Module::::new(1024); let log_base2k: usize = 8; let log_k_ct: usize = 55; @@ -639,7 +644,7 @@ mod tests { } #[test] - fn encrypt_pk_vec_znx_fft64() { + fn encrypt_pk_fft64() { let module: Module = Module::::new(32); let log_base2k: usize = 8; let log_k_ct: usize = 54;