Fixed gadget product & added noise estimations

This commit is contained in:
Jean-Philippe Bossuat
2025-02-24 08:31:02 +01:00
parent 014bf0c2d1
commit 26c2bcbc05
24 changed files with 762 additions and 473 deletions

30
Cargo.lock generated
View File

@@ -49,6 +49,12 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "az"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b7e4c2464d97fe331d41de9d5db0def0a96f4d823b8b32a2efd503578988973"
[[package]]
name = "base2k"
version = "0.1.0"
@@ -58,6 +64,7 @@ dependencies = [
"rand",
"rand_core",
"rand_distr",
"rug",
"sampling",
"utils",
]
@@ -228,6 +235,16 @@ dependencies = [
"wasi",
]
[[package]]
name = "gmp-mpfr-sys"
version = "1.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0205cd82059bc63b63cf516d714352a30c44f2c74da9961dfda2617ae6b5918"
dependencies = [
"libc",
"windows-sys 0.52.0",
]
[[package]]
name = "half"
version = "2.4.1"
@@ -639,6 +656,7 @@ dependencies = [
"base2k",
"criterion",
"rand_distr",
"rug",
"sampling",
]
@@ -660,6 +678,18 @@ dependencies = [
"utils",
]
[[package]]
name = "rug"
version = "1.27.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4207e8d668e5b8eb574bda8322088ccd0d7782d3d03c7e8d562e82ed82bdcbc3"
dependencies = [
"az",
"gmp-mpfr-sys",
"libc",
"libm",
]
[[package]]
name = "ryu"
version = "1.0.18"

View File

@@ -3,6 +3,7 @@ members = ["base2k", "rlwe", "rns", "sampling", "utils"]
[workspace.dependencies]
rug = "1.27"
rand = "0.8.4"
rand_chacha = "0.3.1"
rand_core = "0.6.4"

View File

@@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2021"
[dependencies]
rug = {workspace = true}
criterion = {workspace = true}
itertools = {workspace = true}
rand = {workspace = true}

View File

@@ -1,6 +1,7 @@
use crate::ffi::znx::znx_zero_i64_ref;
use crate::{Infos, VecZnx, VecZnxApi};
use itertools::izip;
use rug::{Assign, Float};
use std::cmp::min;
pub trait Encoding {
@@ -23,6 +24,13 @@ pub trait Encoding {
/// * `data`: data to decode from the receiver.
fn decode_vec_i64(&self, log_base2k: usize, log_k: usize, data: &mut [i64]);
/// decode a vector of Float from the receiver.
///
/// # Arguments
/// * `log_base2k`: base two logarithm decomposition of the receiver.
/// * `data`: data to decode from the receiver.
fn decode_vec_float(&self, log_base2k: usize, data: &mut [Float]);
/// encodes a single i64 on the receiver at the given index.
///
/// # Arguments
@@ -123,6 +131,36 @@ impl Encoding for VecZnx {
})
}
fn decode_vec_float(&self, log_base2k: usize, data: &mut [Float]) {
let cols: usize = self.cols();
assert!(
data.len() >= self.n(),
"invalid data: data.len()={} < self.n()={}",
data.len(),
self.n()
);
let prec: u32 = (log_base2k * cols) as u32;
// 2^{log_base2k}
let base = Float::with_val(prec, (1 << log_base2k) as f64);
// y[i] = sum x[j][i] * 2^{-log_base2k*j}
(0..cols).for_each(|i| {
if i == 0 {
izip!(self.at(cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
y.assign(*x);
*y /= &base;
});
} else {
izip!(self.at(cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
*y += Float::with_val(prec, *x);
*y /= &base;
});
}
});
}
fn encode_coeff_i64(
&mut self,
log_base2k: usize,

View File

@@ -12,6 +12,7 @@ pub mod free;
pub mod infos;
pub mod module;
pub mod sampling;
pub mod stats;
pub mod svp;
pub mod vec_znx;
pub mod vec_znx_big;
@@ -23,6 +24,8 @@ pub use free::*;
pub use infos::*;
pub use module::*;
pub use sampling::*;
#[allow(unused_imports)]
pub use stats::*;
pub use svp::*;
pub use vec_znx::*;
pub use vec_znx_big::*;

28
base2k/src/stats.rs Normal file
View File

@@ -0,0 +1,28 @@
use crate::{Infos, VecZnx, Encoding};
use rug::float::Round;
use rug::ops::{AddAssignRound, SubAssignRound, DivAssignRound};
use rug::Float;
impl VecZnx {
pub fn std(&self, log_base2k: usize) -> f64 {
let prec: u32 = (self.cols() * log_base2k) as u32;
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
self.decode_vec_float(log_base2k, &mut data);
// std = sqrt(sum((xi - avg)^2) / n)
let mut avg: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| {
avg.add_assign_round(x, Round::Nearest);
});
avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
data.iter_mut().for_each(|x| {
x.sub_assign_round(&avg, Round::Nearest);
});
let mut std: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| {
std += x*x
});
std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
std = std.sqrt();
std.to_f64()
}
}

View File

@@ -42,6 +42,9 @@ pub trait VecZnxApi {
/// Zeroes the backing array.
fn zero(&mut self);
/// Normalization: propagates carry and ensures each coefficients
/// falls into the range [-2^{K-1}, 2^{K-1}].
fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]);
/// Right shifts the coefficients by k bits.

View File

