From de3b34477d4013dd588f1c6deac9cdf7a2de15b3 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 8 May 2025 18:32:19 +0200 Subject: [PATCH] added rgsw encrypt + test --- base2k/src/lib.rs | 4 +- base2k/src/scalar_znx.rs | 57 ++++++++- base2k/src/scalar_znx_dft_ops.rs | 206 +++++++++++++++---------------- base2k/src/vec_znx.rs | 6 +- base2k/src/vec_znx_big_ops.rs | 22 ++++ rlwe/src/elem_grlwe.rs | 8 +- rlwe/src/elem_rgsw.rs | 203 ++++++++++++++++++++++++------ rlwe/src/elem_rlwe.rs | 40 ++++-- 8 files changed, 384 insertions(+), 162 deletions(-) diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index bb8ce55..b6ed099 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -196,7 +196,7 @@ impl Scratch { } } - pub fn tmp_scalar(&mut self, module: &Module, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { + pub fn tmp_scalar_znx(&mut self, module: &Module, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols)); ( @@ -205,7 +205,7 @@ impl Scratch { ) } - pub fn tmp_scalar_dft(&mut self, module: &Module, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) { + pub fn tmp_scalar_znx_dft(&mut self, module: &Module, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) { let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols)); ( diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index dde286a..28ee38a 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -1,5 +1,5 @@ use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, Module, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned}; +use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; @@ -144,6 +144,17 @@ impl ScalarZnxToMut for ScalarZnx> { } } +impl VecZnxToMut for ScalarZnx>{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + VecZnx { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: 1, + } + } +} + impl ScalarZnxToRef for ScalarZnx> { fn to_ref(&self) -> ScalarZnx<&[u8]> { ScalarZnx { @@ -154,6 +165,17 @@ impl ScalarZnxToRef for ScalarZnx> { } } +impl VecZnxToRef for ScalarZnx>{ + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: 1, + } + } +} + impl ScalarZnxToMut for ScalarZnx<&mut [u8]> { fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> { ScalarZnx { @@ -164,6 +186,17 @@ impl ScalarZnxToMut for ScalarZnx<&mut [u8]> { } } +impl VecZnxToMut for ScalarZnx<&mut [u8]> { + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + } + } +} + impl ScalarZnxToRef for ScalarZnx<&mut [u8]> { fn to_ref(&self) -> ScalarZnx<&[u8]> { ScalarZnx { @@ -174,6 +207,17 @@ impl ScalarZnxToRef for ScalarZnx<&mut [u8]> { } } +impl VecZnxToRef for ScalarZnx<&mut [u8]> { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + } + } +} + impl ScalarZnxToRef for ScalarZnx<&[u8]> { fn to_ref(&self) -> ScalarZnx<&[u8]> { ScalarZnx { @@ -183,3 +227,14 @@ impl ScalarZnxToRef for ScalarZnx<&[u8]> { } } } + +impl VecZnxToRef for ScalarZnx<&[u8]> { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + } + } +} diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs index f02fa03..1e0313a 100644 --- a/base2k/src/scalar_znx_dft_ops.rs +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -1,103 +1,103 @@ -use crate::ffi::svp; -use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{ - Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft, - VecZnxDftToMut, VecZnxDftToRef, -}; - -pub trait ScalarZnxDftAlloc { - fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned; - fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; - fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned; -} - -pub trait ScalarZnxDftOps { - fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: ScalarZnxDftToMut, - A: ScalarZnxToRef; - fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxDftToMut, - A: ScalarZnxDftToRef, - B: VecZnxDftToRef; - fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: ScalarZnxDftToRef; -} - -impl ScalarZnxDftAlloc for Module { - fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned { - ScalarZnxDftOwned::new(self, cols) - } - - fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize { - ScalarZnxDftOwned::bytes_of(self, cols) - } - - fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned { - ScalarZnxDftOwned::new_from_bytes(self, cols, bytes) - } -} - -impl ScalarZnxDftOps for Module { - fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: ScalarZnxDftToMut, - A: ScalarZnxToRef, - { - unsafe { - svp::svp_prepare( - self.ptr, - res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t, - a.to_ref().at_ptr(a_col, 0), - ) - } - } - - fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxDftToMut, - A: ScalarZnxDftToRef, - B: VecZnxDftToRef, - { - let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); - let b: VecZnxDft<&[u8], FFT64> = b.to_ref(); - unsafe { - svp::svp_apply_dft_to_dft( - self.ptr, - res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, - res.size() as u64, - res.cols() as u64, - a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, - b.at_ptr(b_col, 0) as *const vec_znx_dft_t, - b.size() as u64, - b.cols() as u64, - ) - } - } - - fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: ScalarZnxDftToRef, - { - let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); - unsafe { - svp::svp_apply_dft_to_dft( - self.ptr, - res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, - res.size() as u64, - res.cols() as u64, - a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, - res.at_ptr(res_col, 0) as *const vec_znx_dft_t, - res.size() as u64, - res.cols() as u64, - ) - } - } -} +use crate::ffi::svp; +use crate::ffi::vec_znx_dft::vec_znx_dft_t; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; +use crate::{ + Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft, + VecZnxDftToMut, VecZnxDftToRef, +}; + +pub trait ScalarZnxDftAlloc { + fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned; + fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; + fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned; +} + +pub trait ScalarZnxDftOps { + fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxDftToMut, + A: ScalarZnxToRef; + fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + B: VecZnxDftToRef; + fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef; +} + +impl ScalarZnxDftAlloc for Module { + fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned { + ScalarZnxDftOwned::new(self, cols) + } + + fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize { + ScalarZnxDftOwned::bytes_of(self, cols) + } + + fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned { + ScalarZnxDftOwned::new_from_bytes(self, cols, bytes) + } +} + +impl ScalarZnxDftOps for Module { + fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxDftToMut, + A: ScalarZnxToRef, + { + unsafe { + svp::svp_prepare( + self.ptr, + res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t, + a.to_ref().at_ptr(a_col, 0), + ) + } + } + + fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + B: VecZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); + let b: VecZnxDft<&[u8], FFT64> = b.to_ref(); + unsafe { + svp::svp_apply_dft_to_dft( + self.ptr, + res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, + res.size() as u64, + res.cols() as u64, + a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, + b.at_ptr(b_col, 0) as *const vec_znx_dft_t, + b.size() as u64, + b.cols() as u64, + ) + } + } + + fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); + unsafe { + svp::svp_apply_dft_to_dft( + self.ptr, + res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, + res.size() as u64, + res.cols() as u64, + a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, + res.at_ptr(res_col, 0) as *const vec_znx_dft_t, + res.size() as u64, + res.cols() as u64, + ) + } + } +} diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 70d8fb3..31459d4 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -20,9 +20,9 @@ use std::{cmp::min, fmt}; /// are small polynomials of Zn\[X\]. pub struct VecZnx { pub data: D, - n: usize, - cols: usize, - size: usize, + pub n: usize, + pub cols: usize, + pub size: usize, } impl ZnxInfos for VecZnx { diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 933deb3..809a1eb 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -114,6 +114,9 @@ pub trait VecZnxBigOps { R: VecZnxBigToMut, A: VecZnxToRef; + /// Negates `a` inplace. + fn vec_znx_big_negate_inplace(&self, a: &mut A, a_col: usize) where A: VecZnxBigToMut; + /// Normalizes `a` and stores the result on `b`. /// /// # Arguments @@ -503,6 +506,25 @@ impl VecZnxBigOps for Module { } } + fn vec_znx_big_negate_inplace(&self, a: &mut A, res_col: usize) where A: VecZnxBigToMut { + let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_negate( + self.ptr, + a.at_mut_ptr(res_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(res_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + fn vec_znx_big_normalize( &self, log_base2k: usize, diff --git a/rlwe/src/elem_grlwe.rs b/rlwe/src/elem_grlwe.rs index a0000cf..a460ec4 100644 --- a/rlwe/src/elem_grlwe.rs +++ b/rlwe/src/elem_grlwe.rs @@ -91,7 +91,7 @@ pub fn encrypt_grlwe_sk( module: &Module, ct: &mut GRLWECt, pt: &ScalarZnx

