diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 2875b97..d8c1bdd 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,6 +1,6 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxDft, ZnxSliceSize, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned}; use std::fmt; use std::marker::PhantomData; diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index e86fd08..c943de1 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,7 +1,6 @@ -use base2k::{ - Backend, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, - VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, -}; +use base2k::ZnxInfos; + +use crate::utils::derive_size; pub trait Infos { type Inner: ZnxInfos; @@ -46,257 +45,3 @@ pub trait Infos { /// Returns the bit precision of the ciphertext. fn log_k(&self) -> usize; } - -pub struct RLWECt { - pub data: VecZnx, - pub log_base2k: usize, - pub log_k: usize, -} - -impl Infos for RLWECt { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl VecZnxToMut for RLWECt -where - VecZnx: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data.to_mut() - } -} - -impl VecZnxToRef for RLWECt -where - VecZnx: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data.to_ref() - } -} - -pub struct RLWEPt { - pub data: VecZnx, - pub log_base2k: usize, - pub log_k: usize, -} - -impl Infos for RLWEPt { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl VecZnxToMut for RLWEPt -where - VecZnx: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data.to_mut() - } -} - -impl VecZnxToRef for RLWEPt -where - VecZnx: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data.to_ref() - } -} - -impl RLWECt> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { - Self { - data: module.new_vec_znx(cols, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl RLWEPt> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -pub struct RLWECtDft { - pub data: VecZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} - -impl RLWECtDft, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { - Self { - data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for RLWECtDft { - type Inner = VecZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl VecZnxDftToMut for RLWECtDft -where - VecZnxDft: VecZnxDftToMut, -{ - fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { - self.data.to_mut() - } -} - -impl VecZnxDftToRef for RLWECtDft -where - VecZnxDft: VecZnxDftToRef, -{ - fn to_ref(&self) -> VecZnxDft<&[u8], B> { - self.data.to_ref() - } -} - -pub struct GRLWECt { - pub data: MatZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} - -impl GRLWECt, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { - Self { - data: module.new_mat_znx_dft(rows, 1, 2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for GRLWECt { - type Inner = MatZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl MatZnxDftToMut for GRLWECt -where - MatZnxDft: MatZnxDftToMut, -{ - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - self.data.to_mut() - } -} - -impl MatZnxDftToRef for GRLWECt -where - MatZnxDft: MatZnxDftToRef, -{ - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - self.data.to_ref() - } -} - -pub struct RGSWCt { - pub data: MatZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} - -impl RGSWCt, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { - Self { - data: module.new_mat_znx_dft(rows, 2, 2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for RGSWCt { - type Inner = MatZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl MatZnxDftToMut for RGSWCt -where - MatZnxDft: MatZnxDftToMut, -{ - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - self.data.to_mut() - } -} - -impl MatZnxDftToRef for RGSWCt -where - MatZnxDft: MatZnxDftToRef, -{ - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - self.data.to_ref() - } -} - -pub(crate) fn derive_size(log_base2k: usize, log_k: usize) -> usize { - (log_k + log_base2k - 1) / log_base2k -} diff --git a/rlwe/src/elem_grlwe.rs b/rlwe/src/elem_grlwe.rs new file mode 100644 index 0000000..b269cb3 --- /dev/null +++ b/rlwe/src/elem_grlwe.rs @@ -0,0 +1,53 @@ +use base2k::{Backend, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module}; + +use crate::{elem::Infos, utils::derive_size}; + +pub struct GRLWECt { + pub data: MatZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl GRLWECt, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { + Self { + data: module.new_mat_znx_dft(rows, 1, 2, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for GRLWECt { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl MatZnxDftToMut for GRLWECt +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl MatZnxDftToRef for GRLWECt +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.data.to_ref() + } +} diff --git a/rlwe/src/elem_rgsw.rs b/rlwe/src/elem_rgsw.rs new file mode 100644 index 0000000..1a1ea24 --- /dev/null +++ b/rlwe/src/elem_rgsw.rs @@ -0,0 +1,140 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxDftOps, ZnxView, ZnxViewMut, +}; +use sampling::source::Source; + +use crate::{ + elem::Infos, + elem_grlwe::GRLWECt, + elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, + keys::SecretKeyDft, + utils::derive_size, +}; + +pub struct RGSWCt { + pub data: MatZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl RGSWCt, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { + Self { + data: module.new_mat_znx_dft(rows, 2, 2, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for RGSWCt { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl MatZnxDftToMut for RGSWCt +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl MatZnxDftToRef for RGSWCt +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl GRLWECt, FFT64> { + pub fn encrypt_sk_scratch_bytes(module: &Module, size: usize) -> usize { + RLWECt::encrypt_sk_scratch_bytes(module, size) + + module.bytes_of_vec_znx(2, size) + + module.bytes_of_vec_znx(1, size) + + module.bytes_of_vec_znx_dft(2, size) + } + + pub fn encrypt_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { + RLWECt::encrypt_pk_scratch_bytes(module, pk_size) + } + + pub fn decrypt_scratch_bytes(module: &Module, size: usize) -> usize { + RLWECtDft::decrypt_scratch_bytes(module, size) + } +} + +pub fn encrypt_grlwe_sk( + module: &Module, + ct: &mut GRLWECt, + pt: &ScalarZnx

, + sk: &SecretKeyDft, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + MatZnxDft: MatZnxDftToMut, + ScalarZnx

: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let rows: usize = ct.rows(); + let size: usize = ct.size(); + + let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); + let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size); + let (mut tmp_dft, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size); + + let mut tmp_pt: RLWEPt<&mut [u8]> = RLWEPt { + data: tmp_znx_pt, + log_base2k: ct.log_base2k(), + log_k: ct.log_k(), + }; + + let mut tmp_ct: RLWECt<&mut [u8]> = RLWECt { + data: tmp_znx_ct, + log_base2k: ct.log_base2k(), + log_k: ct.log_k(), + }; + + (0..rows).for_each(|row_i| { + tmp_pt + .data + .at_mut(0, row_i) + .copy_from_slice(&pt.to_ref().raw()); + + tmp_ct.encrypt_sk( + module, + Some(&tmp_pt), + sk, + source_xa, + source_xe, + sigma, + bound, + scratch_3, + ); + + tmp_pt.data.at_mut(0, row_i).fill(0); + + module.vec_znx_dft(&mut tmp_dft, 0, &tmp_ct, 0); + module.vec_znx_dft(&mut tmp_dft, 1, &tmp_ct, 1); + + module.vmp_prepare_row(ct, row_i, 0, &tmp_dft); + }); +} diff --git a/rlwe/src/encryption.rs b/rlwe/src/elem_rlwe.rs similarity index 76% rename from rlwe/src/encryption.rs rename to rlwe/src/elem_rlwe.rs index 0bdae33..8a7d444 100644 --- a/rlwe/src/encryption.rs +++ b/rlwe/src/elem_rlwe.rs @@ -1,20 +1,180 @@ -use std::cmp::min; - use base2k::{ AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, }; - use sampling::source::Source; use crate::{ - elem::{Infos, RLWECt, RLWECtDft, RLWEPt}, + elem::Infos, keys::{PublicKey, SecretDistribution, SecretKeyDft}, + utils::derive_size, }; -pub fn encrypt_rlwe_sk_scratch_bytes(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) +pub struct RLWECt { + pub data: VecZnx, + pub log_base2k: usize, + pub log_k: usize, +} + +impl RLWECt> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { + Self { + data: module.new_vec_znx(cols, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for RLWECt { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl VecZnxToMut for RLWECt +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for RLWECt +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +pub struct RLWEPt { + pub data: VecZnx, + pub log_base2k: usize, + pub log_k: usize, +} + +impl Infos for RLWEPt { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl VecZnxToMut for RLWEPt +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for RLWEPt +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl RLWEPt> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +pub struct RLWECtDft { + pub data: VecZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl RLWECtDft, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { + Self { + data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for RLWECtDft { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl VecZnxDftToMut for RLWECtDft +where + VecZnxDft: VecZnxDftToMut, +{ + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl VecZnxDftToRef for RLWECtDft +where + VecZnxDft: VecZnxDftToRef, +{ + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl RLWECt> { + pub fn encrypt_sk_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + } + + pub fn encrypt_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { + ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) + + module.bytes_of_scalar_znx_dft(1) + + module.vec_znx_big_normalize_tmp_bytes() + } + + pub fn decrypt_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + } } pub fn encrypt_rlwe_sk( @@ -94,11 +254,7 @@ pub fn decrypt_rlwe( module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1); pt.log_base2k = ct.log_base2k(); - pt.log_k = min(pt.log_k(), ct.log_k()); -} - -pub fn decrypt_rlwe_scratch_bytes(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + pt.log_k = pt.log_k().min(ct.log_k()); } impl RLWECt { @@ -207,11 +363,20 @@ pub(crate) fn encrypt_zero_rlwe_dft_sk( module.vec_znx_dft(ct, 0, &tmp_znx, 0); } -pub(crate) fn encrypt_zero_rlwe_dft_scratch_bytes(module: &Module, size: usize) -> usize { - (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) - + module.bytes_of_vec_znx_big(1, size) - + module.bytes_of_vec_znx(1, size) - + module.vec_znx_big_normalize_tmp_bytes() +impl RLWECtDft, FFT64> { + pub fn encrypt_zero_sk_scratch_bytes(module: &Module, size: usize) -> usize { + (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) + + module.bytes_of_vec_znx_big(1, size) + + module.bytes_of_vec_znx(1, size) + + module.vec_znx_big_normalize_tmp_bytes() + } + + pub fn decrypt_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() + | module.bytes_of_vec_znx_dft(1, size) + | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) + + module.bytes_of_vec_znx_big(1, size) + } } pub fn decrypt_rlwe_dft( @@ -246,14 +411,7 @@ pub fn decrypt_rlwe_dft( module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1); pt.log_base2k = ct.log_base2k(); - pt.log_k = min(pt.log_k(), ct.log_k()); -} - -pub fn decrypt_rlwe_dft_scratch_bytes(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() - | module.bytes_of_vec_znx_dft(1, size) - | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) - + module.bytes_of_vec_znx_big(1, size) + pt.log_k = pt.log_k().min(ct.log_k()); } impl RLWECtDft { @@ -290,12 +448,6 @@ impl RLWECtDft { } } -pub fn encrypt_rlwe_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { - ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) - + module.bytes_of_scalar_znx_dft(1) - + module.vec_znx_big_normalize_tmp_bytes() -} - pub(crate) fn encrypt_rlwe_pk( module: &Module, ct: &mut RLWECt, @@ -369,13 +521,10 @@ mod tests { use sampling::source::Source; use crate::{ - elem::{Infos, RLWECt, RLWECtDft, RLWEPt}, - encryption::{decrypt_rlwe_dft_scratch_bytes, encrypt_zero_rlwe_dft_scratch_bytes}, + elem_rlwe::{Infos, RLWECt, RLWECtDft, RLWEPt}, keys::{PublicKey, SecretKey, SecretKeyDft}, }; - use super::{decrypt_rlwe_scratch_bytes, encrypt_rlwe_pk_scratch_bytes, encrypt_rlwe_sk_scratch_bytes}; - #[test] fn encrypt_sk_vec_znx_fft64() { let module: Module = Module::::new(32); @@ -393,8 +542,9 @@ mod tests { let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = - ScratchOwned::new(encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) | decrypt_rlwe_scratch_bytes(&module, ct.size())); + let mut scratch: ScratchOwned = ScratchOwned::new( + RLWECt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECt::decrypt_scratch_bytes(&module, ct.size()), + ); let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); @@ -469,9 +619,8 @@ mod tests { let mut ct_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2); let mut scratch: ScratchOwned = ScratchOwned::new( - encrypt_rlwe_sk_scratch_bytes(&module, ct_dft.size()) - | decrypt_rlwe_dft_scratch_bytes(&module, ct_dft.size()) - | encrypt_zero_rlwe_dft_scratch_bytes(&module, ct_dft.size()), + RLWECtDft::decrypt_scratch_bytes(&module, ct_dft.size()) + | RLWECtDft::encrypt_zero_sk_scratch_bytes(&module, ct_dft.size()), ); ct_dft.encrypt_zero_sk( @@ -523,9 +672,9 @@ mod tests { ); let mut scratch: ScratchOwned = ScratchOwned::new( - encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) - | decrypt_rlwe_scratch_bytes(&module, ct.size()) - | encrypt_rlwe_pk_scratch_bytes(&module, pk.size()), + RLWECt::encrypt_sk_scratch_bytes(&module, ct.size()) + | RLWECt::decrypt_scratch_bytes(&module, ct.size()) + | RLWECt::encrypt_pk_scratch_bytes(&module, pk.size()), ); let mut data_want: Vec = vec![0i64; module.n()]; diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index 89c33e3..2f7b2c7 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -4,10 +4,7 @@ use base2k::{ }; use sampling::source::Source; -use crate::{ - elem::{Infos, RLWECtDft}, - encryption::encrypt_zero_rlwe_dft_scratch_bytes, -}; +use crate::{elem::Infos, elem_rlwe::RLWECtDft}; #[derive(Clone, Copy, Debug)] pub enum SecretDistribution { @@ -182,7 +179,10 @@ impl PublicKey { } // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::new(encrypt_zero_rlwe_dft_scratch_bytes(module, self.size())); + let mut scratch: ScratchOwned = ScratchOwned::new(RLWECtDft::encrypt_zero_sk_scratch_bytes( + module, + self.size(), + )); self.data.encrypt_zero_sk( module, sk_dft, diff --git a/rlwe/src/lib.rs b/rlwe/src/lib.rs index 023acb5..9eea116 100644 --- a/rlwe/src/lib.rs +++ b/rlwe/src/lib.rs @@ -1,3 +1,6 @@ pub mod elem; -pub mod encryption; +pub mod elem_grlwe; +pub mod elem_rgsw; +pub mod elem_rlwe; pub mod keys; +mod utils; diff --git a/rlwe/src/utils.rs b/rlwe/src/utils.rs new file mode 100644 index 0000000..0bb0b45 --- /dev/null +++ b/rlwe/src/utils.rs @@ -0,0 +1,3 @@ +pub(crate) fn derive_size(log_base2k: usize, log_k: usize) -> usize { + (log_k + log_base2k - 1) / log_base2k +}