@@ -454,18 +454,18 @@ impl VmpPMatOps for Module {
fn vmp_apply_dft_tmp_bytes(
&self,
c_cols: usize,
res_cols: usize,
a_cols: usize,
rows: usize,
cols: usize,
gct_rows: usize,
gct_cols: usize,
) -> usize {
unsafe {
vmp::vmp_apply_dft_tmp_bytes(
self.0,
c_cols as u64,
res_cols as u64,
a_cols as u64,
rows as u64,
cols as u64,
gct_rows as u64,
gct_cols as u64,
) as usize
}
}
@@ -495,18 +495,18 @@ impl VmpPMatOps for Module {
fn vmp_apply_dft_to_dft_tmp_bytes(
&self,
c_cols: usize,
res_cols: usize,
a_cols: usize,
rows: usize,
cols: usize,
gct_rows: usize,
gct_cols: usize,
) -> usize {
unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes(
self.0,
c_cols as u64,
res_cols as u64,
a_cols as u64,
rows as u64,
cols as u64,
gct_rows as u64,
gct_cols as u64,
) as usize
}
}

View File

@@ -6,6 +6,7 @@ version = "0.1.0"
edition = "2024"
[dependencies]
rug = {workspace = true}
criterion = {workspace = true}
base2k = {path="../base2k"}
sampling = {path="../sampling"}

View File

@@ -1,10 +1,11 @@
/*
use base2k::{FFT64, Module, SvpPPolOps, VecZnx, VmpPMat, alloc_aligned_u8};
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use rlwe::{
ciphertext::{Ciphertext, new_gadget_ciphertext},
elem::Elem,
encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes},
evaluator::{gadget_product_inplace, gadget_product_tmp_bytes},
elem::{Elem, ElemCommon},
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
gadget_product::{gadget_product_core, gadget_product_tmp_bytes},
key_generator::gen_switching_key_thread_safe_tmp_bytes,
keys::SecretKey,
parameters::{Parameters, ParametersLiteral},
@@ -18,7 +19,9 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
gadget_ct: &'a Ciphertext<VmpPMat>,
tmp_bytes: &'a mut [u8],
) -> Box<dyn FnMut() + 'a> {
Box::new(move || gadget_product_inplace::<true, _>(module, elem, gadget_ct, tmp_bytes))
Box::new(move || {
gadget_product_inplace::<true, _>(module, elem, gadget_ct, elem.cols() + 1, tmp_bytes)
})
}
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
@@ -43,7 +46,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
| gen_switching_key_thread_safe_tmp_bytes(
params.module(),
params.log_base2k(),
params.limbs_q(),
params.cols_q(),
params.log_q(),
)
| gadget_product_tmp_bytes(
@@ -51,13 +54,13 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
params.log_base2k(),
params.log_q(),
params.log_q(),
params.limbs_q(),
params.cols_q(),
params.log_qp(),
)
| encrypt_grlwe_sk_tmp_bytes(
params.module(),
params.log_base2k(),
params.limbs_qp(),
params.cols_qp(),
params.log_qp(),
),
64,
@@ -82,7 +85,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
let mut gadget_ct: Ciphertext<VmpPMat> = new_gadget_ciphertext(
params.module(),
params.log_base2k(),
params.limbs_q(),
params.cols_q(),
params.log_qp(),
);
@@ -123,3 +126,4 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
criterion_group!(benches, bench_gadget_product_inplace);
criterion_main!(benches);
*/

View File

@@ -1,6 +1,7 @@
use base2k::{Encoding, FFT64, SvpPPolOps, VecZnx, VecZnxApi};
use rlwe::{
ciphertext::Ciphertext,
elem::ElemCommon,
keys::SecretKey,
parameters::{Parameters, ParametersLiteral},
plaintext::Plaintext,
@@ -22,7 +23,7 @@ fn main() {
let mut tmp_bytes: Vec<u8> = vec![
0u8;
params.decrypt_rlwe_thread_safe_tmp_byte(params.log_q())
params.decrypt_rlwe_tmp_byte(params.log_q())
| params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
];
@@ -64,7 +65,7 @@ fn main() {
&mut tmp_bytes,
);
params.decrypt_rlwe_thread_safe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes);
params.decrypt_rlwe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes);
pt.0.value[0].print(pt.cols(), 16);

View File

@@ -1,140 +0,0 @@
use base2k::{Encoding, FFT64, SvpPPolOps, VecZnx, VecZnxApi, VmpPMat};
use rlwe::{
ciphertext::{Ciphertext, new_gadget_ciphertext},
decryptor::decrypt_rlwe_thread_safe,
encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes},
evaluator::{gadget_product_inplace, gadget_product_tmp_bytes},
key_generator::gen_switching_key_thread_safe_tmp_bytes,
keys::SecretKey,
parameters::{Parameters, ParametersLiteral},
plaintext::Plaintext,
};
use sampling::source::Source;
fn main() {
let params_lit: ParametersLiteral = ParametersLiteral {
log_n: 4,
log_q: 68,
log_p: 17,
log_base2k: 17,
log_scale: 20,
xe: 3.2,
xs: 8,
};
let params: Parameters = Parameters::new::<FFT64>(&params_lit);
let mut tmp_bytes: Vec<u8> = vec![
0u8;
params.decrypt_rlwe_thread_safe_tmp_byte(params.log_q())
| params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
| gen_switching_key_thread_safe_tmp_bytes(
params.module(),
params.log_base2k(),
params.limbs_q(),
params.log_q()
)
| gadget_product_tmp_bytes(
params.module(),
params.log_base2k(),
params.log_q(),
params.log_q(),
params.limbs_q(),
params.log_qp()
)
| encrypt_grlwe_sk_tmp_bytes(
params.module(),
params.log_base2k(),
params.limbs_qp(),
params.log_qp()
)
];
let mut source: Source = Source::new([3; 32]);
let mut sk0: SecretKey = SecretKey::new(params.module());
let mut sk1: SecretKey = SecretKey::new(params.module());
sk0.fill_ternary_hw(params.xs(), &mut source);
sk1.fill_ternary_hw(params.xs(), &mut source);
let mut want = vec![i64::default(); params.n()];
want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
let log_base2k = params.log_base2k();
let log_k: usize = params.log_q() - 2 * log_base2k;
let mut source_xe: Source = Source::new([4; 32]);
let mut source_xa: Source = Source::new([5; 32]);
let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0);
let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0);
let mut gadget_ct: Ciphertext<VmpPMat> = new_gadget_ciphertext(
params.module(),
log_base2k,
params.limbs_q(),
params.log_qp(),
);
encrypt_grlwe_sk_thread_safe(
params.module(),
&mut gadget_ct,
&sk0.0,
&sk1_svp_ppol,
&mut source_xa,
&mut source_xe,
params.xe(),
&mut tmp_bytes,
);
let mut pt: Plaintext<VecZnx> =
Plaintext::<VecZnx>::new(params.module(), params.log_base2k(), params.log_q());
let mut want = vec![i64::default(); params.n()];
want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
pt.0.value[0].encode_vec_i64(log_base2k, log_k, &want, 32);
pt.0.value[0].normalize(log_base2k, &mut tmp_bytes);
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(params.log_q());
params.encrypt_rlwe_sk_thread_safe(
&mut ct,
Some(&pt),
&sk0_svp_ppol,
&mut source_xa,
&mut source_xe,
&mut tmp_bytes,
);
gadget_product_inplace::<true, _>(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes);
println!("ct.limbs()={}", ct.cols());
println!("gadget_ct.rows()={}", gadget_ct.rows());
println!("gadget_ct.cols()={}", gadget_ct.cols());
println!("res.limbs()={}", ct.cols());
println!();
decrypt_rlwe_thread_safe(
params.module(),
&mut pt.0,
&ct.0,
&sk1_svp_ppol,
&mut tmp_bytes,
);
pt.0.value[0].print(pt.cols(), 16);
let mut have: Vec<i64> = vec![i64::default(); params.n()];
println!("pt: {}", log_k);
pt.0.value[0].decode_vec_i64(pt.log_base2k(), log_k, &mut have);
println!("want: {:?}", &want[..16]);
println!("have: {:?}", &have[..16]);
}