, - sk: &SecretKeyDft, + sk_dft: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -131,7 +131,7 @@ pub fn encrypt_grlwe_sk( vec_znx_ct.encrypt_sk( module, Some(&vec_znx_pt), - sk, + sk_dft, source_xa, source_xe, sigma, @@ -186,7 +186,7 @@ mod tests { use super::GRLWECt; #[test] - fn encrypt_sk_vec_znx_fft64() { + fn encrypt_sk_fft64() { let module: Module = Module::::new(2048); let log_base2k: usize = 8; let log_k_ct: usize = 54; @@ -233,7 +233,7 @@ mod tests { ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); - assert!((sigma - std_pt) <= 0.2, "{} {}", sigma, std_pt); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); }); module.free(); diff --git a/rlwe/src/elem_rgsw.rs b/rlwe/src/elem_rgsw.rs index 1a1ea24..75d6583 100644 --- a/rlwe/src/elem_rgsw.rs +++ b/rlwe/src/elem_rgsw.rs @@ -1,13 +1,13 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, - ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxDftOps, ZnxView, ZnxViewMut, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxOps, + ZnxZero, }; use sampling::source::Source; use crate::{ elem::Infos, - elem_grlwe::GRLWECt, - elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, + elem_rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, keys::SecretKeyDft, utils::derive_size, }; @@ -62,28 +62,32 @@ where } } -impl GRLWECt, FFT64> { +impl RGSWCt, FFT64> { pub fn encrypt_sk_scratch_bytes(module: &Module, size: usize) -> usize { RLWECt::encrypt_sk_scratch_bytes(module, size) + module.bytes_of_vec_znx(2, size) + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) } +} - pub fn encrypt_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { - RLWECt::encrypt_pk_scratch_bytes(module, pk_size) - } - - pub fn decrypt_scratch_bytes(module: &Module, size: usize) -> usize { - RLWECtDft::decrypt_scratch_bytes(module, size) +impl RGSWCt +where + MatZnxDft: MatZnxDftToRef, +{ + pub fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + where + VecZnxDft: VecZnxDftToMut, + { + module.vmp_extract_row(res, self, row_i, col_j); } } -pub fn encrypt_grlwe_sk( +pub fn encrypt_rgsw_sk( module: &Module, - ct: &mut GRLWECt, + ct: &mut RGSWCt, pt: &ScalarZnx

, - sk: &SecretKeyDft, + sk_dft: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -94,47 +98,164 @@ pub fn encrypt_grlwe_sk( ScalarZnx

: ScalarZnxToRef, ScalarZnxDft: ScalarZnxDftToRef, { - let rows: usize = ct.rows(); let size: usize = ct.size(); + let log_base2k: usize = ct.log_base2k(); - let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); - let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size); - let (mut tmp_dft, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size); + let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size); + let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, 2, size); - let mut tmp_pt: RLWEPt<&mut [u8]> = RLWEPt { + let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt { data: tmp_znx_pt, - log_base2k: ct.log_base2k(), + log_base2k: log_base2k, log_k: ct.log_k(), }; - let mut tmp_ct: RLWECt<&mut [u8]> = RLWECt { + let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt { data: tmp_znx_ct, - log_base2k: ct.log_base2k(), + log_base2k: log_base2k, log_k: ct.log_k(), }; - (0..rows).for_each(|row_i| { - tmp_pt - .data - .at_mut(0, row_i) - .copy_from_slice(&pt.to_ref().raw()); + (0..ct.rows()).for_each(|row_j| { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0); + module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2); - tmp_ct.encrypt_sk( - module, - Some(&tmp_pt), - sk, - source_xa, - source_xe, - sigma, - bound, - scratch_3, - ); + (0..ct.cols()).for_each(|col_i| { + // rlwe encrypt of vec_znx_pt into vec_znx_ct + encrypt_rlwe_sk( + module, + &mut vec_znx_ct, + Some((&vec_znx_pt, col_i)), + sk_dft, + source_xa, + source_xe, + sigma, + bound, + scrach_2, + ); - tmp_pt.data.at_mut(0, row_i).fill(0); + // Switch vec_znx_ct into DFT domain + { + let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, 2, size); + module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0); + module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1); + module.vmp_prepare_row(ct, row_j, col_i, &vec_znx_dft_ct); + } + }); - module.vec_znx_dft(&mut tmp_dft, 0, &tmp_ct, 0); - module.vec_znx_dft(&mut tmp_dft, 1, &tmp_ct, 1); - - module.vmp_prepare_row(ct, row_i, 0, &tmp_dft); + vec_znx_pt.data.zero(); // zeroes for next iteration }); } + +impl RGSWCt { + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx

, + sk_dft: &SecretKeyDft, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToMut, + ScalarZnx

: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + encrypt_rgsw_sk( + module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ) + } +} + +#[cfg(test)] +mod tests { + use base2k::{ + FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero, + }; + use sampling::source::Source; + + use crate::{ + elem::Infos, + elem_rlwe::{RLWECtDft, RLWEPt}, + keys::{SecretKey, SecretKeyDft}, + }; + + use super::RGSWCt; + + #[test] + fn encrypt_rgsw_sk_fft64() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let rows: usize = 4; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECtDft::decrypt_scratch_bytes(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); + + (0..ct.cols()).for_each(|col_j| { + (0..ct.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + if col_j == 1 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); + + ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + + pt_want.data.zero(); + }); + }); + + module.free(); + } +} diff --git a/rlwe/src/elem_rlwe.rs b/rlwe/src/elem_rlwe.rs index 19b5496..938b3c5 100644 --- a/rlwe/src/elem_rlwe.rs +++ b/rlwe/src/elem_rlwe.rs @@ -180,7 +180,7 @@ impl RLWECt> { pub fn encrypt_rlwe_sk( module: &Module, ct: &mut RLWECt, - pt: Option<&RLWEPt

>, + pt: Option<(&RLWEPt

, usize)>, sk_dft: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, @@ -213,8 +213,18 @@ pub fn encrypt_rlwe_sk( } // c0_big = m - c0_big - if let Some(pt) = pt { - module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0); + if let Some((pt, col)) = pt { + match col { + 0 => module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0), + 1 => { + module.vec_znx_big_negate_inplace(&mut c0_big, 0); + module.vec_znx_add_inplace(ct, 1, pt, 0); + module.vec_znx_normalize_inplace(log_base2k, ct, 1, scratch_1); + } + _ => panic!("invalid target column: {}", col), + } + } else { + module.vec_znx_big_negate_inplace(&mut c0_big, 0); } // c0_big += e c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); @@ -273,9 +283,23 @@ impl RLWECt { VecZnx

: VecZnxToRef, ScalarZnxDft: ScalarZnxDftToRef, { - encrypt_rlwe_sk( - module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ) + if let Some(pt) = pt { + encrypt_rlwe_sk( + module, + self, + Some((pt, 0)), + sk_dft, + source_xa, + source_xe, + sigma, + bound, + scratch, + ) + } else { + encrypt_rlwe_sk::( + module, self, None, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ) + } } pub fn decrypt( @@ -483,10 +507,10 @@ pub(crate) fn encrypt_rlwe_pk( let size_pk: usize = pk.size(); // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.tmp_scalar_dft(module, 1); + let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); { - let (mut u, _) = scratch_1.tmp_scalar(module, 1); + let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); match pk.dist { SecretDistribution::NONE => panic!( "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate"