diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 7574bf5..5749208 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,4 +1,4 @@ -use base2k::{Backend, FFT64, MatZnxDft, MatZnxDftAlloc, Module, VecZnx, VecZnxAlloc}; +use base2k::{Backend, FFT64, MatZnxDft, MatZnxDftAlloc, Module, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc}; pub struct Ciphertext { data: T, @@ -48,7 +48,12 @@ impl Plaintext { } } -pub(crate) type CipherVecZnx = Ciphertext>; +pub(crate) type CtVecZnx = Ciphertext>; +pub(crate) type CtVecZnxDft = Ciphertext>; +pub(crate) type CtMatZnxDft = Ciphertext>; +pub(crate) type PtVecZnx = Plaintext>; +pub(crate) type PtVecZnxDft = Plaintext>; +pub(crate) type PtMatZnxDft = Plaintext>; impl Ciphertext>> { pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { @@ -70,6 +75,16 @@ impl Plaintext>> { } } +impl Ciphertext, B>> { + pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { + Self { + data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_q)), + log_base2k: log_base2k, + log_q: log_q, + } + } +} + impl Ciphertext, B>> { pub fn new(module: &Module, log_base2k: usize, rows: usize, cols_in: usize, cols_out: usize, log_q: usize) -> Self { Self { diff --git a/rlwe/src/encryption.rs b/rlwe/src/encryption.rs index 3b291f9..3d62bfe 100644 --- a/rlwe/src/encryption.rs +++ b/rlwe/src/encryption.rs @@ -1,20 +1,21 @@ use base2k::{ - AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxBigAlloc, - VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxToMut, VecZnxToRef, ZnxInfos, + AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, + VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, + VecZnxToMut, VecZnxToRef, ZnxInfos, }; use sampling::source::Source; use crate::{ - elem::{CipherVecZnx, Plaintext}, + elem::{CtVecZnx, CtVecZnxDft, PtVecZnx}, keys::SecretKey, }; -pub trait EncryptSk { - fn encrypt( +pub trait EncryptSk { + fn encrypt( module: &Module, res: &mut D, - pt: Option<&Plaintext

>, + pt: Option<&P>, sk: &SecretKey, source_xa: &mut Source, source_xe: &mut Source, @@ -22,7 +23,6 @@ pub trait EncryptSk { sigma: f64, bound: f64, ) where - P: VecZnxToRef, S: ScalarZnxDftToRef; fn encrypt_tmp_bytes(module: &Module, size: usize) -> usize { @@ -30,14 +30,15 @@ pub trait EncryptSk { } } -impl EncryptSk> for CipherVecZnx +impl EncryptSk, PtVecZnx

> for CtVecZnx where VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx

: VecZnxToRef, { - fn encrypt( + fn encrypt( module: &Module, - ct: &mut CipherVecZnx, - pt: Option<&Plaintext

>, + ct: &mut CtVecZnx, + pt: Option<&PtVecZnx

>, sk: &SecretKey, source_xa: &mut Source, source_xe: &mut Source, @@ -45,7 +46,6 @@ where sigma: f64, bound: f64, ) where - P: VecZnxToRef, S: ScalarZnxDftToRef, { let log_base2k: usize = ct.log_base2k(); @@ -60,9 +60,10 @@ where { let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + module.vec_znx_dft(&mut c0_dft, 0, &ct_mut, 1); // c0_dft = DFT(a) * DFT(s) - module.svp_apply(&mut c0_dft, 0, &sk.data().to_ref(), 0, &ct_mut, 1); + module.svp_apply_inplace(&mut c0_dft, 0, &sk.data().to_ref(), 0); // c0_big = IDFT(c0_dft) module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); @@ -79,3 +80,73 @@ where module.vec_znx_big_normalize(log_base2k, &mut ct_mut, 0, &c0_big, 0, scratch_1); } } + +pub trait EncryptZeroSk { + fn encrypt_zero( + module: &Module, + res: &mut D, + sk: &SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + sigma: f64, + bound: f64, + ) where + S: ScalarZnxDftToRef; + + fn encrypt_zero_tmp_bytes(module: &Module, size: usize) -> usize { + (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) + + module.bytes_of_vec_znx_big(1, size) + + module.bytes_of_vec_znx(1, size) + + module.vec_znx_big_normalize_tmp_bytes() + } +} + +impl EncryptZeroSk> for CtVecZnxDft +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, +{ + fn encrypt_zero( + module: &Module, + ct: &mut CtVecZnxDft, + sk: &SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + sigma: f64, + bound: f64, + ) where + S: ScalarZnxDftToRef, + { + let log_base2k: usize = ct.log_base2k(); + let log_q: usize = ct.log_q(); + let mut ct_mut: VecZnxDft<&mut [u8], FFT64> = ct.data_mut().to_mut(); + let size: usize = ct_mut.size(); + + // ct[1] = DFT(a) + { + let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); + tmp_znx.fill_uniform(log_base2k, 1, size, source_xa); + module.vec_znx_dft(&mut ct_mut, 1, &tmp_znx, 0); + } + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); + + { + let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + // c0_dft = DFT(a) * DFT(s) + module.svp_apply(&mut tmp_dft, 0, &sk.data().to_ref(), 0, &ct_mut, 1); + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); + } + + // c0_big += e + c0_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound); + + // c0 = norm(c0_big = -as + e) + let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); + module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); + // ct[0] = DFT(-as + e) + module.vec_znx_dft(&mut ct_mut, 0, &tmp_znx, 0); + } +} diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index 50f1221..d84abc0 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -15,7 +15,7 @@ impl SecretKey { &self.data } - pub fn data_mut(&self) -> &mut T { + pub fn data_mut(&mut self) -> &mut T { &mut self.data } }