View File

@@ -1,70 +1,77 @@
use crate::elem::{Elem, ElemVecZnx, VecZnxCommon};
use crate::elem::{Elem, ElemCommon};
use crate::parameters::Parameters;
use crate::plaintext::Plaintext;
use base2k::{Infos, Module, VecZnx, VmpPMat};
pub struct Ciphertext<T>(pub Elem<T>);
impl Parameters {
pub fn new_ciphertext(&self, log_q: usize) -> Ciphertext<VecZnx> {
Ciphertext::new(self.module(), self.log_base2k(), log_q, 2)
}
}
impl<T> ElemCommon<T> for Ciphertext<T>
where
T: Infos,
{
fn n(&self) -> usize {
self.elem().n()
}
fn log_n(&self) -> usize {
self.elem().log_n()
}
fn log_q(&self) -> usize {
self.elem().log_q()
}
fn elem(&self) -> &Elem<T> {
&self.0
}
fn elem_mut(&mut self) -> &mut Elem<T> {
&mut self.0
}
fn size(&self) -> usize {
self.elem().size()
}
fn rows(&self) -> usize {
self.elem().rows()
}
fn cols(&self) -> usize {
self.elem().cols()
}
fn at(&self, i: usize) -> &T {
self.elem().at(i)
}
fn at_mut(&mut self, i: usize) -> &mut T {
self.elem_mut().at_mut(i)
}
fn log_base2k(&self) -> usize {
self.elem().log_base2k()
}
fn log_scale(&self) -> usize {
self.elem().log_scale()
}
}
impl Ciphertext<VecZnx> {
pub fn new(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> Self {
Self(Elem::<VecZnx>::new(module, log_base2k, log_q, rows))
}
}
impl<T> Ciphertext<T>
where
T: VecZnxCommon<Owned = T>,
{
pub fn zero(&mut self) {
self.0.zero()
}
pub fn as_plaintext(&self) -> Plaintext<T> {
unsafe { Plaintext::<T>(std::ptr::read(&self.0)) }
}
}
impl<T> Ciphertext<T>
where
T: Infos,
{
pub fn n(&self) -> usize {
self.0.n()
}
pub fn log_q(&self) -> usize {
self.0.log_q
}
pub fn rows(&self) -> usize {
self.0.rows()
}
pub fn cols(&self) -> usize {
self.0.cols()
}
pub fn at(&self, i: usize) -> &T {
self.0.at(i)
}
pub fn at_mut(&mut self, i: usize) -> &mut T {
self.0.at_mut(i)
}
pub fn log_base2k(&self) -> usize {
self.0.log_base2k
}
pub fn log_scale(&self) -> usize {
self.0.log_scale
}
}
impl Parameters {
pub fn new_ciphertext(&self, log_q: usize) -> Ciphertext<VecZnx> {
Ciphertext::new(self.module(), self.log_base2k(), log_q, 2)
}
pub fn new_rlwe_ciphertext(module: &Module, log_base2k: usize, log_q: usize) -> Ciphertext<VecZnx> {
let rows: usize = 2;
Ciphertext::<VecZnx>::new(module, log_base2k, log_q, rows)
}
pub fn new_gadget_ciphertext(
@@ -74,7 +81,7 @@ pub fn new_gadget_ciphertext(
log_q: usize,
) -> Ciphertext<VmpPMat> {
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 1, rows, 2 * cols);
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 2, rows, cols);
elem.log_q = log_q;
Ciphertext(elem)
}
@@ -86,7 +93,7 @@ pub fn new_rgsw_ciphertext(
log_q: usize,
) -> Ciphertext<VmpPMat> {
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 2, rows, 2 * cols);
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 4, rows, cols);
elem.log_q = log_q;
Ciphertext(elem)
}

View File

