From 26c2bcbc05eeacc3b070c5dea3d3be7c5c277c60 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 24 Feb 2025 08:31:02 +0100 Subject: [PATCH] Fixed gadget product & added noise estimations --- Cargo.lock | 30 +++ Cargo.toml | 1 + base2k/Cargo.toml | 1 + base2k/src/encoding.rs | 38 ++++ base2k/src/lib.rs | 3 + base2k/src/stats.rs | 28 +++ base2k/src/vec_znx.rs | 3 + base2k/src/vmp.rs | 24 +-- rlwe/Cargo.toml | 1 + rlwe/benches/gadget_product.rs | 20 +- rlwe/examples/encryption.rs | 5 +- rlwe/examples/gadget_product.rs | 140 ------------ rlwe/src/ciphertext.rs | 123 ++++++----- rlwe/src/decryptor.rs | 24 +-- rlwe/src/elem.rs | 45 +++- rlwe/src/encryptor.rs | 43 ++-- rlwe/src/evaluator.rs | 171 --------------- rlwe/src/gadget_product.rs | 371 ++++++++++++++++++++++++++++++++ rlwe/src/key_generator.rs | 8 +- rlwe/src/keys.rs | 2 +- rlwe/src/lib.rs | 3 +- rlwe/src/parameters.rs | 4 +- rlwe/src/plaintext.rs | 92 ++++---- rlwe/src/rgsw_product.rs | 55 +++++ 24 files changed, 762 insertions(+), 473 deletions(-) create mode 100644 base2k/src/stats.rs delete mode 100644 rlwe/examples/gadget_product.rs delete mode 100644 rlwe/src/evaluator.rs create mode 100644 rlwe/src/gadget_product.rs create mode 100644 rlwe/src/rgsw_product.rs diff --git a/Cargo.lock b/Cargo.lock index 631a1f5..c32b92a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -49,6 +49,12 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "az" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b7e4c2464d97fe331d41de9d5db0def0a96f4d823b8b32a2efd503578988973" + [[package]] name = "base2k" version = "0.1.0" @@ -58,6 +64,7 @@ dependencies = [ "rand", "rand_core", "rand_distr", + "rug", "sampling", "utils", ] @@ -228,6 +235,16 @@ dependencies = [ "wasi", ] +[[package]] +name = "gmp-mpfr-sys" +version = "1.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0205cd82059bc63b63cf516d714352a30c44f2c74da9961dfda2617ae6b5918" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "half" version = "2.4.1" @@ -639,6 +656,7 @@ dependencies = [ "base2k", "criterion", "rand_distr", + "rug", "sampling", ] @@ -660,6 +678,18 @@ dependencies = [ "utils", ] +[[package]] +name = "rug" +version = "1.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4207e8d668e5b8eb574bda8322088ccd0d7782d3d03c7e8d562e82ed82bdcbc3" +dependencies = [ + "az", + "gmp-mpfr-sys", + "libc", + "libm", +] + [[package]] name = "ryu" version = "1.0.18" diff --git a/Cargo.toml b/Cargo.toml index 266d242..4a1653f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = ["base2k", "rlwe", "rns", "sampling", "utils"] [workspace.dependencies] +rug = "1.27" rand = "0.8.4" rand_chacha = "0.3.1" rand_core = "0.6.4" diff --git a/base2k/Cargo.toml b/base2k/Cargo.toml index 5e829f2..2ebb8db 100644 --- a/base2k/Cargo.toml +++ b/base2k/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +rug = {workspace = true} criterion = {workspace = true} itertools = {workspace = true} rand = {workspace = true} diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 5bdaa72..0380c03 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -1,6 +1,7 @@ use crate::ffi::znx::znx_zero_i64_ref; use crate::{Infos, VecZnx, VecZnxApi}; use itertools::izip; +use rug::{Assign, Float}; use std::cmp::min; pub trait Encoding { @@ -23,6 +24,13 @@ pub trait Encoding { /// * `data`: data to decode from the receiver. fn decode_vec_i64(&self, log_base2k: usize, log_k: usize, data: &mut [i64]); + /// decode a vector of Float from the receiver. + /// + /// # Arguments + /// * `log_base2k`: base two logarithm decomposition of the receiver. + /// * `data`: data to decode from the receiver. + fn decode_vec_float(&self, log_base2k: usize, data: &mut [Float]); + /// encodes a single i64 on the receiver at the given index. /// /// # Arguments @@ -123,6 +131,36 @@ impl Encoding for VecZnx { }) } + fn decode_vec_float(&self, log_base2k: usize, data: &mut [Float]) { + let cols: usize = self.cols(); + assert!( + data.len() >= self.n(), + "invalid data: data.len()={} < self.n()={}", + data.len(), + self.n() + ); + + let prec: u32 = (log_base2k * cols) as u32; + + // 2^{log_base2k} + let base = Float::with_val(prec, (1 << log_base2k) as f64); + + // y[i] = sum x[j][i] * 2^{-log_base2k*j} + (0..cols).for_each(|i| { + if i == 0 { + izip!(self.at(cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + y.assign(*x); + *y /= &base; + }); + } else { + izip!(self.at(cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + *y += Float::with_val(prec, *x); + *y /= &base; + }); + } + }); + } + fn encode_coeff_i64( &mut self, log_base2k: usize, diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index ee25f6f..c69243e 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -12,6 +12,7 @@ pub mod free; pub mod infos; pub mod module; pub mod sampling; +pub mod stats; pub mod svp; pub mod vec_znx; pub mod vec_znx_big; @@ -23,6 +24,8 @@ pub use free::*; pub use infos::*; pub use module::*; pub use sampling::*; +#[allow(unused_imports)] +pub use stats::*; pub use svp::*; pub use vec_znx::*; pub use vec_znx_big::*; diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs new file mode 100644 index 0000000..60568b3 --- /dev/null +++ b/base2k/src/stats.rs @@ -0,0 +1,28 @@ +use crate::{Infos, VecZnx, Encoding}; +use rug::float::Round; +use rug::ops::{AddAssignRound, SubAssignRound, DivAssignRound}; +use rug::Float; + +impl VecZnx { + pub fn std(&self, log_base2k: usize) -> f64 { + let prec: u32 = (self.cols() * log_base2k) as u32; + let mut data: Vec = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); + self.decode_vec_float(log_base2k, &mut data); + // std = sqrt(sum((xi - avg)^2) / n) + let mut avg: Float = Float::with_val(prec, 0); + data.iter().for_each(|x| { + avg.add_assign_round(x, Round::Nearest); + }); + avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest); + data.iter_mut().for_each(|x| { + x.sub_assign_round(&avg, Round::Nearest); + }); + let mut std: Float = Float::with_val(prec, 0); + data.iter().for_each(|x| { + std += x*x + }); + std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest); + std = std.sqrt(); + std.to_f64() + } +} diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index eb885fb..01118be 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -42,6 +42,9 @@ pub trait VecZnxApi { /// Zeroes the backing array. fn zero(&mut self); + + /// Normalization: propagates carry and ensures each coefficients + /// falls into the range [-2^{K-1}, 2^{K-1}]. fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]); /// Right shifts the coefficients by k bits. diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index 46d2a09..39be0ee 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -454,18 +454,18 @@ impl VmpPMatOps for Module { fn vmp_apply_dft_tmp_bytes( &self, - c_cols: usize, + res_cols: usize, a_cols: usize, - rows: usize, - cols: usize, + gct_rows: usize, + gct_cols: usize, ) -> usize { unsafe { vmp::vmp_apply_dft_tmp_bytes( self.0, - c_cols as u64, + res_cols as u64, a_cols as u64, - rows as u64, - cols as u64, + gct_rows as u64, + gct_cols as u64, ) as usize } } @@ -495,18 +495,18 @@ impl VmpPMatOps for Module { fn vmp_apply_dft_to_dft_tmp_bytes( &self, - c_cols: usize, + res_cols: usize, a_cols: usize, - rows: usize, - cols: usize, + gct_rows: usize, + gct_cols: usize, ) -> usize { unsafe { vmp::vmp_apply_dft_to_dft_tmp_bytes( self.0, - c_cols as u64, + res_cols as u64, a_cols as u64, - rows as u64, - cols as u64, + gct_rows as u64, + gct_cols as u64, ) as usize } } diff --git a/rlwe/Cargo.toml b/rlwe/Cargo.toml index cf4ae7a..21c5ddd 100644 --- a/rlwe/Cargo.toml +++ b/rlwe/Cargo.toml @@ -6,6 +6,7 @@ version = "0.1.0" edition = "2024" [dependencies] +rug = {workspace = true} criterion = {workspace = true} base2k = {path="../base2k"} sampling = {path="../sampling"} diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs index b91b6c0..5cedf5c 100644 --- a/rlwe/benches/gadget_product.rs +++ b/rlwe/benches/gadget_product.rs @@ -1,10 +1,11 @@ +/* use base2k::{FFT64, Module, SvpPPolOps, VecZnx, VmpPMat, alloc_aligned_u8}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use rlwe::{ ciphertext::{Ciphertext, new_gadget_ciphertext}, - elem::Elem, - encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes}, - evaluator::{gadget_product_inplace, gadget_product_tmp_bytes}, + elem::{Elem, ElemCommon}, + encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes}, + gadget_product::{gadget_product_core, gadget_product_tmp_bytes}, key_generator::gen_switching_key_thread_safe_tmp_bytes, keys::SecretKey, parameters::{Parameters, ParametersLiteral}, @@ -18,7 +19,9 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { gadget_ct: &'a Ciphertext, tmp_bytes: &'a mut [u8], ) -> Box { - Box::new(move || gadget_product_inplace::(module, elem, gadget_ct, tmp_bytes)) + Box::new(move || { + gadget_product_inplace::(module, elem, gadget_ct, elem.cols() + 1, tmp_bytes) + }) } let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = @@ -43,7 +46,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { | gen_switching_key_thread_safe_tmp_bytes( params.module(), params.log_base2k(), - params.limbs_q(), + params.cols_q(), params.log_q(), ) | gadget_product_tmp_bytes( @@ -51,13 +54,13 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { params.log_base2k(), params.log_q(), params.log_q(), - params.limbs_q(), + params.cols_q(), params.log_qp(), ) | encrypt_grlwe_sk_tmp_bytes( params.module(), params.log_base2k(), - params.limbs_qp(), + params.cols_qp(), params.log_qp(), ), 64, @@ -82,7 +85,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { let mut gadget_ct: Ciphertext = new_gadget_ciphertext( params.module(), params.log_base2k(), - params.limbs_q(), + params.cols_q(), params.log_qp(), ); @@ -123,3 +126,4 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { criterion_group!(benches, bench_gadget_product_inplace); criterion_main!(benches); +*/ diff --git a/rlwe/examples/encryption.rs b/rlwe/examples/encryption.rs index 25689f6..4a05523 100644 --- a/rlwe/examples/encryption.rs +++ b/rlwe/examples/encryption.rs @@ -1,6 +1,7 @@ use base2k::{Encoding, FFT64, SvpPPolOps, VecZnx, VecZnxApi}; use rlwe::{ ciphertext::Ciphertext, + elem::ElemCommon, keys::SecretKey, parameters::{Parameters, ParametersLiteral}, plaintext::Plaintext, @@ -22,7 +23,7 @@ fn main() { let mut tmp_bytes: Vec = vec![ 0u8; - params.decrypt_rlwe_thread_safe_tmp_byte(params.log_q()) + params.decrypt_rlwe_tmp_byte(params.log_q()) | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) ]; @@ -64,7 +65,7 @@ fn main() { &mut tmp_bytes, ); - params.decrypt_rlwe_thread_safe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes); + params.decrypt_rlwe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes); pt.0.value[0].print(pt.cols(), 16); diff --git a/rlwe/examples/gadget_product.rs b/rlwe/examples/gadget_product.rs deleted file mode 100644 index f9e205f..0000000 --- a/rlwe/examples/gadget_product.rs +++ /dev/null @@ -1,140 +0,0 @@ -use base2k::{Encoding, FFT64, SvpPPolOps, VecZnx, VecZnxApi, VmpPMat}; -use rlwe::{ - ciphertext::{Ciphertext, new_gadget_ciphertext}, - decryptor::decrypt_rlwe_thread_safe, - encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes}, - evaluator::{gadget_product_inplace, gadget_product_tmp_bytes}, - key_generator::gen_switching_key_thread_safe_tmp_bytes, - keys::SecretKey, - parameters::{Parameters, ParametersLiteral}, - plaintext::Plaintext, -}; -use sampling::source::Source; - -fn main() { - let params_lit: ParametersLiteral = ParametersLiteral { - log_n: 4, - log_q: 68, - log_p: 17, - log_base2k: 17, - log_scale: 20, - xe: 3.2, - xs: 8, - }; - - let params: Parameters = Parameters::new::(¶ms_lit); - - let mut tmp_bytes: Vec = vec![ - 0u8; - params.decrypt_rlwe_thread_safe_tmp_byte(params.log_q()) - | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) - | gen_switching_key_thread_safe_tmp_bytes( - params.module(), - params.log_base2k(), - params.limbs_q(), - params.log_q() - ) - | gadget_product_tmp_bytes( - params.module(), - params.log_base2k(), - params.log_q(), - params.log_q(), - params.limbs_q(), - params.log_qp() - ) - | encrypt_grlwe_sk_tmp_bytes( - params.module(), - params.log_base2k(), - params.limbs_qp(), - params.log_qp() - ) - ]; - - let mut source: Source = Source::new([3; 32]); - - let mut sk0: SecretKey = SecretKey::new(params.module()); - let mut sk1: SecretKey = SecretKey::new(params.module()); - - sk0.fill_ternary_hw(params.xs(), &mut source); - sk1.fill_ternary_hw(params.xs(), &mut source); - - let mut want = vec![i64::default(); params.n()]; - - want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - - let log_base2k = params.log_base2k(); - - let log_k: usize = params.log_q() - 2 * log_base2k; - - let mut source_xe: Source = Source::new([4; 32]); - let mut source_xa: Source = Source::new([5; 32]); - - let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0); - - let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0); - - let mut gadget_ct: Ciphertext = new_gadget_ciphertext( - params.module(), - log_base2k, - params.limbs_q(), - params.log_qp(), - ); - - encrypt_grlwe_sk_thread_safe( - params.module(), - &mut gadget_ct, - &sk0.0, - &sk1_svp_ppol, - &mut source_xa, - &mut source_xe, - params.xe(), - &mut tmp_bytes, - ); - - let mut pt: Plaintext = - Plaintext::::new(params.module(), params.log_base2k(), params.log_q()); - - let mut want = vec![i64::default(); params.n()]; - want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - pt.0.value[0].encode_vec_i64(log_base2k, log_k, &want, 32); - pt.0.value[0].normalize(log_base2k, &mut tmp_bytes); - - let mut ct: Ciphertext = params.new_ciphertext(params.log_q()); - - params.encrypt_rlwe_sk_thread_safe( - &mut ct, - Some(&pt), - &sk0_svp_ppol, - &mut source_xa, - &mut source_xe, - &mut tmp_bytes, - ); - - gadget_product_inplace::(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes); - - println!("ct.limbs()={}", ct.cols()); - println!("gadget_ct.rows()={}", gadget_ct.rows()); - println!("gadget_ct.cols()={}", gadget_ct.cols()); - println!("res.limbs()={}", ct.cols()); - println!(); - - decrypt_rlwe_thread_safe( - params.module(), - &mut pt.0, - &ct.0, - &sk1_svp_ppol, - &mut tmp_bytes, - ); - - pt.0.value[0].print(pt.cols(), 16); - - let mut have: Vec = vec![i64::default(); params.n()]; - - println!("pt: {}", log_k); - pt.0.value[0].decode_vec_i64(pt.log_base2k(), log_k, &mut have); - - println!("want: {:?}", &want[..16]); - println!("have: {:?}", &have[..16]); -} diff --git a/rlwe/src/ciphertext.rs b/rlwe/src/ciphertext.rs index bf9a231..3274ae8 100644 --- a/rlwe/src/ciphertext.rs +++ b/rlwe/src/ciphertext.rs @@ -1,70 +1,77 @@ -use crate::elem::{Elem, ElemVecZnx, VecZnxCommon}; +use crate::elem::{Elem, ElemCommon}; use crate::parameters::Parameters; -use crate::plaintext::Plaintext; use base2k::{Infos, Module, VecZnx, VmpPMat}; pub struct Ciphertext(pub Elem); +impl Parameters { + pub fn new_ciphertext(&self, log_q: usize) -> Ciphertext { + Ciphertext::new(self.module(), self.log_base2k(), log_q, 2) + } +} + +impl ElemCommon for Ciphertext +where + T: Infos, +{ + fn n(&self) -> usize { + self.elem().n() + } + + fn log_n(&self) -> usize { + self.elem().log_n() + } + + fn log_q(&self) -> usize { + self.elem().log_q() + } + + fn elem(&self) -> &Elem { + &self.0 + } + + fn elem_mut(&mut self) -> &mut Elem { + &mut self.0 + } + + fn size(&self) -> usize { + self.elem().size() + } + + fn rows(&self) -> usize { + self.elem().rows() + } + + fn cols(&self) -> usize { + self.elem().cols() + } + + fn at(&self, i: usize) -> &T { + self.elem().at(i) + } + + fn at_mut(&mut self, i: usize) -> &mut T { + self.elem_mut().at_mut(i) + } + + fn log_base2k(&self) -> usize { + self.elem().log_base2k() + } + + fn log_scale(&self) -> usize { + self.elem().log_scale() + } +} + impl Ciphertext { pub fn new(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> Self { Self(Elem::::new(module, log_base2k, log_q, rows)) } } -impl Ciphertext -where - T: VecZnxCommon, -{ - pub fn zero(&mut self) { - self.0.zero() - } - - pub fn as_plaintext(&self) -> Plaintext { - unsafe { Plaintext::(std::ptr::read(&self.0)) } - } -} - -impl Ciphertext -where - T: Infos, -{ - pub fn n(&self) -> usize { - self.0.n() - } - - pub fn log_q(&self) -> usize { - self.0.log_q - } - - pub fn rows(&self) -> usize { - self.0.rows() - } - - pub fn cols(&self) -> usize { - self.0.cols() - } - - pub fn at(&self, i: usize) -> &T { - self.0.at(i) - } - - pub fn at_mut(&mut self, i: usize) -> &mut T { - self.0.at_mut(i) - } - - pub fn log_base2k(&self) -> usize { - self.0.log_base2k - } - - pub fn log_scale(&self) -> usize { - self.0.log_scale - } -} - -impl Parameters { - pub fn new_ciphertext(&self, log_q: usize) -> Ciphertext { - Ciphertext::new(self.module(), self.log_base2k(), log_q, 2) - } +pub fn new_rlwe_ciphertext(module: &Module, log_base2k: usize, log_q: usize) -> Ciphertext { + let rows: usize = 2; + Ciphertext::::new(module, log_base2k, log_q, rows) } pub fn new_gadget_ciphertext( @@ -74,7 +81,7 @@ pub fn new_gadget_ciphertext( log_q: usize, ) -> Ciphertext { let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut elem: Elem = Elem::::new(module, log_base2k, 1, rows, 2 * cols); + let mut elem: Elem = Elem::::new(module, log_base2k, 2, rows, cols); elem.log_q = log_q; Ciphertext(elem) } @@ -86,7 +93,7 @@ pub fn new_rgsw_ciphertext( log_q: usize, ) -> Ciphertext { let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut elem: Elem = Elem::::new(module, log_base2k, 2, rows, 2 * cols); + let mut elem: Elem = Elem::::new(module, log_base2k, 4, rows, cols); elem.log_q = log_q; Ciphertext(elem) } diff --git a/rlwe/src/decryptor.rs b/rlwe/src/decryptor.rs index 4074315..63a892b 100644 --- a/rlwe/src/decryptor.rs +++ b/rlwe/src/decryptor.rs @@ -1,6 +1,6 @@ use crate::{ ciphertext::Ciphertext, - elem::{Elem, ElemVecZnx, VecZnxCommon}, + elem::{Elem, ElemCommon, VecZnxCommon}, keys::SecretKey, parameters::Parameters, plaintext::Plaintext, @@ -20,19 +20,19 @@ impl Decryptor { } } -pub fn decrypt_rlwe_thread_safe_tmp_byte(module: &Module, limbs: usize) -> usize { +pub fn decrypt_rlwe_tmp_byte(module: &Module, limbs: usize) -> usize { module.bytes_of_vec_znx_dft(limbs) + module.vec_znx_big_normalize_tmp_bytes() } impl Parameters { - pub fn decrypt_rlwe_thread_safe_tmp_byte(&self, log_q: usize) -> usize { - decrypt_rlwe_thread_safe_tmp_byte( + pub fn decrypt_rlwe_tmp_byte(&self, log_q: usize) -> usize { + decrypt_rlwe_tmp_byte( self.module(), (log_q + self.log_base2k() - 1) / self.log_base2k(), ) } - pub fn decrypt_rlwe_thread_safe( + pub fn decrypt_rlwe( &self, res: &mut Plaintext, ct: &Ciphertext, @@ -40,13 +40,13 @@ impl Parameters { tmp_bytes: &mut [u8], ) where T: VecZnxCommon, - Elem: ElemVecZnx, + Elem: ElemCommon, { - decrypt_rlwe_thread_safe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes) + decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes) } } -pub fn decrypt_rlwe_thread_safe( +pub fn decrypt_rlwe( module: &Module, res: &mut Elem, a: &Elem, @@ -54,15 +54,15 @@ pub fn decrypt_rlwe_thread_safe( tmp_bytes: &mut [u8], ) where T: VecZnxCommon, - Elem: ElemVecZnx, + Elem: ElemCommon, { let cols: usize = a.cols(); assert!( - tmp_bytes.len() >= decrypt_rlwe_thread_safe_tmp_byte(module, cols), - "invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_thread_safe_tmp_byte={}", + tmp_bytes.len() >= decrypt_rlwe_tmp_byte(module, cols), + "invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_tmp_byte={}", tmp_bytes.len(), - decrypt_rlwe_thread_safe_tmp_byte(module, cols) + decrypt_rlwe_tmp_byte(module, cols) ); let res_dft_bytes: usize = module.bytes_of_vec_znx_dft(cols); diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 5c1348d..88c5795 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -75,45 +75,68 @@ where } } -impl Elem { - pub fn n(&self) -> usize { +pub trait ElemCommon { + fn n(&self) -> usize; + fn log_n(&self) -> usize; + fn elem(&self) -> &Elem; + fn elem_mut(&mut self) -> &mut Elem; + fn size(&self) -> usize; + fn rows(&self) -> usize; + fn cols(&self) -> usize; + fn log_base2k(&self) -> usize; + fn log_q(&self) -> usize; + fn log_scale(&self) -> usize; + fn at(&self, i: usize) -> &T; + fn at_mut(&mut self, i: usize) -> &mut T; +} + +impl ElemCommon for Elem { + fn n(&self) -> usize { self.value[0].n() } - pub fn log_n(&self) -> usize { + fn log_n(&self) -> usize { self.value[0].log_n() } - pub fn size(&self) -> usize { + fn elem(&self) -> &Elem { + self + } + + fn elem_mut(&mut self) -> &mut Elem { + self + } + + fn size(&self) -> usize { self.value.len() } - pub fn rows(&self) -> usize { + fn rows(&self) -> usize { self.value[0].rows() } - pub fn cols(&self) -> usize { + fn cols(&self) -> usize { self.value[0].cols() } - pub fn log_base2k(&self) -> usize { + fn log_base2k(&self) -> usize { self.log_base2k } - pub fn log_q(&self) -> usize { + fn log_q(&self) -> usize { self.log_q } - pub fn log_scale(&self) -> usize { + fn log_scale(&self) -> usize { self.log_scale } - pub fn at(&self, i: usize) -> &T { + fn at(&self, i: usize) -> &T { assert!(i < self.size()); &self.value[i] } - pub fn at_mut(&mut self, i: usize) -> &mut T { + fn at_mut(&mut self, i: usize) -> &mut T { assert!(i < self.size()); &mut self.value[i] } diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs index 8abfcb0..7156d52 100644 --- a/rlwe/src/encryptor.rs +++ b/rlwe/src/encryptor.rs @@ -1,5 +1,5 @@ use crate::ciphertext::Ciphertext; -use crate::elem::{Elem, ElemVecZnx, VecZnxCommon}; +use crate::elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon}; use crate::keys::SecretKey; use crate::parameters::Parameters; use crate::plaintext::Plaintext; @@ -32,7 +32,7 @@ impl EncryptorSk { initialized, source_xa: Source::new(new_seed()), source_xe: Source::new(new_seed()), - tmp_bytes: vec![0u8; params.encrypt_rlwe_sk_tmp_bytes(params.limbs_qp())], + tmp_bytes: vec![0u8; params.encrypt_rlwe_sk_tmp_bytes(params.cols_qp())], } } @@ -56,7 +56,7 @@ impl EncryptorSk { pt: Option<&Plaintext>, ) where T: VecZnxCommon, - Elem: ElemVecZnx, + Elem: ElemCommon, { assert!( self.initialized == true, @@ -82,7 +82,7 @@ impl EncryptorSk { tmp_bytes: &mut [u8], ) where T: VecZnxCommon, - Elem: ElemVecZnx, + Elem: ElemCommon, { assert!( self.initialized == true, @@ -107,7 +107,7 @@ impl Parameters { tmp_bytes: &mut [u8], ) where T: VecZnxCommon, - Elem: ElemVecZnx, + Elem: ElemCommon, { encrypt_rlwe_sk_thread_safe( self.module(), @@ -138,7 +138,7 @@ pub fn encrypt_rlwe_sk_thread_safe( tmp_bytes: &mut [u8], ) where T: VecZnxCommon, - Elem: ElemVecZnx, + Elem: ElemCommon, { let cols: usize = ct.cols(); let log_base2k: usize = ct.log_base2k(); @@ -197,6 +197,12 @@ pub fn encrypt_rlwe_sk_thread_safe( ); } +impl Parameters { + pub fn encrypt_grlwe_sk_tmp_bytes(&self, rows: usize, log_q: usize) -> usize { + encrypt_grlwe_sk_tmp_bytes(self.module(), self.log_base2k(), rows, log_q) + } +} + pub fn encrypt_grlwe_sk_tmp_bytes( module: &Module, log_base2k: usize, @@ -207,10 +213,10 @@ pub fn encrypt_grlwe_sk_tmp_bytes( Elem::::bytes_of(module, log_base2k, log_q, 2) + Plaintext::::bytes_of(module, log_base2k, log_q) + encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) - + module.vmp_prepare_tmp_bytes(rows, 2 * cols) + + module.vmp_prepare_tmp_bytes(rows, cols) } -pub fn encrypt_grlwe_sk_thread_safe( +pub fn encrypt_grlwe_sk( module: &Module, ct: &mut Ciphertext, m: &Scalar, @@ -249,7 +255,7 @@ pub fn encrypt_grlwe_sk_thread_safe( (0..rows).for_each(|row_i| { // Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i}) - tmp_pt.0.value[0].at_mut(row_i).copy_from_slice(&m.0); + tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.0); // Encrypts RLWE(m * 2^{-log_base2k*i}) encrypt_rlwe_sk_thread_safe( @@ -263,19 +269,28 @@ pub fn encrypt_grlwe_sk_thread_safe( tmp_bytes_enc_sk, ); + //tmp_pt.at(0).print(tmp_pt.cols(), 16); + //println!(); + // Zeroes the ith-row of tmp_pt - tmp_pt.0.value[0].at_mut(row_i).fill(0); + tmp_pt.at_mut(0).at_mut(row_i).fill(0); //println!("row:{}/{}", row_i, rows); - //tmp_elem.at(0).print(tmp_elem.limbs(), tmp_elem.n()); - //tmp_elem.at(1).print(tmp_elem.limbs(), tmp_elem.n()); + //tmp_elem.at(0).print(tmp_elem.cols(), tmp_elem.n()); + //tmp_elem.at(1).print(tmp_elem.cols(), tmp_elem.n()); //println!(); //println!(">>>"); // GRLWE[row_i][0||1] = [-as + m * 2^{-i*log_base2k} + e*2^{-log_q} || a] module.vmp_prepare_row( - &mut ct.0.value[0], - cast_mut::(tmp_bytes_elem), + &mut ct.at_mut(0), + tmp_elem.at(0).raw(), + row_i, + tmp_bytes_vmp_prepare_row, + ); + module.vmp_prepare_row( + &mut ct.at_mut(1), + tmp_elem.at(1).raw(), row_i, tmp_bytes_vmp_prepare_row, ); diff --git a/rlwe/src/evaluator.rs b/rlwe/src/evaluator.rs deleted file mode 100644 index 30dffc4..0000000 --- a/rlwe/src/evaluator.rs +++ /dev/null @@ -1,171 +0,0 @@ -use crate::{ - ciphertext::Ciphertext, - elem::{Elem, ElemVecZnx, VecZnxCommon}, -}; -use base2k::{Module, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps}; -use std::cmp::min; - -pub fn gadget_product_tmp_bytes( - module: &Module, - log_base2k: usize, - out_log_q: usize, - in_log_q: usize, - gct_rows: usize, - gct_log_q: usize, -) -> usize { - let gct_cols: usize = (gct_log_q + log_base2k - 1) / log_base2k; - let in_limbs: usize = (in_log_q + log_base2k - 1) / log_base2k; - let out_limbs: usize = (out_log_q + log_base2k - 1) / log_base2k; - module.vmp_apply_dft_to_dft_tmp_bytes(out_limbs, in_limbs, gct_rows, gct_cols) - + 2 * module.bytes_of_vec_znx_dft(gct_cols) -} - -pub fn gadget_product_inplace( - module: &Module, - res: &mut Elem, - b: &Ciphertext, - tmp_bytes: &mut [u8], -) where - T: VecZnxCommon, - Elem: ElemVecZnx, -{ - unsafe { - let a_ptr: *const T = res.at(1) as *const T; - gadget_product::(module, res, &*a_ptr, b, tmp_bytes); - } -} - -/// Evaluates the gadget product res <- a x b. -/// -/// # Arguments -/// -/// * `module`: backend support for operations mod (X^N + 1). -/// * `res`: an [Elem] to store (-cs + m * a + e, c) with res_ncols limbs. -/// * `a`: a [VecZnx] of a_ncols limbs. -/// * `b`: a [GadgetCiphertext] as a vector of (-Bs + m * 2^{-k} + E, B) -/// containing b_nrows [VecZnx], each of b_ncols limbs. -/// -/// # Computation -/// -/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i]) -/// = (cs + m * a + e, c) with min(res_limbs, b_cols) limbs. -pub fn gadget_product( - module: &Module, - res: &mut Elem, - a: &T, - b: &Ciphertext, - tmp_bytes: &mut [u8], -) where - T: VecZnxCommon, - Elem: ElemVecZnx, -{ - let log_base2k: usize = b.log_base2k(); - let rows: usize = min(b.rows(), a.cols()); - let cols: usize = b.cols(); - - let bytes_vmp_apply_dft: usize = - module.vmp_apply_dft_to_dft_tmp_bytes(cols, a.cols(), rows, cols); - let bytes_vec_znx_dft: usize = module.bytes_of_vec_znx_dft(cols); - - let (tmp_bytes_vmp_apply_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_vmp_apply_dft); - let (tmp_bytes_c1_dft, tmp_bytes_res_dft) = tmp_bytes.split_at_mut(bytes_vec_znx_dft); - - let mut tmp_a_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_res_dft); - let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); - - // Alias c0 and c1 part of res_big - let (tmp_bytes_res_dft_c0, tmp_bytes_res_dft_c1) = - tmp_bytes_res_dft.split_at_mut(bytes_vec_znx_dft >> 1); - let res_big_c0: VecZnxBig = module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c0); - let mut res_big_c1: VecZnxBig = - module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c1); - - // tmp_a_dft <- DFT(a) - // (n x cols) <- (n x limbs=rows) x (rows x cols) - // res_dft[a * (G0|G1)] <- sum[rows] tmp_a_dft x (DFT(G0)|DFT(G1)) - gadget_product_core(module, &mut res_dft, a, b.at(0), tmp_bytes_vmp_apply_dft); - - // res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)]) - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols); - - // res_big <- res[0] + res_big[a*G0] - module.vec_znx_big_add_small_inplace(&mut res_big, res.at(0)); - module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big_c0, tmp_bytes_c1_dft); - - if OVERWRITE { - // res[1] = normalize(res_big[a*G1]) - module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_c1, tmp_bytes_c1_dft); - } else { - // res[1] = normalize(res_big[a*G1] + res[1]) - module.vec_znx_big_add_small_inplace(&mut res_big_c1, res.at(1)); - module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_c1, tmp_bytes_c1_dft); - } -} - -pub fn gadget_product_core( - module: &Module, - res_dft: &mut VecZnxDft, - a: &T, - b: &VmpPMat, - tmp_bytes_vmp_apply_dft: &mut [u8], -) where - T: VecZnxCommon, - Elem: ElemVecZnx, -{ - // res_dft <- DFT(a) - module.vec_znx_dft(res_dft, a, a.cols()); - - // (n x cols) <- (n x limbs=rows) x (rows x cols) - // res_dft[a * (G0|G1)] <- sum[rows] res_dft x (DFT(G0)|DFT(G1)) - module.vmp_apply_dft_to_dft_inplace(res_dft, b, tmp_bytes_vmp_apply_dft); -} - -pub fn rgsw_product( - module: &Module, - res: &mut Elem, - a: &Ciphertext, - b: &Ciphertext, - tmp_bytes: &mut [u8], -) where - T: VecZnxCommon, - Elem: ElemVecZnx, -{ - let log_base2k: usize = b.log_base2k(); - let rows: usize = min(b.rows(), a.cols()); - let cols: usize = b.cols(); - let in_cols = a.cols(); - let out_cols: usize = a.cols(); - - let bytes_of_vec_znx_dft = module.bytes_of_vec_znx_dft(cols); - let bytes_of_vmp_apply_dft_to_dft = - module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, rows, cols); - - let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); - let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); - let (tmp_bytes_tmp_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); - let (tmp_bytes_r1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); - let (tmp_bytes_r2_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); - let (bytes_of_vmp_apply_dft_to_dft, tmp_bytes) = - tmp_bytes.split_at_mut(bytes_of_vmp_apply_dft_to_dft); - - let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c0_dft); - let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft); - let mut tmp_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_tmp_dft); - let mut r1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r1_dft); - let mut r2_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r2_dft); - - // c0_dft <- DFT(a[0]) - module.vec_znx_dft(&mut c0_dft, a.at(0), in_cols); - - // r_dft <- sum[rows] c0_dft[cols] x RGSW[0][cols] - module.vmp_apply_dft_to_dft( - &mut r1_dft, - &c1_dft, - &b.0.value[0], - bytes_of_vmp_apply_dft_to_dft, - ); - - // c1_dft <- DFT(a[1]) - module.vec_znx_dft(&mut c1_dft, a.at(1), in_cols); -} diff --git a/rlwe/src/gadget_product.rs b/rlwe/src/gadget_product.rs new file mode 100644 index 0000000..4027575 --- /dev/null +++ b/rlwe/src/gadget_product.rs @@ -0,0 +1,371 @@ +use crate::{ + ciphertext::Ciphertext, + elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon}, + parameters::Parameters, +}; +use base2k::{Module, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps}; +use std::cmp::min; + +pub fn gadget_product_tmp_bytes( + module: &Module, + log_base2k: usize, + res_log_q: usize, + in_log_q: usize, + gct_rows: usize, + gct_log_q: usize, +) -> usize { + let gct_cols: usize = (gct_log_q + log_base2k - 1) / log_base2k; + let in_cols: usize = (in_log_q + log_base2k - 1) / log_base2k; + let out_cols: usize = (res_log_q + log_base2k - 1) / log_base2k; + module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, gct_rows, gct_cols) +} + +impl Parameters { + pub fn gadget_product_tmp_bytes( + &self, + res_log_q: usize, + in_log_q: usize, + gct_rows: usize, + gct_log_q: usize, + ) -> usize { + gadget_product_tmp_bytes( + self.module(), + self.log_base2k(), + res_log_q, + in_log_q, + gct_rows, + gct_log_q, + ) + } +} + +/// Evaluates the gadget product res <- a x b. +/// +/// # Arguments +/// +/// * `module`: backend support for operations mod (X^N + 1). +/// * `res`: an [Elem] to store (-cs + m * a + e, c) with res_ncols cols. +/// * `a`: a [VecZnx] of a_ncols cols. +/// * `b`: a [Ciphertext] as a vector of (-Bs + m * 2^{-k} + E, B) +/// containing b_nrows [VecZnx], each of b_ncols cols. +/// +/// # Computation +/// +/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i]) +/// = (cs + m * a + e, c) with min(res_cols, b_cols) cols. +pub fn gadget_product_core( + module: &Module, + res_dft_0: &mut VecZnxDft, + res_dft_1: &mut VecZnxDft, + a: &T, + a_cols: usize, + b: &Ciphertext, + b_cols: usize, + tmp_bytes: &mut [u8], +) where + T: VecZnxCommon, + Elem: ElemVecZnx, +{ + assert!(b_cols <= b.cols()); + module.vec_znx_dft(res_dft_1, a, a_cols); + module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes); + module.vmp_apply_dft_to_dft_inplace(res_dft_1, b.at(1), tmp_bytes); +} + +/* +// res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)]) +module.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0, b_cols); +module.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1, b_cols); + +// res_big <- res[0] + res_big[a*G0] +module.vec_znx_big_add_small_inplace(&mut res_big_0, res.at(0)); +module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big_0, tmp_bytes_carry); + +if OVERWRITE { + // res[1] = normalize(res_big[a*G1]) + module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_1, tmp_bytes_carry); +} else { + // res[1] = normalize(res_big[a*G1] + res[1]) + module.vec_znx_big_add_small_inplace(&mut res_big_1, res.at(1)); + module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_1, tmp_bytes_carry); +} +*/ + +#[cfg(test)] +mod test { + use crate::{ + ciphertext::{Ciphertext, new_gadget_ciphertext}, + decryptor::decrypt_rlwe, + elem::{Elem, ElemCommon, ElemVecZnx}, + encryptor::encrypt_grlwe_sk, + gadget_product::gadget_product_core, + keys::SecretKey, + parameters::{Parameters, ParametersLiteral}, + plaintext::Plaintext, + }; + use base2k::{ + FFT64, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxDft, + VecZnxDftOps, VecZnxOps, VmpPMat, + }; + use sampling::source::{Source, new_seed}; + + #[test] + fn test_gadget_product_core() { + let log_base2k: usize = 10; + let q_cols: usize = 7; + let p_cols: usize = 1; + + // Basic parameters with enough limbs to test edge cases + let params_lit: ParametersLiteral = ParametersLiteral { + log_n: 12, + log_q: q_cols * log_base2k, + log_p: p_cols * log_base2k, + log_base2k: log_base2k, + log_scale: 20, + xe: 3.2, + xs: 1 << 11, + }; + + let params: Parameters = Parameters::new::(¶ms_lit); + + // scratch space + let mut tmp_bytes: Vec = + vec![ + 0u8; + params.decrypt_rlwe_tmp_byte(params.log_qp()) + | params.encrypt_rlwe_sk_tmp_bytes(params.log_qp()) + | params.gadget_product_tmp_bytes( + params.log_qp(), + params.log_qp(), + params.cols_qp(), + params.log_qp() + ) + | params.encrypt_grlwe_sk_tmp_bytes(params.cols_qp(), params.log_qp()) + ]; + + // Samplers for public and private randomness + let mut source_xe: Source = Source::new(new_seed()); + let mut source_xa: Source = Source::new(new_seed()); + let mut source_xs: Source = Source::new(new_seed()); + + // Two secret keys + let mut sk0: SecretKey = SecretKey::new(params.module()); + sk0.fill_ternary_hw(params.xs(), &mut source_xs); + let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); + params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0); + + let mut sk1: SecretKey = SecretKey::new(params.module()); + sk1.fill_ternary_hw(params.xs(), &mut source_xs); + let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); + params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0); + + // The gadget ciphertext + let mut gadget_ct: Ciphertext = new_gadget_ciphertext( + params.module(), + log_base2k, + params.cols_qp(), + params.log_qp(), + ); + + // gct = [-b*sk1 + g(sk0) + e, b] + encrypt_grlwe_sk( + params.module(), + &mut gadget_ct, + &sk0.0, + &sk1_svp_ppol, + &mut source_xa, + &mut source_xe, + params.xe(), + &mut tmp_bytes, + ); + + // Intermediate buffers + let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols()); + let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols()); + let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big(); + let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big(); + + // Input polynopmial, uniformly distributed + let mut a: VecZnx = params.module().new_vec_znx(params.cols_q()); + params + .module() + .fill_uniform(log_base2k, &mut a, params.cols_q(), &mut source_xa); + + // res = g^-1(a) * gct + let mut elem_res: Elem = + Elem::::new(params.module(), log_base2k, params.log_qp(), 2); + + // Ideal output = a * s + let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(a.cols()); + let mut a_big: VecZnxBig = a_dft.as_vec_znx_big(); + let mut a_times_s: VecZnx = params.module().new_vec_znx(a.cols()); + + // a * sk0 + params + .module() + .svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a, a.cols()); + params + .module() + .vec_znx_idft_tmp_a(&mut a_big, &mut a_dft, a.cols()); + params.module().vec_znx_big_normalize( + params.log_base2k(), + &mut a_times_s, + &a_big, + &mut tmp_bytes, + ); + + // Plaintext for decrypted output of gadget product + let mut pt: Plaintext = + Plaintext::::new(params.module(), params.log_base2k(), params.log_qp()); + + // Iterates over all possible cols values for input/output polynomials and gadget ciphertext. + + pt.elem_mut().zero(); + elem_res.zero(); + + let a_cols: usize = a.cols() - 1; + let b_cols: usize = gadget_ct.cols(); + + println!("a_cols: {} b_cols: {}", a_cols, b_cols); + + // res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e') + // res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c) + gadget_product_core::( + params.module(), + &mut res_dft_0, + &mut res_dft_1, + &a, + a_cols, + &gadget_ct, + b_cols, + &mut tmp_bytes, + ); + + // res_big_0 = IDFT(res_dft_0) + params + .module() + .vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0, b_cols); + // res_big_1 = IDFT(res_dft_1); + params + .module() + .vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1, b_cols); + + // res_big_0 = normalize(res_big_0) + params.module().vec_znx_big_normalize( + log_base2k, + elem_res.at_mut(0), + &res_big_0, + &mut tmp_bytes, + ); + + // res_big_1 = normalize(res_big_1) + params.module().vec_znx_big_normalize( + log_base2k, + elem_res.at_mut(1), + &res_big_1, + &mut tmp_bytes, + ); + + // <(-c*sk1 + a*sk0 + e, a), (1, sk1)> = a*sk0 + e + decrypt_rlwe( + params.module(), + pt.elem_mut(), + &elem_res, + &sk1_svp_ppol, + &mut tmp_bytes, + ); + + // a * sk0 + e - a*sk0 = e + params + .module() + .vec_znx_sub_inplace(pt.at_mut(0), &mut a_times_s); + pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes); + + pt.at(0).print(pt.elem().cols(), 16); + + println!("noise_have: {}", pt.at(0).std(log_base2k).log2()); + + let var_a_err: f64; + + if a_cols < a.cols() { + var_a_err = 1f64 / 12f64; + } else { + var_a_err = 0f64; + } + + let a_logq: usize = a_cols * log_base2k; + let b_logq: usize = b_cols * log_base2k; + let var_msg: f64 = params.xs() as f64; + println!( + "noise_pred: {}", + params.noise_grlwe_product(var_msg, var_a_err, a_logq, b_logq) + ); + } +} + +impl Parameters { + pub fn noise_grlwe_product( + &self, + var_msg: f64, + var_a_err: f64, + a_logq: usize, + b_logq: usize, + ) -> f64 { + let n: f64 = self.n() as f64; + let var_xs: f64 = self.xs() as f64; + + let var_gct_err_lhs: f64; + let var_gct_err_rhs: f64; + if b_logq < self.log_qp() { + let var_round: f64 = 1f64 / 12f64; + var_gct_err_lhs = var_round; + var_gct_err_rhs = var_round; + } else { + var_gct_err_lhs = self.xe() * self.xe(); + var_gct_err_rhs = 0f64; + } + + noise_grlwe_product( + n, + self.log_base2k(), + var_xs, + var_msg, + var_a_err, + var_gct_err_lhs, + var_gct_err_rhs, + a_logq, + b_logq, + ) + } +} + +pub fn noise_grlwe_product( + n: f64, + log_base2k: usize, + var_xs: f64, + var_msg: f64, + var_a_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + a_logq: usize, + b_logq: usize, +) -> f64 { + let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k; + let b_cols: usize = (b_logq + log_base2k - 1) / log_base2k; + + let b_scale = 2.0f64.powi(b_logq as i32); + let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); + + let base: f64 = (1 << (log_base2k)) as f64; + let var_base: f64 = base * base / 12f64; + let var_round: f64 = 1f64 / 12f64; + + // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) + // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs + let mut noise: f64 = + (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); + noise += var_msg * var_a_err * a_scale * a_scale; + noise = noise.sqrt(); + noise /= b_scale; + noise.log2() +} diff --git a/rlwe/src/key_generator.rs b/rlwe/src/key_generator.rs index 6546317..7ca311a 100644 --- a/rlwe/src/key_generator.rs +++ b/rlwe/src/key_generator.rs @@ -1,4 +1,4 @@ -use crate::encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes}; +use crate::encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes}; use crate::keys::{PublicKey, SecretKey, SwitchingKey}; use crate::parameters::Parameters; use base2k::{Module, SvpPPol}; @@ -40,7 +40,7 @@ impl KeyGenerator { } } -pub fn gen_switching_key_thread_safe_tmp_bytes( +pub fn gen_switching_key_tmp_bytes( module: &Module, log_base2k: usize, rows: usize, @@ -49,7 +49,7 @@ pub fn gen_switching_key_thread_safe_tmp_bytes( encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q) } -pub fn gen_switching_key_thread_safe( +pub fn gen_switching_key( module: &Module, swk: &mut SwitchingKey, sk_in: &SecretKey, @@ -59,7 +59,7 @@ pub fn gen_switching_key_thread_safe( sigma: f64, tmp_bytes: &mut [u8], ) { - encrypt_grlwe_sk_thread_safe( + encrypt_grlwe_sk( module, &mut swk.0, &sk_in.0, sk_out, source_xa, source_xe, sigma, tmp_bytes, ); } diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index c1808dd..df78e11 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,5 +1,5 @@ use crate::ciphertext::{Ciphertext, new_gadget_ciphertext}; -use crate::elem::Elem; +use crate::elem::{Elem, ElemCommon}; use crate::encryptor::{encrypt_rlwe_sk_thread_safe, encrypt_rlwe_sk_tmp_bytes}; use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VmpPMat}; use sampling::source::Source; diff --git a/rlwe/src/lib.rs b/rlwe/src/lib.rs index a559681..1243297 100644 --- a/rlwe/src/lib.rs +++ b/rlwe/src/lib.rs @@ -2,8 +2,9 @@ pub mod ciphertext; pub mod decryptor; pub mod elem; pub mod encryptor; -pub mod evaluator; +pub mod gadget_product; pub mod key_generator; pub mod keys; pub mod parameters; pub mod plaintext; +pub mod rgsw_product; diff --git a/rlwe/src/parameters.rs b/rlwe/src/parameters.rs index 86946ec..8cc948f 100644 --- a/rlwe/src/parameters.rs +++ b/rlwe/src/parameters.rs @@ -59,11 +59,11 @@ impl Parameters { self.log_q + self.log_p } - pub fn limbs_q(&self) -> usize { + pub fn cols_q(&self) -> usize { (self.log_q + self.log_base2k - 1) / self.log_base2k } - pub fn limbs_qp(&self) -> usize { + pub fn cols_qp(&self) -> usize { (self.log_q + self.log_p + self.log_base2k - 1) / self.log_base2k } diff --git a/rlwe/src/plaintext.rs b/rlwe/src/plaintext.rs index 36f9932..78a62cb 100644 --- a/rlwe/src/plaintext.rs +++ b/rlwe/src/plaintext.rs @@ -1,5 +1,5 @@ use crate::ciphertext::Ciphertext; -use crate::elem::{Elem, ElemVecZnx, VecZnxCommon}; +use crate::elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon}; use crate::parameters::Parameters; use base2k::{Module, VecZnx}; @@ -46,43 +46,61 @@ where Self(Elem::::from_bytes(module, log_base2k, log_q, 1, bytes)) } - pub fn n(&self) -> usize { - self.0.n() - } - - pub fn log_q(&self) -> usize { - self.0.log_q - } - - pub fn rows(&self) -> usize { - self.0.rows() - } - - pub fn cols(&self) -> usize { - self.0.cols() - } - - pub fn at(&self, i: usize) -> &T { - self.0.at(i) - } - - pub fn at_mut(&mut self, i: usize) -> &mut T { - self.0.at_mut(i) - } - - pub fn log_base2k(&self) -> usize { - self.0.log_base2k() - } - - pub fn log_scale(&self) -> usize { - self.0.log_scale() - } - - pub fn zero(&mut self) { - self.0.zero() - } - pub fn as_ciphertext(&self) -> Ciphertext { unsafe { Ciphertext::(std::ptr::read(&self.0)) } } } + +impl ElemCommon for Plaintext +where + T: VecZnxCommon, + Elem: ElemVecZnx, +{ + fn n(&self) -> usize { + self.0.n() + } + + fn log_n(&self) -> usize { + self.elem().log_n() + } + + fn log_q(&self) -> usize { + self.0.log_q + } + + fn elem(&self) -> &Elem { + &self.0 + } + + fn elem_mut(&mut self) -> &mut Elem { + &mut self.0 + } + + fn size(&self) -> usize { + self.elem().size() + } + + fn rows(&self) -> usize { + self.0.rows() + } + + fn cols(&self) -> usize { + self.0.cols() + } + + fn at(&self, i: usize) -> &T { + self.0.at(i) + } + + fn at_mut(&mut self, i: usize) -> &mut T { + self.0.at_mut(i) + } + + fn log_base2k(&self) -> usize { + self.0.log_base2k() + } + + fn log_scale(&self) -> usize { + self.0.log_scale() + } +} diff --git a/rlwe/src/rgsw_product.rs b/rlwe/src/rgsw_product.rs new file mode 100644 index 0000000..f2bda0f --- /dev/null +++ b/rlwe/src/rgsw_product.rs @@ -0,0 +1,55 @@ +use crate::{ + ciphertext::Ciphertext, + elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon}, +}; +use base2k::{Module, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps}; +use std::cmp::min; + +pub fn rgsw_product( + module: &Module, + _res: &mut Elem, + a: &Ciphertext, + b: &Ciphertext, + tmp_bytes: &mut [u8], +) where + T: VecZnxCommon, + Elem: ElemVecZnx, +{ + let _log_base2k: usize = b.log_base2k(); + let rows: usize = min(b.rows(), a.cols()); + let cols: usize = b.cols(); + let in_cols = a.cols(); + let out_cols: usize = a.cols(); + + let bytes_of_vec_znx_dft = module.bytes_of_vec_znx_dft(cols); + let bytes_of_vmp_apply_dft_to_dft = + module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, rows, cols); + + let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); + let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); + let (tmp_bytes_tmp_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); + let (tmp_bytes_r1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); + let (tmp_bytes_r2_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); + let (bytes_of_vmp_apply_dft_to_dft, tmp_bytes) = + tmp_bytes.split_at_mut(bytes_of_vmp_apply_dft_to_dft); + + let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c0_dft); + let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft); + let mut _tmp_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_tmp_dft); + let mut r1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r1_dft); + let mut _r2_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r2_dft); + + // c0_dft <- DFT(a[0]) + module.vec_znx_dft(&mut c0_dft, a.at(0), in_cols); + + // r_dft <- sum[rows] c0_dft[cols] x RGSW[0][cols] + module.vmp_apply_dft_to_dft( + &mut r1_dft, + &c1_dft, + &b.0.value[0], + bytes_of_vmp_apply_dft_to_dft, + ); + + // c1_dft <- DFT(a[1]) + module.vec_znx_dft(&mut c1_dft, a.at(1), in_cols); +}