@@ -1,6 +1,6 @@
use crate::{
ciphertext::Ciphertext,
elem::{Elem, ElemVecZnx, VecZnxCommon},
elem::{Elem, ElemCommon, VecZnxCommon},
keys::SecretKey,
parameters::Parameters,
plaintext::Plaintext,
@@ -20,19 +20,19 @@ impl Decryptor {
}
}
pub fn decrypt_rlwe_thread_safe_tmp_byte(module: &Module, limbs: usize) -> usize {
pub fn decrypt_rlwe_tmp_byte(module: &Module, limbs: usize) -> usize {
module.bytes_of_vec_znx_dft(limbs) + module.vec_znx_big_normalize_tmp_bytes()
}
impl Parameters {
pub fn decrypt_rlwe_thread_safe_tmp_byte(&self, log_q: usize) -> usize {
decrypt_rlwe_thread_safe_tmp_byte(
pub fn decrypt_rlwe_tmp_byte(&self, log_q: usize) -> usize {
decrypt_rlwe_tmp_byte(
self.module(),
(log_q + self.log_base2k() - 1) / self.log_base2k(),
)
}
pub fn decrypt_rlwe_thread_safe<T>(
pub fn decrypt_rlwe<T>(
&self,
res: &mut Plaintext<T>,
ct: &Ciphertext<T>,
@@ -40,13 +40,13 @@ impl Parameters {
tmp_bytes: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
Elem<T>: ElemCommon<T>,
{
decrypt_rlwe_thread_safe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes)
decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes)
}
}
pub fn decrypt_rlwe_thread_safe<T>(
pub fn decrypt_rlwe<T>(
module: &Module,
res: &mut Elem<T>,
a: &Elem<T>,
@@ -54,15 +54,15 @@ pub fn decrypt_rlwe_thread_safe<T>(
tmp_bytes: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
Elem<T>: ElemCommon<T>,
{
let cols: usize = a.cols();
assert!(
tmp_bytes.len() >= decrypt_rlwe_thread_safe_tmp_byte(module, cols),
"invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_thread_safe_tmp_byte={}",
tmp_bytes.len() >= decrypt_rlwe_tmp_byte(module, cols),
"invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_tmp_byte={}",
tmp_bytes.len(),
decrypt_rlwe_thread_safe_tmp_byte(module, cols)
decrypt_rlwe_tmp_byte(module, cols)
);
let res_dft_bytes: usize = module.bytes_of_vec_znx_dft(cols);

View File

@@ -75,45 +75,68 @@ where
}
}
impl<T: Infos> Elem<T> {
pub fn n(&self) -> usize {
pub trait ElemCommon<T> {
fn n(&self) -> usize;
fn log_n(&self) -> usize;
fn elem(&self) -> &Elem<T>;
fn elem_mut(&mut self) -> &mut Elem<T>;
fn size(&self) -> usize;
fn rows(&self) -> usize;
fn cols(&self) -> usize;
fn log_base2k(&self) -> usize;
fn log_q(&self) -> usize;
fn log_scale(&self) -> usize;
fn at(&self, i: usize) -> &T;
fn at_mut(&mut self, i: usize) -> &mut T;
}
impl<T: Infos> ElemCommon<T> for Elem<T> {
fn n(&self) -> usize {
self.value[0].n()
}
pub fn log_n(&self) -> usize {
fn log_n(&self) -> usize {
self.value[0].log_n()
}
pub fn size(&self) -> usize {
fn elem(&self) -> &Elem<T> {
self
}
fn elem_mut(&mut self) -> &mut Elem<T> {
self
}
fn size(&self) -> usize {
self.value.len()
}
pub fn rows(&self) -> usize {
fn rows(&self) -> usize {
self.value[0].rows()
}
pub fn cols(&self) -> usize {
fn cols(&self) -> usize {
self.value[0].cols()
}
pub fn log_base2k(&self) -> usize {
fn log_base2k(&self) -> usize {
self.log_base2k
}
pub fn log_q(&self) -> usize {
fn log_q(&self) -> usize {
self.log_q
}
pub fn log_scale(&self) -> usize {
fn log_scale(&self) -> usize {
self.log_scale
}
pub fn at(&self, i: usize) -> &T {
fn at(&self, i: usize) -> &T {
assert!(i < self.size());
&self.value[i]
}
pub fn at_mut(&mut self, i: usize) -> &mut T {
fn at_mut(&mut self, i: usize) -> &mut T {
assert!(i < self.size());
&mut self.value[i]
}

View File

@@ -1,5 +1,5 @@
use crate::ciphertext::Ciphertext;
use crate::elem::{Elem, ElemVecZnx, VecZnxCommon};
use crate::elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon};
use crate::keys::SecretKey;
use crate::parameters::Parameters;
use crate::plaintext::Plaintext;
@@ -32,7 +32,7 @@ impl EncryptorSk {
initialized,
source_xa: Source::new(new_seed()),
source_xe: Source::new(new_seed()),
tmp_bytes: vec![0u8; params.encrypt_rlwe_sk_tmp_bytes(params.limbs_qp())],
tmp_bytes: vec![0u8; params.encrypt_rlwe_sk_tmp_bytes(params.cols_qp())],
}
}
@@ -56,7 +56,7 @@ impl EncryptorSk {
pt: Option<&Plaintext<T>>,
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
Elem<T>: ElemCommon<T>,
{
assert!(
self.initialized == true,
@@ -82,7 +82,7 @@ impl EncryptorSk {
tmp_bytes: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
Elem<T>: ElemCommon<T>,
{
assert!(
self.initialized == true,
@@ -107,7 +107,7 @@ impl Parameters {
tmp_bytes: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
Elem<T>: ElemCommon<T>,
{
encrypt_rlwe_sk_thread_safe(
self.module(),
@@ -138,7 +138,7 @@ pub fn encrypt_rlwe_sk_thread_safe<T>(
tmp_bytes: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
Elem<T>: ElemCommon<T>,
{
let cols: usize = ct.cols();
let log_base2k: usize = ct.log_base2k();
@@ -197,6 +197,12 @@ pub fn encrypt_rlwe_sk_thread_safe<T>(
);
}
impl Parameters {
pub fn encrypt_grlwe_sk_tmp_bytes(&self, rows: usize, log_q: usize) -> usize {
encrypt_grlwe_sk_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
}
}
pub fn encrypt_grlwe_sk_tmp_bytes(
module: &Module,
log_base2k: usize,
@@ -207,10 +213,10 @@ pub fn encrypt_grlwe_sk_tmp_bytes(
Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2)
+ Plaintext::<VecZnx>::bytes_of(module, log_base2k, log_q)
+ encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
+ module.vmp_prepare_tmp_bytes(rows, 2 * cols)
+ module.vmp_prepare_tmp_bytes(rows, cols)
}
pub fn encrypt_grlwe_sk_thread_safe(
pub fn encrypt_grlwe_sk(
module: &Module,
ct: &mut Ciphertext<VmpPMat>,
m: &Scalar,
@@ -249,7 +255,7 @@ pub fn encrypt_grlwe_sk_thread_safe(
(0..rows).for_each(|row_i| {
// Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i})
tmp_pt.0.value[0].at_mut(row_i).copy_from_slice(&m.0);
tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.0);
// Encrypts RLWE(m * 2^{-log_base2k*i})
encrypt_rlwe_sk_thread_safe(
@@ -263,19 +269,28 @@ pub fn encrypt_grlwe_sk_thread_safe(
tmp_bytes_enc_sk,
);
//tmp_pt.at(0).print(tmp_pt.cols(), 16);
//println!();
// Zeroes the ith-row of tmp_pt
tmp_pt.0.value[0].at_mut(row_i).fill(0);
tmp_pt.at_mut(0).at_mut(row_i).fill(0);
//println!("row:{}/{}", row_i, rows);
//tmp_elem.at(0).print(tmp_elem.limbs(), tmp_elem.n());
//tmp_elem.at(1).print(tmp_elem.limbs(), tmp_elem.n());
//tmp_elem.at(0).print(tmp_elem.cols(), tmp_elem.n());
//tmp_elem.at(1).print(tmp_elem.cols(), tmp_elem.n());
//println!();
//println!(">>>");
// GRLWE[row_i][0||1] = [-as + m * 2^{-i*log_base2k} + e*2^{-log_q} || a]
module.vmp_prepare_row(
&mut ct.0.value[0],
cast_mut::<u8, i64>(tmp_bytes_elem),
&mut ct.at_mut(0),
tmp_elem.at(0).raw(),
row_i,
tmp_bytes_vmp_prepare_row,
);
module.vmp_prepare_row(
&mut ct.at_mut(1),
tmp_elem.at(1).raw(),
row_i,
tmp_bytes_vmp_prepare_row,
);

View File

@@ -1,171 +0,0 @@
use crate::{
ciphertext::Ciphertext,
elem::{Elem, ElemVecZnx, VecZnxCommon},
};
use base2k::{Module, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
use std::cmp::min;
pub fn gadget_product_tmp_bytes(
module: &Module,
log_base2k: usize,
out_log_q: usize,
in_log_q: usize,
gct_rows: usize,
gct_log_q: usize,
) -> usize {
let gct_cols: usize = (gct_log_q + log_base2k - 1) / log_base2k;
let in_limbs: usize = (in_log_q + log_base2k - 1) / log_base2k;
let out_limbs: usize = (out_log_q + log_base2k - 1) / log_base2k;
module.vmp_apply_dft_to_dft_tmp_bytes(out_limbs, in_limbs, gct_rows, gct_cols)
+ 2 * module.bytes_of_vec_znx_dft(gct_cols)
}
pub fn gadget_product_inplace<const OVERWRITE: bool, T>(
module: &Module,
res: &mut Elem<T>,
b: &Ciphertext<VmpPMat>,
tmp_bytes: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
unsafe {
let a_ptr: *const T = res.at(1) as *const T;
gadget_product::<OVERWRITE, T>(module, res, &*a_ptr, b, tmp_bytes);
}
}
/// Evaluates the gadget product res <- a x b.
///
/// # Arguments
///
/// * `module`: backend support for operations mod (X^N + 1).
/// * `res`: an [Elem] to store (-cs + m * a + e, c) with res_ncols limbs.
/// * `a`: a [VecZnx] of a_ncols limbs.
/// * `b`: a [GadgetCiphertext] as a vector of (-Bs + m * 2^{-k} + E, B)
/// containing b_nrows [VecZnx], each of b_ncols limbs.
///
/// # Computation
///
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
/// = (cs + m * a + e, c) with min(res_limbs, b_cols) limbs.
pub fn gadget_product<const OVERWRITE: bool, T>(
module: &Module,
res: &mut Elem<T>,
a: &T,
b: &Ciphertext<VmpPMat>,
tmp_bytes: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
let log_base2k: usize = b.log_base2k();
let rows: usize = min(b.rows(), a.cols());
let cols: usize = b.cols();
let bytes_vmp_apply_dft: usize =
module.vmp_apply_dft_to_dft_tmp_bytes(cols, a.cols(), rows, cols);
let bytes_vec_znx_dft: usize = module.bytes_of_vec_znx_dft(cols);
let (tmp_bytes_vmp_apply_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_vmp_apply_dft);
let (tmp_bytes_c1_dft, tmp_bytes_res_dft) = tmp_bytes.split_at_mut(bytes_vec_znx_dft);
let mut tmp_a_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_res_dft);
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
// Alias c0 and c1 part of res_big
let (tmp_bytes_res_dft_c0, tmp_bytes_res_dft_c1) =
tmp_bytes_res_dft.split_at_mut(bytes_vec_znx_dft >> 1);
let res_big_c0: VecZnxBig = module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c0);
let mut res_big_c1: VecZnxBig =
module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c1);
// tmp_a_dft <- DFT(a)
// (n x cols) <- (n x limbs=rows) x (rows x cols)
// res_dft[a * (G0|G1)] <- sum[rows] tmp_a_dft x (DFT(G0)|DFT(G1))
gadget_product_core(module, &mut res_dft, a, b.at(0), tmp_bytes_vmp_apply_dft);
// res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)])
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols);
// res_big <- res[0] + res_big[a*G0]
module.vec_znx_big_add_small_inplace(&mut res_big, res.at(0));
module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big_c0, tmp_bytes_c1_dft);
if OVERWRITE {
// res[1] = normalize(res_big[a*G1])
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_c1, tmp_bytes_c1_dft);
} else {
// res[1] = normalize(res_big[a*G1] + res[1])
module.vec_znx_big_add_small_inplace(&mut res_big_c1, res.at(1));
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_c1, tmp_bytes_c1_dft);
}
}
pub fn gadget_product_core<T>(
module: &Module,
res_dft: &mut VecZnxDft,
a: &T,
b: &VmpPMat,
tmp_bytes_vmp_apply_dft: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
// res_dft <- DFT(a)
module.vec_znx_dft(res_dft, a, a.cols());
// (n x cols) <- (n x limbs=rows) x (rows x cols)
// res_dft[a * (G0|G1)] <- sum[rows] res_dft x (DFT(G0)|DFT(G1))
module.vmp_apply_dft_to_dft_inplace(res_dft, b, tmp_bytes_vmp_apply_dft);
}
pub fn rgsw_product<T>(
module: &Module,
res: &mut Elem<T>,
a: &Ciphertext<T>,
b: &Ciphertext<VmpPMat>,
tmp_bytes: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
let log_base2k: usize = b.log_base2k();
let rows: usize = min(b.rows(), a.cols());
let cols: usize = b.cols();
let in_cols = a.cols();
let out_cols: usize = a.cols();
let bytes_of_vec_znx_dft = module.bytes_of_vec_znx_dft(cols);
let bytes_of_vmp_apply_dft_to_dft =
module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, rows, cols);
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
let (tmp_bytes_tmp_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
let (tmp_bytes_r1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
let (tmp_bytes_r2_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
let (bytes_of_vmp_apply_dft_to_dft, tmp_bytes) =
tmp_bytes.split_at_mut(bytes_of_vmp_apply_dft_to_dft);
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c0_dft);
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft);
let mut tmp_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_tmp_dft);
let mut r1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r1_dft);
let mut r2_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r2_dft);
// c0_dft <- DFT(a[0])
module.vec_znx_dft(&mut c0_dft, a.at(0), in_cols);
// r_dft <- sum[rows] c0_dft[cols] x RGSW[0][cols]
module.vmp_apply_dft_to_dft(
&mut r1_dft,
&c1_dft,
&b.0.value[0],
bytes_of_vmp_apply_dft_to_dft,
);
// c1_dft <- DFT(a[1])
module.vec_znx_dft(&mut c1_dft, a.at(1), in_cols);
}

371
rlwe/src/gadget_product.rs Normal file
View File

@@ -0,0 +1,371 @@
use crate::{
ciphertext::Ciphertext,
elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon},
parameters::Parameters,
};
use base2k::{Module, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
use std::cmp::min;
pub fn gadget_product_tmp_bytes(
module: &Module,
log_base2k: usize,
res_log_q: usize,
in_log_q: usize,
gct_rows: usize,
gct_log_q: usize,
) -> usize {
let gct_cols: usize = (gct_log_q + log_base2k - 1) / log_base2k;
let in_cols: usize = (in_log_q + log_base2k - 1) / log_base2k;
let out_cols: usize = (res_log_q + log_base2k - 1) / log_base2k;
module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, gct_rows, gct_cols)
}
impl Parameters {
pub fn gadget_product_tmp_bytes(
&self,
res_log_q: usize,
in_log_q: usize,
gct_rows: usize,
gct_log_q: usize,
) -> usize {
gadget_product_tmp_bytes(
self.module(),
self.log_base2k(),
res_log_q,
in_log_q,
gct_rows,
gct_log_q,
)
}
}
/// Evaluates the gadget product res <- a x b.
///
/// # Arguments
///
/// * `module`: backend support for operations mod (X^N + 1).
/// * `res`: an [Elem] to store (-cs + m * a + e, c) with res_ncols cols.
/// * `a`: a [VecZnx] of a_ncols cols.
/// * `b`: a [Ciphertext<VmpPMat>] as a vector of (-Bs + m * 2^{-k} + E, B)
/// containing b_nrows [VecZnx], each of b_ncols cols.
///
/// # Computation
///
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
/// = (cs + m * a + e, c) with min(res_cols, b_cols) cols.
pub fn gadget_product_core<const OVERWRITE: bool, T>(
module: &Module,
res_dft_0: &mut VecZnxDft,
res_dft_1: &mut VecZnxDft,
a: &T,
a_cols: usize,
b: &Ciphertext<VmpPMat>,
b_cols: usize,
tmp_bytes: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
assert!(b_cols <= b.cols());
module.vec_znx_dft(res_dft_1, a, a_cols);
module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes);
module.vmp_apply_dft_to_dft_inplace(res_dft_1, b.at(1), tmp_bytes);
}
/*
// res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)])
module.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0, b_cols);
module.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1, b_cols);
// res_big <- res[0] + res_big[a*G0]
module.vec_znx_big_add_small_inplace(&mut res_big_0, res.at(0));
module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big_0, tmp_bytes_carry);
if OVERWRITE {
// res[1] = normalize(res_big[a*G1])
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_1, tmp_bytes_carry);
} else {
// res[1] = normalize(res_big[a*G1] + res[1])
module.vec_znx_big_add_small_inplace(&mut res_big_1, res.at(1));
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_1, tmp_bytes_carry);
}
*/
#[cfg(test)]
mod test {
use crate::{
ciphertext::{Ciphertext, new_gadget_ciphertext},
decryptor::decrypt_rlwe,
elem::{Elem, ElemCommon, ElemVecZnx},
encryptor::encrypt_grlwe_sk,
gadget_product::gadget_product_core,
keys::SecretKey,
parameters::{Parameters, ParametersLiteral},
plaintext::Plaintext,
};
use base2k::{
FFT64, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxDft,
VecZnxDftOps, VecZnxOps, VmpPMat,
};
use sampling::source::{Source, new_seed};
#[test]
fn test_gadget_product_core() {
let log_base2k: usize = 10;
let q_cols: usize = 7;
let p_cols: usize = 1;
// Basic parameters with enough limbs to test edge cases
let params_lit: ParametersLiteral = ParametersLiteral {
log_n: 12,
log_q: q_cols * log_base2k,
log_p: p_cols * log_base2k,
log_base2k: log_base2k,
log_scale: 20,
xe: 3.2,
xs: 1 << 11,
};
let params: Parameters = Parameters::new::<FFT64>(&params_lit);
// scratch space
let mut tmp_bytes: Vec<u8> =
vec![
0u8;
params.decrypt_rlwe_tmp_byte(params.log_qp())
| params.encrypt_rlwe_sk_tmp_bytes(params.log_qp())
| params.gadget_product_tmp_bytes(
params.log_qp(),
params.log_qp(),
params.cols_qp(),
params.log_qp()
)
| params.encrypt_grlwe_sk_tmp_bytes(params.cols_qp(), params.log_qp())
];
// Samplers for public and private randomness
let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(new_seed());
let mut source_xs: Source = Source::new(new_seed());
// Two secret keys
let mut sk0: SecretKey = SecretKey::new(params.module());
sk0.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0);
let mut sk1: SecretKey = SecretKey::new(params.module());
sk1.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0);
// The gadget ciphertext
let mut gadget_ct: Ciphertext<VmpPMat> = new_gadget_ciphertext(
params.module(),
log_base2k,
params.cols_qp(),
params.log_qp(),
);
// gct = [-b*sk1 + g(sk0) + e, b]
encrypt_grlwe_sk(
params.module(),
&mut gadget_ct,
&sk0.0,
&sk1_svp_ppol,
&mut source_xa,
&mut source_xe,
params.xe(),
&mut tmp_bytes,
);
// Intermediate buffers
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big();
let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big();
// Input polynopmial, uniformly distributed
let mut a: VecZnx = params.module().new_vec_znx(params.cols_q());
params
.module()
.fill_uniform(log_base2k, &mut a, params.cols_q(), &mut source_xa);
// res = g^-1(a) * gct
let mut elem_res: Elem<VecZnx> =
Elem::<VecZnx>::new(params.module(), log_base2k, params.log_qp(), 2);
// Ideal output = a * s
let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(a.cols());
let mut a_big: VecZnxBig = a_dft.as_vec_znx_big();
let mut a_times_s: VecZnx = params.module().new_vec_znx(a.cols());
// a * sk0
params
.module()
.svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a, a.cols());
params
.module()
.vec_znx_idft_tmp_a(&mut a_big, &mut a_dft, a.cols());
params.module().vec_znx_big_normalize(
params.log_base2k(),
&mut a_times_s,
&a_big,
&mut tmp_bytes,
);
// Plaintext for decrypted output of gadget product
let mut pt: Plaintext<VecZnx> =
Plaintext::<VecZnx>::new(params.module(), params.log_base2k(), params.log_qp());
// Iterates over all possible cols values for input/output polynomials and gadget ciphertext.
pt.elem_mut().zero();
elem_res.zero();
let a_cols: usize = a.cols() - 1;
let b_cols: usize = gadget_ct.cols();
println!("a_cols: {} b_cols: {}", a_cols, b_cols);
// res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e')
// res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c)
gadget_product_core::<true, _>(
params.module(),
&mut res_dft_0,
&mut res_dft_1,
&a,
a_cols,
&gadget_ct,
b_cols,
&mut tmp_bytes,
);
// res_big_0 = IDFT(res_dft_0)
params
.module()
.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0, b_cols);
// res_big_1 = IDFT(res_dft_1);
params
.module()
.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1, b_cols);
// res_big_0 = normalize(res_big_0)
params.module().vec_znx_big_normalize(
log_base2k,
elem_res.at_mut(0),
&res_big_0,
&mut tmp_bytes,
);
// res_big_1 = normalize(res_big_1)
params.module().vec_znx_big_normalize(
log_base2k,
elem_res.at_mut(1),
&res_big_1,
&mut tmp_bytes,
);
// <(-c*sk1 + a*sk0 + e, a), (1, sk1)> = a*sk0 + e
decrypt_rlwe(
params.module(),
pt.elem_mut(),
&elem_res,
&sk1_svp_ppol,
&mut tmp_bytes,
);
// a * sk0 + e - a*sk0 = e
params
.module()
.vec_znx_sub_inplace(pt.at_mut(0), &mut a_times_s);
pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes);
pt.at(0).print(pt.elem().cols(), 16);
println!("noise_have: {}", pt.at(0).std(log_base2k).log2());
let var_a_err: f64;
if a_cols < a.cols() {
var_a_err = 1f64 / 12f64;
} else {
var_a_err = 0f64;
}
let a_logq: usize = a_cols * log_base2k;
let b_logq: usize = b_cols * log_base2k;
let var_msg: f64 = params.xs() as f64;
println!(
"noise_pred: {}",
params.noise_grlwe_product(var_msg, var_a_err, a_logq, b_logq)
);
}
}
impl Parameters {
pub fn noise_grlwe_product(
&self,
var_msg: f64,
var_a_err: f64,
a_logq: usize,
b_logq: usize,
) -> f64 {
let n: f64 = self.n() as f64;
let var_xs: f64 = self.xs() as f64;
let var_gct_err_lhs: f64;
let var_gct_err_rhs: f64;
if b_logq < self.log_qp() {
let var_round: f64 = 1f64 / 12f64;
var_gct_err_lhs = var_round;
var_gct_err_rhs = var_round;
} else {
var_gct_err_lhs = self.xe() * self.xe();
var_gct_err_rhs = 0f64;
}
noise_grlwe_product(
n,
self.log_base2k(),
var_xs,
var_msg,
var_a_err,
var_gct_err_lhs,
var_gct_err_rhs,
a_logq,
b_logq,
)
}
}
pub fn noise_grlwe_product(
n: f64,
log_base2k: usize,
var_xs: f64,
var_msg: f64,
var_a_err: f64,
var_gct_err_lhs: f64,
var_gct_err_rhs: f64,
a_logq: usize,
b_logq: usize,
) -> f64 {
let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k;
let b_cols: usize = (b_logq + log_base2k - 1) / log_base2k;
let b_scale = 2.0f64.powi(b_logq as i32);
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
let base: f64 = (1 << (log_base2k)) as f64;
let var_base: f64 = base * base / 12f64;
let var_round: f64 = 1f64 / 12f64;
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
let mut noise: f64 =
(a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
noise += var_msg * var_a_err * a_scale * a_scale;
noise = noise.sqrt();
noise /= b_scale;
noise.log2()
}

View File

@@ -1,4 +1,4 @@
use crate::encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes};
use crate::encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes};
use crate::keys::{PublicKey, SecretKey, SwitchingKey};
use crate::parameters::Parameters;
use base2k::{Module, SvpPPol};
@@ -40,7 +40,7 @@ impl KeyGenerator {
}
}
pub fn gen_switching_key_thread_safe_tmp_bytes(
pub fn gen_switching_key_tmp_bytes(
module: &Module,
log_base2k: usize,
rows: usize,
@@ -49,7 +49,7 @@ pub fn gen_switching_key_thread_safe_tmp_bytes(
encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
}
pub fn gen_switching_key_thread_safe(
pub fn gen_switching_key(
module: &Module,
swk: &mut SwitchingKey,
sk_in: &SecretKey,
@@ -59,7 +59,7 @@ pub fn gen_switching_key_thread_safe(
sigma: f64,
tmp_bytes: &mut [u8],
) {
encrypt_grlwe_sk_thread_safe(
encrypt_grlwe_sk(
module, &mut swk.0, &sk_in.0, sk_out, source_xa, source_xe, sigma, tmp_bytes,
);
}

View File

@@ -1,5 +1,5 @@
use crate::ciphertext::{Ciphertext, new_gadget_ciphertext};
use crate::elem::Elem;
use crate::elem::{Elem, ElemCommon};
use crate::encryptor::{encrypt_rlwe_sk_thread_safe, encrypt_rlwe_sk_tmp_bytes};
use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VmpPMat};
use sampling::source::Source;

View File

@@ -2,8 +2,9 @@ pub mod ciphertext;
pub mod decryptor;
pub mod elem;
pub mod encryptor;
pub mod evaluator;
pub mod gadget_product;
pub mod key_generator;
pub mod keys;
pub mod parameters;
pub mod plaintext;
pub mod rgsw_product;

View File

@@ -59,11 +59,11 @@ impl Parameters {
self.log_q + self.log_p
}
pub fn limbs_q(&self) -> usize {
pub fn cols_q(&self) -> usize {
(self.log_q + self.log_base2k - 1) / self.log_base2k
}
pub fn limbs_qp(&self) -> usize {
pub fn cols_qp(&self) -> usize {
(self.log_q + self.log_p + self.log_base2k - 1) / self.log_base2k
}

View File

@@ -1,5 +1,5 @@
use crate::ciphertext::Ciphertext;
use crate::elem::{Elem, ElemVecZnx, VecZnxCommon};
use crate::elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon};
use crate::parameters::Parameters;
use base2k::{Module, VecZnx};
@@ -46,43 +46,61 @@ where
Self(Elem::<T>::from_bytes(module, log_base2k, log_q, 1, bytes))
}
pub fn n(&self) -> usize {
self.0.n()
}
pub fn log_q(&self) -> usize {
self.0.log_q
}
pub fn rows(&self) -> usize {
self.0.rows()
}
pub fn cols(&self) -> usize {
self.0.cols()
}
pub fn at(&self, i: usize) -> &T {
self.0.at(i)
}
pub fn at_mut(&mut self, i: usize) -> &mut T {
self.0.at_mut(i)
}
pub fn log_base2k(&self) -> usize {
self.0.log_base2k()
}
pub fn log_scale(&self) -> usize {
self.0.log_scale()
}
pub fn zero(&mut self) {
self.0.zero()
}
pub fn as_ciphertext(&self) -> Ciphertext<T> {
unsafe { Ciphertext::<T>(std::ptr::read(&self.0)) }
}
}
impl<T> ElemCommon<T> for Plaintext<T>
where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
fn n(&self) -> usize {
self.0.n()
}
fn log_n(&self) -> usize {
self.elem().log_n()
}
fn log_q(&self) -> usize {
self.0.log_q
}
fn elem(&self) -> &Elem<T> {
&self.0
}
fn elem_mut(&mut self) -> &mut Elem<T> {
&mut self.0
}
fn size(&self) -> usize {
self.elem().size()
}
fn rows(&self) -> usize {
self.0.rows()
}
fn cols(&self) -> usize {
self.0.cols()
}
fn at(&self, i: usize) -> &T {
self.0.at(i)
}
fn at_mut(&mut self, i: usize) -> &mut T {
self.0.at_mut(i)
}
fn log_base2k(&self) -> usize {
self.0.log_base2k()
}
fn log_scale(&self) -> usize {
self.0.log_scale()
}
}

55
rlwe/src/rgsw_product.rs Normal file
View File

@@ -0,0 +1,55 @@
use crate::{
ciphertext::Ciphertext,
elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon},
};
use base2k::{Module, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
use std::cmp::min;
pub fn rgsw_product<T>(
module: &Module,
_res: &mut Elem<T>,
a: &Ciphertext<T>,
b: &Ciphertext<VmpPMat>,
tmp_bytes: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
let _log_base2k: usize = b.log_base2k();
let rows: usize = min(b.rows(), a.cols());
let cols: usize = b.cols();
let in_cols = a.cols();
let out_cols: usize = a.cols();
let bytes_of_vec_znx_dft = module.bytes_of_vec_znx_dft(cols);
let bytes_of_vmp_apply_dft_to_dft =
module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, rows, cols);
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
let (tmp_bytes_tmp_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
let (tmp_bytes_r1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
let (tmp_bytes_r2_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
let (bytes_of_vmp_apply_dft_to_dft, tmp_bytes) =
tmp_bytes.split_at_mut(bytes_of_vmp_apply_dft_to_dft);
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c0_dft);
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft);
let mut _tmp_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_tmp_dft);
let mut r1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r1_dft);
let mut _r2_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r2_dft);
// c0_dft <- DFT(a[0])
module.vec_znx_dft(&mut c0_dft, a.at(0), in_cols);
// r_dft <- sum[rows] c0_dft[cols] x RGSW[0][cols]
module.vmp_apply_dft_to_dft(
&mut r1_dft,
&c1_dft,
&b.0.value[0],
bytes_of_vmp_apply_dft_to_dft,
);
// c1_dft <- DFT(a[1])
module.vec_znx_dft(&mut c1_dft, a.at(1), in_cols);
}