mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Fixed gadget product & added noise estimations
This commit is contained in:
30
Cargo.lock
generated
30
Cargo.lock
generated
@@ -49,6 +49,12 @@ version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
|
||||
|
||||
[[package]]
|
||||
name = "az"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b7e4c2464d97fe331d41de9d5db0def0a96f4d823b8b32a2efd503578988973"
|
||||
|
||||
[[package]]
|
||||
name = "base2k"
|
||||
version = "0.1.0"
|
||||
@@ -58,6 +64,7 @@ dependencies = [
|
||||
"rand",
|
||||
"rand_core",
|
||||
"rand_distr",
|
||||
"rug",
|
||||
"sampling",
|
||||
"utils",
|
||||
]
|
||||
@@ -228,6 +235,16 @@ dependencies = [
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gmp-mpfr-sys"
|
||||
version = "1.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b0205cd82059bc63b63cf516d714352a30c44f2c74da9961dfda2617ae6b5918"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.4.1"
|
||||
@@ -639,6 +656,7 @@ dependencies = [
|
||||
"base2k",
|
||||
"criterion",
|
||||
"rand_distr",
|
||||
"rug",
|
||||
"sampling",
|
||||
]
|
||||
|
||||
@@ -660,6 +678,18 @@ dependencies = [
|
||||
"utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rug"
|
||||
version = "1.27.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4207e8d668e5b8eb574bda8322088ccd0d7782d3d03c7e8d562e82ed82bdcbc3"
|
||||
dependencies = [
|
||||
"az",
|
||||
"gmp-mpfr-sys",
|
||||
"libc",
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.18"
|
||||
|
||||
@@ -3,6 +3,7 @@ members = ["base2k", "rlwe", "rns", "sampling", "utils"]
|
||||
|
||||
|
||||
[workspace.dependencies]
|
||||
rug = "1.27"
|
||||
rand = "0.8.4"
|
||||
rand_chacha = "0.3.1"
|
||||
rand_core = "0.6.4"
|
||||
|
||||
@@ -4,6 +4,7 @@ version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
rug = {workspace = true}
|
||||
criterion = {workspace = true}
|
||||
itertools = {workspace = true}
|
||||
rand = {workspace = true}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::ffi::znx::znx_zero_i64_ref;
|
||||
use crate::{Infos, VecZnx, VecZnxApi};
|
||||
use itertools::izip;
|
||||
use rug::{Assign, Float};
|
||||
use std::cmp::min;
|
||||
|
||||
pub trait Encoding {
|
||||
@@ -23,6 +24,13 @@ pub trait Encoding {
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_i64(&self, log_base2k: usize, log_k: usize, data: &mut [i64]);
|
||||
|
||||
/// decode a vector of Float from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `log_base2k`: base two logarithm decomposition of the receiver.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_float(&self, log_base2k: usize, data: &mut [Float]);
|
||||
|
||||
/// encodes a single i64 on the receiver at the given index.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -123,6 +131,36 @@ impl Encoding for VecZnx {
|
||||
})
|
||||
}
|
||||
|
||||
fn decode_vec_float(&self, log_base2k: usize, data: &mut [Float]) {
|
||||
let cols: usize = self.cols();
|
||||
assert!(
|
||||
data.len() >= self.n(),
|
||||
"invalid data: data.len()={} < self.n()={}",
|
||||
data.len(),
|
||||
self.n()
|
||||
);
|
||||
|
||||
let prec: u32 = (log_base2k * cols) as u32;
|
||||
|
||||
// 2^{log_base2k}
|
||||
let base = Float::with_val(prec, (1 << log_base2k) as f64);
|
||||
|
||||
// y[i] = sum x[j][i] * 2^{-log_base2k*j}
|
||||
(0..cols).for_each(|i| {
|
||||
if i == 0 {
|
||||
izip!(self.at(cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
y.assign(*x);
|
||||
*y /= &base;
|
||||
});
|
||||
} else {
|
||||
izip!(self.at(cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y += Float::with_val(prec, *x);
|
||||
*y /= &base;
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn encode_coeff_i64(
|
||||
&mut self,
|
||||
log_base2k: usize,
|
||||
|
||||
@@ -12,6 +12,7 @@ pub mod free;
|
||||
pub mod infos;
|
||||
pub mod module;
|
||||
pub mod sampling;
|
||||
pub mod stats;
|
||||
pub mod svp;
|
||||
pub mod vec_znx;
|
||||
pub mod vec_znx_big;
|
||||
@@ -23,6 +24,8 @@ pub use free::*;
|
||||
pub use infos::*;
|
||||
pub use module::*;
|
||||
pub use sampling::*;
|
||||
#[allow(unused_imports)]
|
||||
pub use stats::*;
|
||||
pub use svp::*;
|
||||
pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
|
||||
28
base2k/src/stats.rs
Normal file
28
base2k/src/stats.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use crate::{Infos, VecZnx, Encoding};
|
||||
use rug::float::Round;
|
||||
use rug::ops::{AddAssignRound, SubAssignRound, DivAssignRound};
|
||||
use rug::Float;
|
||||
|
||||
impl VecZnx {
|
||||
pub fn std(&self, log_base2k: usize) -> f64 {
|
||||
let prec: u32 = (self.cols() * log_base2k) as u32;
|
||||
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
|
||||
self.decode_vec_float(log_base2k, &mut data);
|
||||
// std = sqrt(sum((xi - avg)^2) / n)
|
||||
let mut avg: Float = Float::with_val(prec, 0);
|
||||
data.iter().for_each(|x| {
|
||||
avg.add_assign_round(x, Round::Nearest);
|
||||
});
|
||||
avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
|
||||
data.iter_mut().for_each(|x| {
|
||||
x.sub_assign_round(&avg, Round::Nearest);
|
||||
});
|
||||
let mut std: Float = Float::with_val(prec, 0);
|
||||
data.iter().for_each(|x| {
|
||||
std += x*x
|
||||
});
|
||||
std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
|
||||
std = std.sqrt();
|
||||
std.to_f64()
|
||||
}
|
||||
}
|
||||
@@ -42,6 +42,9 @@ pub trait VecZnxApi {
|
||||
|
||||
/// Zeroes the backing array.
|
||||
fn zero(&mut self);
|
||||
|
||||
/// Normalization: propagates carry and ensures each coefficients
|
||||
/// falls into the range [-2^{K-1}, 2^{K-1}].
|
||||
fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]);
|
||||
|
||||
/// Right shifts the coefficients by k bits.
|
||||
|
||||
@@ -454,18 +454,18 @@ impl VmpPMatOps for Module {
|
||||
|
||||
fn vmp_apply_dft_tmp_bytes(
|
||||
&self,
|
||||
c_cols: usize,
|
||||
res_cols: usize,
|
||||
a_cols: usize,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
gct_rows: usize,
|
||||
gct_cols: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_tmp_bytes(
|
||||
self.0,
|
||||
c_cols as u64,
|
||||
res_cols as u64,
|
||||
a_cols as u64,
|
||||
rows as u64,
|
||||
cols as u64,
|
||||
gct_rows as u64,
|
||||
gct_cols as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
@@ -495,18 +495,18 @@ impl VmpPMatOps for Module {
|
||||
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
&self,
|
||||
c_cols: usize,
|
||||
res_cols: usize,
|
||||
a_cols: usize,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
gct_rows: usize,
|
||||
gct_cols: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
self.0,
|
||||
c_cols as u64,
|
||||
res_cols as u64,
|
||||
a_cols as u64,
|
||||
rows as u64,
|
||||
cols as u64,
|
||||
gct_rows as u64,
|
||||
gct_cols as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
rug = {workspace = true}
|
||||
criterion = {workspace = true}
|
||||
base2k = {path="../base2k"}
|
||||
sampling = {path="../sampling"}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
/*
|
||||
use base2k::{FFT64, Module, SvpPPolOps, VecZnx, VmpPMat, alloc_aligned_u8};
|
||||
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
||||
use rlwe::{
|
||||
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||
elem::Elem,
|
||||
encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes},
|
||||
evaluator::{gadget_product_inplace, gadget_product_tmp_bytes},
|
||||
elem::{Elem, ElemCommon},
|
||||
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
|
||||
gadget_product::{gadget_product_core, gadget_product_tmp_bytes},
|
||||
key_generator::gen_switching_key_thread_safe_tmp_bytes,
|
||||
keys::SecretKey,
|
||||
parameters::{Parameters, ParametersLiteral},
|
||||
@@ -18,7 +19,9 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
||||
gadget_ct: &'a Ciphertext<VmpPMat>,
|
||||
tmp_bytes: &'a mut [u8],
|
||||
) -> Box<dyn FnMut() + 'a> {
|
||||
Box::new(move || gadget_product_inplace::<true, _>(module, elem, gadget_ct, tmp_bytes))
|
||||
Box::new(move || {
|
||||
gadget_product_inplace::<true, _>(module, elem, gadget_ct, elem.cols() + 1, tmp_bytes)
|
||||
})
|
||||
}
|
||||
|
||||
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
|
||||
@@ -43,7 +46,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
||||
| gen_switching_key_thread_safe_tmp_bytes(
|
||||
params.module(),
|
||||
params.log_base2k(),
|
||||
params.limbs_q(),
|
||||
params.cols_q(),
|
||||
params.log_q(),
|
||||
)
|
||||
| gadget_product_tmp_bytes(
|
||||
@@ -51,13 +54,13 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
||||
params.log_base2k(),
|
||||
params.log_q(),
|
||||
params.log_q(),
|
||||
params.limbs_q(),
|
||||
params.cols_q(),
|
||||
params.log_qp(),
|
||||
)
|
||||
| encrypt_grlwe_sk_tmp_bytes(
|
||||
params.module(),
|
||||
params.log_base2k(),
|
||||
params.limbs_qp(),
|
||||
params.cols_qp(),
|
||||
params.log_qp(),
|
||||
),
|
||||
64,
|
||||
@@ -82,7 +85,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
||||
let mut gadget_ct: Ciphertext<VmpPMat> = new_gadget_ciphertext(
|
||||
params.module(),
|
||||
params.log_base2k(),
|
||||
params.limbs_q(),
|
||||
params.cols_q(),
|
||||
params.log_qp(),
|
||||
);
|
||||
|
||||
@@ -123,3 +126,4 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
||||
|
||||
criterion_group!(benches, bench_gadget_product_inplace);
|
||||
criterion_main!(benches);
|
||||
*/
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use base2k::{Encoding, FFT64, SvpPPolOps, VecZnx, VecZnxApi};
|
||||
use rlwe::{
|
||||
ciphertext::Ciphertext,
|
||||
elem::ElemCommon,
|
||||
keys::SecretKey,
|
||||
parameters::{Parameters, ParametersLiteral},
|
||||
plaintext::Plaintext,
|
||||
@@ -22,7 +23,7 @@ fn main() {
|
||||
|
||||
let mut tmp_bytes: Vec<u8> = vec![
|
||||
0u8;
|
||||
params.decrypt_rlwe_thread_safe_tmp_byte(params.log_q())
|
||||
params.decrypt_rlwe_tmp_byte(params.log_q())
|
||||
| params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
|
||||
];
|
||||
|
||||
@@ -64,7 +65,7 @@ fn main() {
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
params.decrypt_rlwe_thread_safe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes);
|
||||
params.decrypt_rlwe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes);
|
||||
|
||||
pt.0.value[0].print(pt.cols(), 16);
|
||||
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
use base2k::{Encoding, FFT64, SvpPPolOps, VecZnx, VecZnxApi, VmpPMat};
|
||||
use rlwe::{
|
||||
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||
decryptor::decrypt_rlwe_thread_safe,
|
||||
encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes},
|
||||
evaluator::{gadget_product_inplace, gadget_product_tmp_bytes},
|
||||
key_generator::gen_switching_key_thread_safe_tmp_bytes,
|
||||
keys::SecretKey,
|
||||
parameters::{Parameters, ParametersLiteral},
|
||||
plaintext::Plaintext,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
fn main() {
|
||||
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||
log_n: 4,
|
||||
log_q: 68,
|
||||
log_p: 17,
|
||||
log_base2k: 17,
|
||||
log_scale: 20,
|
||||
xe: 3.2,
|
||||
xs: 8,
|
||||
};
|
||||
|
||||
let params: Parameters = Parameters::new::<FFT64>(¶ms_lit);
|
||||
|
||||
let mut tmp_bytes: Vec<u8> = vec![
|
||||
0u8;
|
||||
params.decrypt_rlwe_thread_safe_tmp_byte(params.log_q())
|
||||
| params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
|
||||
| gen_switching_key_thread_safe_tmp_bytes(
|
||||
params.module(),
|
||||
params.log_base2k(),
|
||||
params.limbs_q(),
|
||||
params.log_q()
|
||||
)
|
||||
| gadget_product_tmp_bytes(
|
||||
params.module(),
|
||||
params.log_base2k(),
|
||||
params.log_q(),
|
||||
params.log_q(),
|
||||
params.limbs_q(),
|
||||
params.log_qp()
|
||||
)
|
||||
| encrypt_grlwe_sk_tmp_bytes(
|
||||
params.module(),
|
||||
params.log_base2k(),
|
||||
params.limbs_qp(),
|
||||
params.log_qp()
|
||||
)
|
||||
];
|
||||
|
||||
let mut source: Source = Source::new([3; 32]);
|
||||
|
||||
let mut sk0: SecretKey = SecretKey::new(params.module());
|
||||
let mut sk1: SecretKey = SecretKey::new(params.module());
|
||||
|
||||
sk0.fill_ternary_hw(params.xs(), &mut source);
|
||||
sk1.fill_ternary_hw(params.xs(), &mut source);
|
||||
|
||||
let mut want = vec![i64::default(); params.n()];
|
||||
|
||||
want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
|
||||
let log_base2k = params.log_base2k();
|
||||
|
||||
let log_k: usize = params.log_q() - 2 * log_base2k;
|
||||
|
||||
let mut source_xe: Source = Source::new([4; 32]);
|
||||
let mut source_xa: Source = Source::new([5; 32]);
|
||||
|
||||
let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
|
||||
params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0);
|
||||
|
||||
let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
|
||||
params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0);
|
||||
|
||||
let mut gadget_ct: Ciphertext<VmpPMat> = new_gadget_ciphertext(
|
||||
params.module(),
|
||||
log_base2k,
|
||||
params.limbs_q(),
|
||||
params.log_qp(),
|
||||
);
|
||||
|
||||
encrypt_grlwe_sk_thread_safe(
|
||||
params.module(),
|
||||
&mut gadget_ct,
|
||||
&sk0.0,
|
||||
&sk1_svp_ppol,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
params.xe(),
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
let mut pt: Plaintext<VecZnx> =
|
||||
Plaintext::<VecZnx>::new(params.module(), params.log_base2k(), params.log_q());
|
||||
|
||||
let mut want = vec![i64::default(); params.n()];
|
||||
want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
pt.0.value[0].encode_vec_i64(log_base2k, log_k, &want, 32);
|
||||
pt.0.value[0].normalize(log_base2k, &mut tmp_bytes);
|
||||
|
||||
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(params.log_q());
|
||||
|
||||
params.encrypt_rlwe_sk_thread_safe(
|
||||
&mut ct,
|
||||
Some(&pt),
|
||||
&sk0_svp_ppol,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
gadget_product_inplace::<true, _>(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes);
|
||||
|
||||
println!("ct.limbs()={}", ct.cols());
|
||||
println!("gadget_ct.rows()={}", gadget_ct.rows());
|
||||
println!("gadget_ct.cols()={}", gadget_ct.cols());
|
||||
println!("res.limbs()={}", ct.cols());
|
||||
println!();
|
||||
|
||||
decrypt_rlwe_thread_safe(
|
||||
params.module(),
|
||||
&mut pt.0,
|
||||
&ct.0,
|
||||
&sk1_svp_ppol,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
pt.0.value[0].print(pt.cols(), 16);
|
||||
|
||||
let mut have: Vec<i64> = vec![i64::default(); params.n()];
|
||||
|
||||
println!("pt: {}", log_k);
|
||||
pt.0.value[0].decode_vec_i64(pt.log_base2k(), log_k, &mut have);
|
||||
|
||||
println!("want: {:?}", &want[..16]);
|
||||
println!("have: {:?}", &have[..16]);
|
||||
}
|
||||
@@ -1,70 +1,77 @@
|
||||
use crate::elem::{Elem, ElemVecZnx, VecZnxCommon};
|
||||
use crate::elem::{Elem, ElemCommon};
|
||||
use crate::parameters::Parameters;
|
||||
use crate::plaintext::Plaintext;
|
||||
use base2k::{Infos, Module, VecZnx, VmpPMat};
|
||||
|
||||
pub struct Ciphertext<T>(pub Elem<T>);
|
||||
|
||||
impl Parameters {
|
||||
pub fn new_ciphertext(&self, log_q: usize) -> Ciphertext<VecZnx> {
|
||||
Ciphertext::new(self.module(), self.log_base2k(), log_q, 2)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ElemCommon<T> for Ciphertext<T>
|
||||
where
|
||||
T: Infos,
|
||||
{
|
||||
fn n(&self) -> usize {
|
||||
self.elem().n()
|
||||
}
|
||||
|
||||
fn log_n(&self) -> usize {
|
||||
self.elem().log_n()
|
||||
}
|
||||
|
||||
fn log_q(&self) -> usize {
|
||||
self.elem().log_q()
|
||||
}
|
||||
|
||||
fn elem(&self) -> &Elem<T> {
|
||||
&self.0
|
||||
}
|
||||
|
||||
fn elem_mut(&mut self) -> &mut Elem<T> {
|
||||
&mut self.0
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.elem().size()
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.elem().rows()
|
||||
}
|
||||
|
||||
fn cols(&self) -> usize {
|
||||
self.elem().cols()
|
||||
}
|
||||
|
||||
fn at(&self, i: usize) -> &T {
|
||||
self.elem().at(i)
|
||||
}
|
||||
|
||||
fn at_mut(&mut self, i: usize) -> &mut T {
|
||||
self.elem_mut().at_mut(i)
|
||||
}
|
||||
|
||||
fn log_base2k(&self) -> usize {
|
||||
self.elem().log_base2k()
|
||||
}
|
||||
|
||||
fn log_scale(&self) -> usize {
|
||||
self.elem().log_scale()
|
||||
}
|
||||
}
|
||||
|
||||
impl Ciphertext<VecZnx> {
|
||||
pub fn new(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> Self {
|
||||
Self(Elem::<VecZnx>::new(module, log_base2k, log_q, rows))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Ciphertext<T>
|
||||
where
|
||||
T: VecZnxCommon<Owned = T>,
|
||||
{
|
||||
pub fn zero(&mut self) {
|
||||
self.0.zero()
|
||||
}
|
||||
|
||||
pub fn as_plaintext(&self) -> Plaintext<T> {
|
||||
unsafe { Plaintext::<T>(std::ptr::read(&self.0)) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Ciphertext<T>
|
||||
where
|
||||
T: Infos,
|
||||
{
|
||||
pub fn n(&self) -> usize {
|
||||
self.0.n()
|
||||
}
|
||||
|
||||
pub fn log_q(&self) -> usize {
|
||||
self.0.log_q
|
||||
}
|
||||
|
||||
pub fn rows(&self) -> usize {
|
||||
self.0.rows()
|
||||
}
|
||||
|
||||
pub fn cols(&self) -> usize {
|
||||
self.0.cols()
|
||||
}
|
||||
|
||||
pub fn at(&self, i: usize) -> &T {
|
||||
self.0.at(i)
|
||||
}
|
||||
|
||||
pub fn at_mut(&mut self, i: usize) -> &mut T {
|
||||
self.0.at_mut(i)
|
||||
}
|
||||
|
||||
pub fn log_base2k(&self) -> usize {
|
||||
self.0.log_base2k
|
||||
}
|
||||
|
||||
pub fn log_scale(&self) -> usize {
|
||||
self.0.log_scale
|
||||
}
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn new_ciphertext(&self, log_q: usize) -> Ciphertext<VecZnx> {
|
||||
Ciphertext::new(self.module(), self.log_base2k(), log_q, 2)
|
||||
}
|
||||
pub fn new_rlwe_ciphertext(module: &Module, log_base2k: usize, log_q: usize) -> Ciphertext<VecZnx> {
|
||||
let rows: usize = 2;
|
||||
Ciphertext::<VecZnx>::new(module, log_base2k, log_q, rows)
|
||||
}
|
||||
|
||||
pub fn new_gadget_ciphertext(
|
||||
@@ -74,7 +81,7 @@ pub fn new_gadget_ciphertext(
|
||||
log_q: usize,
|
||||
) -> Ciphertext<VmpPMat> {
|
||||
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
|
||||
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 1, rows, 2 * cols);
|
||||
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 2, rows, cols);
|
||||
elem.log_q = log_q;
|
||||
Ciphertext(elem)
|
||||
}
|
||||
@@ -86,7 +93,7 @@ pub fn new_rgsw_ciphertext(
|
||||
log_q: usize,
|
||||
) -> Ciphertext<VmpPMat> {
|
||||
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
|
||||
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 2, rows, 2 * cols);
|
||||
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 4, rows, cols);
|
||||
elem.log_q = log_q;
|
||||
Ciphertext(elem)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
ciphertext::Ciphertext,
|
||||
elem::{Elem, ElemVecZnx, VecZnxCommon},
|
||||
elem::{Elem, ElemCommon, VecZnxCommon},
|
||||
keys::SecretKey,
|
||||
parameters::Parameters,
|
||||
plaintext::Plaintext,
|
||||
@@ -20,19 +20,19 @@ impl Decryptor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrypt_rlwe_thread_safe_tmp_byte(module: &Module, limbs: usize) -> usize {
|
||||
pub fn decrypt_rlwe_tmp_byte(module: &Module, limbs: usize) -> usize {
|
||||
module.bytes_of_vec_znx_dft(limbs) + module.vec_znx_big_normalize_tmp_bytes()
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn decrypt_rlwe_thread_safe_tmp_byte(&self, log_q: usize) -> usize {
|
||||
decrypt_rlwe_thread_safe_tmp_byte(
|
||||
pub fn decrypt_rlwe_tmp_byte(&self, log_q: usize) -> usize {
|
||||
decrypt_rlwe_tmp_byte(
|
||||
self.module(),
|
||||
(log_q + self.log_base2k() - 1) / self.log_base2k(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn decrypt_rlwe_thread_safe<T>(
|
||||
pub fn decrypt_rlwe<T>(
|
||||
&self,
|
||||
res: &mut Plaintext<T>,
|
||||
ct: &Ciphertext<T>,
|
||||
@@ -40,13 +40,13 @@ impl Parameters {
|
||||
tmp_bytes: &mut [u8],
|
||||
) where
|
||||
T: VecZnxCommon<Owned = T>,
|
||||
Elem<T>: ElemVecZnx<T>,
|
||||
Elem<T>: ElemCommon<T>,
|
||||
{
|
||||
decrypt_rlwe_thread_safe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes)
|
||||
decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrypt_rlwe_thread_safe<T>(
|
||||
pub fn decrypt_rlwe<T>(
|
||||
module: &Module,
|
||||
res: &mut Elem<T>,
|
||||
a: &Elem<T>,
|
||||
@@ -54,15 +54,15 @@ pub fn decrypt_rlwe_thread_safe<T>(
|
||||
tmp_bytes: &mut [u8],
|
||||
) where
|
||||
T: VecZnxCommon<Owned = T>,
|
||||
Elem<T>: ElemVecZnx<T>,
|
||||
Elem<T>: ElemCommon<T>,
|
||||
{
|
||||
let cols: usize = a.cols();
|
||||
|
||||
assert!(
|
||||
tmp_bytes.len() >= decrypt_rlwe_thread_safe_tmp_byte(module, cols),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_thread_safe_tmp_byte={}",
|
||||
tmp_bytes.len() >= decrypt_rlwe_tmp_byte(module, cols),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_tmp_byte={}",
|
||||
tmp_bytes.len(),
|
||||
decrypt_rlwe_thread_safe_tmp_byte(module, cols)
|
||||
decrypt_rlwe_tmp_byte(module, cols)
|
||||
);
|
||||
|
||||
let res_dft_bytes: usize = module.bytes_of_vec_znx_dft(cols);
|
||||
|
||||
@@ -75,45 +75,68 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Infos> Elem<T> {
|
||||
pub fn n(&self) -> usize {
|
||||
pub trait ElemCommon<T> {
|
||||
fn n(&self) -> usize;
|
||||
fn log_n(&self) -> usize;
|
||||
fn elem(&self) -> &Elem<T>;
|
||||
fn elem_mut(&mut self) -> &mut Elem<T>;
|
||||
fn size(&self) -> usize;
|
||||
fn rows(&self) -> usize;
|
||||
fn cols(&self) -> usize;
|
||||
fn log_base2k(&self) -> usize;
|
||||
fn log_q(&self) -> usize;
|
||||
fn log_scale(&self) -> usize;
|
||||
fn at(&self, i: usize) -> &T;
|
||||
fn at_mut(&mut self, i: usize) -> &mut T;
|
||||
}
|
||||
|
||||
impl<T: Infos> ElemCommon<T> for Elem<T> {
|
||||
fn n(&self) -> usize {
|
||||
self.value[0].n()
|
||||
}
|
||||
|
||||
pub fn log_n(&self) -> usize {
|
||||
fn log_n(&self) -> usize {
|
||||
self.value[0].log_n()
|
||||
}
|
||||
|
||||
pub fn size(&self) -> usize {
|
||||
fn elem(&self) -> &Elem<T> {
|
||||
self
|
||||
}
|
||||
|
||||
fn elem_mut(&mut self) -> &mut Elem<T> {
|
||||
self
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.value.len()
|
||||
}
|
||||
|
||||
pub fn rows(&self) -> usize {
|
||||
fn rows(&self) -> usize {
|
||||
self.value[0].rows()
|
||||
}
|
||||
|
||||
pub fn cols(&self) -> usize {
|
||||
fn cols(&self) -> usize {
|
||||
self.value[0].cols()
|
||||
}
|
||||
|
||||
pub fn log_base2k(&self) -> usize {
|
||||
fn log_base2k(&self) -> usize {
|
||||
self.log_base2k
|
||||
}
|
||||
|
||||
pub fn log_q(&self) -> usize {
|
||||
fn log_q(&self) -> usize {
|
||||
self.log_q
|
||||
}
|
||||
|
||||
pub fn log_scale(&self) -> usize {
|
||||
fn log_scale(&self) -> usize {
|
||||
self.log_scale
|
||||
}
|
||||
|
||||
pub fn at(&self, i: usize) -> &T {
|
||||
fn at(&self, i: usize) -> &T {
|
||||
assert!(i < self.size());
|
||||
&self.value[i]
|
||||
}
|
||||
|
||||
pub fn at_mut(&mut self, i: usize) -> &mut T {
|
||||
fn at_mut(&mut self, i: usize) -> &mut T {
|
||||
assert!(i < self.size());
|
||||
&mut self.value[i]
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use crate::elem::{Elem, ElemVecZnx, VecZnxCommon};
|
||||
use crate::elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon};
|
||||
use crate::keys::SecretKey;
|
||||
use crate::parameters::Parameters;
|
||||
use crate::plaintext::Plaintext;
|
||||
@@ -32,7 +32,7 @@ impl EncryptorSk {
|
||||
initialized,
|
||||
source_xa: Source::new(new_seed()),
|
||||
source_xe: Source::new(new_seed()),
|
||||
tmp_bytes: vec![0u8; params.encrypt_rlwe_sk_tmp_bytes(params.limbs_qp())],
|
||||
tmp_bytes: vec![0u8; params.encrypt_rlwe_sk_tmp_bytes(params.cols_qp())],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ impl EncryptorSk {
|
||||
pt: Option<&Plaintext<T>>,
|
||||
) where
|
||||
T: VecZnxCommon<Owned = T>,
|
||||
Elem<T>: ElemVecZnx<T>,
|
||||
Elem<T>: ElemCommon<T>,
|
||||
{
|
||||
assert!(
|
||||
self.initialized == true,
|
||||
@@ -82,7 +82,7 @@ impl EncryptorSk {
|
||||
tmp_bytes: &mut [u8],
|
||||
) where
|
||||
T: VecZnxCommon<Owned = T>,
|
||||
Elem<T>: ElemVecZnx<T>,
|
||||
Elem<T>: ElemCommon<T>,
|
||||
{
|
||||
assert!(
|
||||
self.initialized == true,
|
||||
@@ -107,7 +107,7 @@ impl Parameters {
|
||||
tmp_bytes: &mut [u8],
|
||||
) where
|
||||
T: VecZnxCommon<Owned = T>,
|
||||
Elem<T>: ElemVecZnx<T>,
|
||||
Elem<T>: ElemCommon<T>,
|
||||
{
|
||||
encrypt_rlwe_sk_thread_safe(
|
||||
self.module(),
|
||||
@@ -138,7 +138,7 @@ pub fn encrypt_rlwe_sk_thread_safe<T>(
|
||||
tmp_bytes: &mut [u8],
|
||||
) where
|
||||
T: VecZnxCommon<Owned = T>,
|
||||
Elem<T>: ElemVecZnx<T>,
|
||||
Elem<T>: ElemCommon<T>,
|
||||
{
|
||||
let cols: usize = ct.cols();
|
||||
let log_base2k: usize = ct.log_base2k();
|
||||
@@ -197,6 +197,12 @@ pub fn encrypt_rlwe_sk_thread_safe<T>(
|
||||
);
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn encrypt_grlwe_sk_tmp_bytes(&self, rows: usize, log_q: usize) -> usize {
|
||||
encrypt_grlwe_sk_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encrypt_grlwe_sk_tmp_bytes(
|
||||
module: &Module,
|
||||
log_base2k: usize,
|
||||
@@ -207,10 +213,10 @@ pub fn encrypt_grlwe_sk_tmp_bytes(
|
||||
Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2)
|
||||
+ Plaintext::<VecZnx>::bytes_of(module, log_base2k, log_q)
|
||||
+ encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
|
||||
+ module.vmp_prepare_tmp_bytes(rows, 2 * cols)
|
||||
+ module.vmp_prepare_tmp_bytes(rows, cols)
|
||||
}
|
||||
|
||||
pub fn encrypt_grlwe_sk_thread_safe(
|
||||
pub fn encrypt_grlwe_sk(
|
||||
module: &Module,
|
||||
ct: &mut Ciphertext<VmpPMat>,
|
||||
m: &Scalar,
|
||||
@@ -249,7 +255,7 @@ pub fn encrypt_grlwe_sk_thread_safe(
|
||||
|
||||
(0..rows).for_each(|row_i| {
|
||||
// Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i})
|
||||
tmp_pt.0.value[0].at_mut(row_i).copy_from_slice(&m.0);
|
||||
tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.0);
|
||||
|
||||
// Encrypts RLWE(m * 2^{-log_base2k*i})
|
||||
encrypt_rlwe_sk_thread_safe(
|
||||
@@ -263,19 +269,28 @@ pub fn encrypt_grlwe_sk_thread_safe(
|
||||
tmp_bytes_enc_sk,
|
||||
);
|
||||
|
||||
//tmp_pt.at(0).print(tmp_pt.cols(), 16);
|
||||
//println!();
|
||||
|
||||
// Zeroes the ith-row of tmp_pt
|
||||
tmp_pt.0.value[0].at_mut(row_i).fill(0);
|
||||
tmp_pt.at_mut(0).at_mut(row_i).fill(0);
|
||||
|
||||
//println!("row:{}/{}", row_i, rows);
|
||||
//tmp_elem.at(0).print(tmp_elem.limbs(), tmp_elem.n());
|
||||
//tmp_elem.at(1).print(tmp_elem.limbs(), tmp_elem.n());
|
||||
//tmp_elem.at(0).print(tmp_elem.cols(), tmp_elem.n());
|
||||
//tmp_elem.at(1).print(tmp_elem.cols(), tmp_elem.n());
|
||||
//println!();
|
||||
//println!(">>>");
|
||||
|
||||
// GRLWE[row_i][0||1] = [-as + m * 2^{-i*log_base2k} + e*2^{-log_q} || a]
|
||||
module.vmp_prepare_row(
|
||||
&mut ct.0.value[0],
|
||||
cast_mut::<u8, i64>(tmp_bytes_elem),
|
||||
&mut ct.at_mut(0),
|
||||
tmp_elem.at(0).raw(),
|
||||
row_i,
|
||||
tmp_bytes_vmp_prepare_row,
|
||||
);
|
||||
module.vmp_prepare_row(
|
||||
&mut ct.at_mut(1),
|
||||
tmp_elem.at(1).raw(),
|
||||
row_i,
|
||||
tmp_bytes_vmp_prepare_row,
|
||||
);
|
||||
|
||||
@@ -1,171 +0,0 @@
|
||||
use crate::{
|
||||
ciphertext::Ciphertext,
|
||||
elem::{Elem, ElemVecZnx, VecZnxCommon},
|
||||
};
|
||||
use base2k::{Module, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
|
||||
use std::cmp::min;
|
||||
|
||||
pub fn gadget_product_tmp_bytes(
|
||||
module: &Module,
|
||||
log_base2k: usize,
|
||||
out_log_q: usize,
|
||||
in_log_q: usize,
|
||||
gct_rows: usize,
|
||||
gct_log_q: usize,
|
||||
) -> usize {
|
||||
let gct_cols: usize = (gct_log_q + log_base2k - 1) / log_base2k;
|
||||
let in_limbs: usize = (in_log_q + log_base2k - 1) / log_base2k;
|
||||
let out_limbs: usize = (out_log_q + log_base2k - 1) / log_base2k;
|
||||
module.vmp_apply_dft_to_dft_tmp_bytes(out_limbs, in_limbs, gct_rows, gct_cols)
|
||||
+ 2 * module.bytes_of_vec_znx_dft(gct_cols)
|
||||
}
|
||||
|
||||
pub fn gadget_product_inplace<const OVERWRITE: bool, T>(
|
||||
module: &Module,
|
||||
res: &mut Elem<T>,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
tmp_bytes: &mut [u8],
|
||||
) where
|
||||
T: VecZnxCommon<Owned = T>,
|
||||
Elem<T>: ElemVecZnx<T>,
|
||||
{
|
||||
unsafe {
|
||||
let a_ptr: *const T = res.at(1) as *const T;
|
||||
gadget_product::<OVERWRITE, T>(module, res, &*a_ptr, b, tmp_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluates the gadget product res <- a x b.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `module`: backend support for operations mod (X^N + 1).
|
||||
/// * `res`: an [Elem] to store (-cs + m * a + e, c) with res_ncols limbs.
|
||||
/// * `a`: a [VecZnx] of a_ncols limbs.
|
||||
/// * `b`: a [GadgetCiphertext] as a vector of (-Bs + m * 2^{-k} + E, B)
|
||||
/// containing b_nrows [VecZnx], each of b_ncols limbs.
|
||||
///
|
||||
/// # Computation
|
||||
///
|
||||
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
|
||||
/// = (cs + m * a + e, c) with min(res_limbs, b_cols) limbs.
|
||||
pub fn gadget_product<const OVERWRITE: bool, T>(
|
||||
module: &Module,
|
||||
res: &mut Elem<T>,
|
||||
a: &T,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
tmp_bytes: &mut [u8],
|
||||
) where
|
||||
T: VecZnxCommon<Owned = T>,
|
||||
Elem<T>: ElemVecZnx<T>,
|
||||
{
|
||||
let log_base2k: usize = b.log_base2k();
|
||||
let rows: usize = min(b.rows(), a.cols());
|
||||
let cols: usize = b.cols();
|
||||
|
||||
let bytes_vmp_apply_dft: usize =
|
||||
module.vmp_apply_dft_to_dft_tmp_bytes(cols, a.cols(), rows, cols);
|
||||
let bytes_vec_znx_dft: usize = module.bytes_of_vec_znx_dft(cols);
|
||||
|
||||
let (tmp_bytes_vmp_apply_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_vmp_apply_dft);
|
||||
let (tmp_bytes_c1_dft, tmp_bytes_res_dft) = tmp_bytes.split_at_mut(bytes_vec_znx_dft);
|
||||
|
||||
let mut tmp_a_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft);
|
||||
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_res_dft);
|
||||
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
|
||||
|
||||
// Alias c0 and c1 part of res_big
|
||||
let (tmp_bytes_res_dft_c0, tmp_bytes_res_dft_c1) =
|
||||
tmp_bytes_res_dft.split_at_mut(bytes_vec_znx_dft >> 1);
|
||||
let res_big_c0: VecZnxBig = module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c0);
|
||||
let mut res_big_c1: VecZnxBig =
|
||||
module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c1);
|
||||
|
||||
// tmp_a_dft <- DFT(a)
|
||||
// (n x cols) <- (n x limbs=rows) x (rows x cols)
|
||||
// res_dft[a * (G0|G1)] <- sum[rows] tmp_a_dft x (DFT(G0)|DFT(G1))
|
||||
gadget_product_core(module, &mut res_dft, a, b.at(0), tmp_bytes_vmp_apply_dft);
|
||||
|
||||
// res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)])
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols);
|
||||
|
||||
// res_big <- res[0] + res_big[a*G0]
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, res.at(0));
|
||||
module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big_c0, tmp_bytes_c1_dft);
|
||||
|
||||
if OVERWRITE {
|
||||
// res[1] = normalize(res_big[a*G1])
|
||||
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_c1, tmp_bytes_c1_dft);
|
||||
} else {
|
||||
// res[1] = normalize(res_big[a*G1] + res[1])
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big_c1, res.at(1));
|
||||
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_c1, tmp_bytes_c1_dft);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gadget_product_core<T>(
|
||||
module: &Module,
|
||||
res_dft: &mut VecZnxDft,
|
||||
a: &T,
|
||||
b: &VmpPMat,
|
||||
tmp_bytes_vmp_apply_dft: &mut [u8],
|
||||
) where
|
||||
T: VecZnxCommon<Owned = T>,
|
||||
Elem<T>: ElemVecZnx<T>,
|
||||
{
|
||||
// res_dft <- DFT(a)
|
||||
module.vec_znx_dft(res_dft, a, a.cols());
|
||||
|
||||
// (n x cols) <- (n x limbs=rows) x (rows x cols)
|
||||
// res_dft[a * (G0|G1)] <- sum[rows] res_dft x (DFT(G0)|DFT(G1))
|
||||
module.vmp_apply_dft_to_dft_inplace(res_dft, b, tmp_bytes_vmp_apply_dft);
|
||||
}
|
||||
|
||||
pub fn rgsw_product<T>(
|
||||
module: &Module,
|
||||
res: &mut Elem<T>,
|
||||
a: &Ciphertext<T>,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
tmp_bytes: &mut [u8],
|
||||
) where
|
||||
T: VecZnxCommon<Owned = T>,
|
||||
Elem<T>: ElemVecZnx<T>,
|
||||
{
|
||||
let log_base2k: usize = b.log_base2k();
|
||||
let rows: usize = min(b.rows(), a.cols());
|
||||
let cols: usize = b.cols();
|
||||
let in_cols = a.cols();
|
||||
let out_cols: usize = a.cols();
|
||||
|
||||
let bytes_of_vec_znx_dft = module.bytes_of_vec_znx_dft(cols);
|
||||
let bytes_of_vmp_apply_dft_to_dft =
|
||||
module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, rows, cols);
|
||||
|
||||
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
||||
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
||||
let (tmp_bytes_tmp_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
||||
let (tmp_bytes_r1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
||||
let (tmp_bytes_r2_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
||||
let (bytes_of_vmp_apply_dft_to_dft, tmp_bytes) =
|
||||
tmp_bytes.split_at_mut(bytes_of_vmp_apply_dft_to_dft);
|
||||
|
||||
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c0_dft);
|
||||
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft);
|
||||
let mut tmp_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_tmp_dft);
|
||||
let mut r1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r1_dft);
|
||||
let mut r2_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r2_dft);
|
||||
|
||||
// c0_dft <- DFT(a[0])
|
||||
module.vec_znx_dft(&mut c0_dft, a.at(0), in_cols);
|
||||
|
||||
// r_dft <- sum[rows] c0_dft[cols] x RGSW[0][cols]
|
||||
module.vmp_apply_dft_to_dft(
|
||||
&mut r1_dft,
|
||||
&c1_dft,
|
||||
&b.0.value[0],
|
||||
bytes_of_vmp_apply_dft_to_dft,
|
||||
);
|
||||
|
||||
// c1_dft <- DFT(a[1])
|
||||
module.vec_znx_dft(&mut c1_dft, a.at(1), in_cols);
|
||||
}
|
||||
371
rlwe/src/gadget_product.rs
Normal file
371
rlwe/src/gadget_product.rs
Normal file
@@ -0,0 +1,371 @@
|
||||
use crate::{
|
||||
ciphertext::Ciphertext,
|
||||
elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon},
|
||||
parameters::Parameters,
|
||||
};
|
||||
use base2k::{Module, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
|
||||
use std::cmp::min;
|
||||
|
||||
pub fn gadget_product_tmp_bytes(
|
||||
module: &Module,
|
||||
log_base2k: usize,
|
||||
res_log_q: usize,
|
||||
in_log_q: usize,
|
||||
gct_rows: usize,
|
||||
gct_log_q: usize,
|
||||
) -> usize {
|
||||
let gct_cols: usize = (gct_log_q + log_base2k - 1) / log_base2k;
|
||||
let in_cols: usize = (in_log_q + log_base2k - 1) / log_base2k;
|
||||
let out_cols: usize = (res_log_q + log_base2k - 1) / log_base2k;
|
||||
module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, gct_rows, gct_cols)
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn gadget_product_tmp_bytes(
|
||||
&self,
|
||||
res_log_q: usize,
|
||||
in_log_q: usize,
|
||||
gct_rows: usize,
|
||||
gct_log_q: usize,
|
||||
) -> usize {
|
||||
gadget_product_tmp_bytes(
|
||||
self.module(),
|
||||
self.log_base2k(),
|
||||
res_log_q,
|
||||
in_log_q,
|
||||
gct_rows,
|
||||
gct_log_q,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluates the gadget product res <- a x b.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `module`: backend support for operations mod (X^N + 1).
|
||||
/// * `res`: an [Elem] to store (-cs + m * a + e, c) with res_ncols cols.
|
||||
/// * `a`: a [VecZnx] of a_ncols cols.
|
||||
/// * `b`: a [Ciphertext<VmpPMat>] as a vector of (-Bs + m * 2^{-k} + E, B)
|
||||
/// containing b_nrows [VecZnx], each of b_ncols cols.
|
||||
///
|
||||
/// # Computation
|
||||
///
|
||||
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
|
||||
/// = (cs + m * a + e, c) with min(res_cols, b_cols) cols.
|
||||
pub fn gadget_product_core<const OVERWRITE: bool, T>(
|
||||
module: &Module,
|
||||
res_dft_0: &mut VecZnxDft,
|
||||
res_dft_1: &mut VecZnxDft,
|
||||
a: &T,
|
||||
a_cols: usize,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
b_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) where
|
||||
T: VecZnxCommon<Owned = T>,
|
||||
Elem<T>: ElemVecZnx<T>,
|
||||
{
|
||||
assert!(b_cols <= b.cols());
|
||||
module.vec_znx_dft(res_dft_1, a, a_cols);
|
||||
module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes);
|
||||
module.vmp_apply_dft_to_dft_inplace(res_dft_1, b.at(1), tmp_bytes);
|
||||
}
|
||||
|
||||
/*
|
||||
// res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)])
|
||||
module.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0, b_cols);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1, b_cols);
|
||||
|
||||
// res_big <- res[0] + res_big[a*G0]
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big_0, res.at(0));
|
||||
module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big_0, tmp_bytes_carry);
|
||||
|
||||
if OVERWRITE {
|
||||
// res[1] = normalize(res_big[a*G1])
|
||||
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_1, tmp_bytes_carry);
|
||||
} else {
|
||||
// res[1] = normalize(res_big[a*G1] + res[1])
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big_1, res.at(1));
|
||||
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_1, tmp_bytes_carry);
|
||||
}
|
||||
*/
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::{
|
||||
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||
decryptor::decrypt_rlwe,
|
||||
elem::{Elem, ElemCommon, ElemVecZnx},
|
||||
encryptor::encrypt_grlwe_sk,
|
||||
gadget_product::gadget_product_core,
|
||||
keys::SecretKey,
|
||||
parameters::{Parameters, ParametersLiteral},
|
||||
plaintext::Plaintext,
|
||||
};
|
||||
use base2k::{
|
||||
FFT64, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxDft,
|
||||
VecZnxDftOps, VecZnxOps, VmpPMat,
|
||||
};
|
||||
use sampling::source::{Source, new_seed};
|
||||
|
||||
#[test]
|
||||
fn test_gadget_product_core() {
|
||||
let log_base2k: usize = 10;
|
||||
let q_cols: usize = 7;
|
||||
let p_cols: usize = 1;
|
||||
|
||||
// Basic parameters with enough limbs to test edge cases
|
||||
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||
log_n: 12,
|
||||
log_q: q_cols * log_base2k,
|
||||
log_p: p_cols * log_base2k,
|
||||
log_base2k: log_base2k,
|
||||
log_scale: 20,
|
||||
xe: 3.2,
|
||||
xs: 1 << 11,
|
||||
};
|
||||
|
||||
let params: Parameters = Parameters::new::<FFT64>(¶ms_lit);
|
||||
|
||||
// scratch space
|
||||
let mut tmp_bytes: Vec<u8> =
|
||||
vec![
|
||||
0u8;
|
||||
params.decrypt_rlwe_tmp_byte(params.log_qp())
|
||||
| params.encrypt_rlwe_sk_tmp_bytes(params.log_qp())
|
||||
| params.gadget_product_tmp_bytes(
|
||||
params.log_qp(),
|
||||
params.log_qp(),
|
||||
params.cols_qp(),
|
||||
params.log_qp()
|
||||
)
|
||||
| params.encrypt_grlwe_sk_tmp_bytes(params.cols_qp(), params.log_qp())
|
||||
];
|
||||
|
||||
// Samplers for public and private randomness
|
||||
let mut source_xe: Source = Source::new(new_seed());
|
||||
let mut source_xa: Source = Source::new(new_seed());
|
||||
let mut source_xs: Source = Source::new(new_seed());
|
||||
|
||||
// Two secret keys
|
||||
let mut sk0: SecretKey = SecretKey::new(params.module());
|
||||
sk0.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||
let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
|
||||
params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0);
|
||||
|
||||
let mut sk1: SecretKey = SecretKey::new(params.module());
|
||||
sk1.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||
let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
|
||||
params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0);
|
||||
|
||||
// The gadget ciphertext
|
||||
let mut gadget_ct: Ciphertext<VmpPMat> = new_gadget_ciphertext(
|
||||
params.module(),
|
||||
log_base2k,
|
||||
params.cols_qp(),
|
||||
params.log_qp(),
|
||||
);
|
||||
|
||||
// gct = [-b*sk1 + g(sk0) + e, b]
|
||||
encrypt_grlwe_sk(
|
||||
params.module(),
|
||||
&mut gadget_ct,
|
||||
&sk0.0,
|
||||
&sk1_svp_ppol,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
params.xe(),
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
// Intermediate buffers
|
||||
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
|
||||
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
|
||||
let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big();
|
||||
let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big();
|
||||
|
||||
// Input polynopmial, uniformly distributed
|
||||
let mut a: VecZnx = params.module().new_vec_znx(params.cols_q());
|
||||
params
|
||||
.module()
|
||||
.fill_uniform(log_base2k, &mut a, params.cols_q(), &mut source_xa);
|
||||
|
||||
// res = g^-1(a) * gct
|
||||
let mut elem_res: Elem<VecZnx> =
|
||||
Elem::<VecZnx>::new(params.module(), log_base2k, params.log_qp(), 2);
|
||||
|
||||
// Ideal output = a * s
|
||||
let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(a.cols());
|
||||
let mut a_big: VecZnxBig = a_dft.as_vec_znx_big();
|
||||
let mut a_times_s: VecZnx = params.module().new_vec_znx(a.cols());
|
||||
|
||||
// a * sk0
|
||||
params
|
||||
.module()
|
||||
.svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a, a.cols());
|
||||
params
|
||||
.module()
|
||||
.vec_znx_idft_tmp_a(&mut a_big, &mut a_dft, a.cols());
|
||||
params.module().vec_znx_big_normalize(
|
||||
params.log_base2k(),
|
||||
&mut a_times_s,
|
||||
&a_big,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
// Plaintext for decrypted output of gadget product
|
||||
let mut pt: Plaintext<VecZnx> =
|
||||
Plaintext::<VecZnx>::new(params.module(), params.log_base2k(), params.log_qp());
|
||||
|
||||
// Iterates over all possible cols values for input/output polynomials and gadget ciphertext.
|
||||
|
||||
pt.elem_mut().zero();
|
||||
elem_res.zero();
|
||||
|
||||
let a_cols: usize = a.cols() - 1;
|
||||
let b_cols: usize = gadget_ct.cols();
|
||||
|
||||
println!("a_cols: {} b_cols: {}", a_cols, b_cols);
|
||||
|
||||
// res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e')
|
||||
// res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c)
|
||||
gadget_product_core::<true, _>(
|
||||
params.module(),
|
||||
&mut res_dft_0,
|
||||
&mut res_dft_1,
|
||||
&a,
|
||||
a_cols,
|
||||
&gadget_ct,
|
||||
b_cols,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
// res_big_0 = IDFT(res_dft_0)
|
||||
params
|
||||
.module()
|
||||
.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0, b_cols);
|
||||
// res_big_1 = IDFT(res_dft_1);
|
||||
params
|
||||
.module()
|
||||
.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1, b_cols);
|
||||
|
||||
// res_big_0 = normalize(res_big_0)
|
||||
params.module().vec_znx_big_normalize(
|
||||
log_base2k,
|
||||
elem_res.at_mut(0),
|
||||
&res_big_0,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
// res_big_1 = normalize(res_big_1)
|
||||
params.module().vec_znx_big_normalize(
|
||||
log_base2k,
|
||||
elem_res.at_mut(1),
|
||||
&res_big_1,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
// <(-c*sk1 + a*sk0 + e, a), (1, sk1)> = a*sk0 + e
|
||||
decrypt_rlwe(
|
||||
params.module(),
|
||||
pt.elem_mut(),
|
||||
&elem_res,
|
||||
&sk1_svp_ppol,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
// a * sk0 + e - a*sk0 = e
|
||||
params
|
||||
.module()
|
||||
.vec_znx_sub_inplace(pt.at_mut(0), &mut a_times_s);
|
||||
pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes);
|
||||
|
||||
pt.at(0).print(pt.elem().cols(), 16);
|
||||
|
||||
println!("noise_have: {}", pt.at(0).std(log_base2k).log2());
|
||||
|
||||
let var_a_err: f64;
|
||||
|
||||
if a_cols < a.cols() {
|
||||
var_a_err = 1f64 / 12f64;
|
||||
} else {
|
||||
var_a_err = 0f64;
|
||||
}
|
||||
|
||||
let a_logq: usize = a_cols * log_base2k;
|
||||
let b_logq: usize = b_cols * log_base2k;
|
||||
let var_msg: f64 = params.xs() as f64;
|
||||
println!(
|
||||
"noise_pred: {}",
|
||||
params.noise_grlwe_product(var_msg, var_a_err, a_logq, b_logq)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn noise_grlwe_product(
|
||||
&self,
|
||||
var_msg: f64,
|
||||
var_a_err: f64,
|
||||
a_logq: usize,
|
||||
b_logq: usize,
|
||||
) -> f64 {
|
||||
let n: f64 = self.n() as f64;
|
||||
let var_xs: f64 = self.xs() as f64;
|
||||
|
||||
let var_gct_err_lhs: f64;
|
||||
let var_gct_err_rhs: f64;
|
||||
if b_logq < self.log_qp() {
|
||||
let var_round: f64 = 1f64 / 12f64;
|
||||
var_gct_err_lhs = var_round;
|
||||
var_gct_err_rhs = var_round;
|
||||
} else {
|
||||
var_gct_err_lhs = self.xe() * self.xe();
|
||||
var_gct_err_rhs = 0f64;
|
||||
}
|
||||
|
||||
noise_grlwe_product(
|
||||
n,
|
||||
self.log_base2k(),
|
||||
var_xs,
|
||||
var_msg,
|
||||
var_a_err,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
a_logq,
|
||||
b_logq,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn noise_grlwe_product(
|
||||
n: f64,
|
||||
log_base2k: usize,
|
||||
var_xs: f64,
|
||||
var_msg: f64,
|
||||
var_a_err: f64,
|
||||
var_gct_err_lhs: f64,
|
||||
var_gct_err_rhs: f64,
|
||||
a_logq: usize,
|
||||
b_logq: usize,
|
||||
) -> f64 {
|
||||
let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k;
|
||||
let b_cols: usize = (b_logq + log_base2k - 1) / log_base2k;
|
||||
|
||||
let b_scale = 2.0f64.powi(b_logq as i32);
|
||||
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
|
||||
|
||||
let base: f64 = (1 << (log_base2k)) as f64;
|
||||
let var_base: f64 = base * base / 12f64;
|
||||
let var_round: f64 = 1f64 / 12f64;
|
||||
|
||||
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
|
||||
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
|
||||
let mut noise: f64 =
|
||||
(a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
|
||||
noise += var_msg * var_a_err * a_scale * a_scale;
|
||||
noise = noise.sqrt();
|
||||
noise /= b_scale;
|
||||
noise.log2()
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes};
|
||||
use crate::encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes};
|
||||
use crate::keys::{PublicKey, SecretKey, SwitchingKey};
|
||||
use crate::parameters::Parameters;
|
||||
use base2k::{Module, SvpPPol};
|
||||
@@ -40,7 +40,7 @@ impl KeyGenerator {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_switching_key_thread_safe_tmp_bytes(
|
||||
pub fn gen_switching_key_tmp_bytes(
|
||||
module: &Module,
|
||||
log_base2k: usize,
|
||||
rows: usize,
|
||||
@@ -49,7 +49,7 @@ pub fn gen_switching_key_thread_safe_tmp_bytes(
|
||||
encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
|
||||
}
|
||||
|
||||
pub fn gen_switching_key_thread_safe(
|
||||
pub fn gen_switching_key(
|
||||
module: &Module,
|
||||
swk: &mut SwitchingKey,
|
||||
sk_in: &SecretKey,
|
||||
@@ -59,7 +59,7 @@ pub fn gen_switching_key_thread_safe(
|
||||
sigma: f64,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
encrypt_grlwe_sk_thread_safe(
|
||||
encrypt_grlwe_sk(
|
||||
module, &mut swk.0, &sk_in.0, sk_out, source_xa, source_xe, sigma, tmp_bytes,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::ciphertext::{Ciphertext, new_gadget_ciphertext};
|
||||
use crate::elem::Elem;
|
||||
use crate::elem::{Elem, ElemCommon};
|
||||
use crate::encryptor::{encrypt_rlwe_sk_thread_safe, encrypt_rlwe_sk_tmp_bytes};
|
||||
use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VmpPMat};
|
||||
use sampling::source::Source;
|
||||
|
||||
@@ -2,8 +2,9 @@ pub mod ciphertext;
|
||||
pub mod decryptor;
|
||||
pub mod elem;
|
||||
pub mod encryptor;
|
||||
pub mod evaluator;
|
||||
pub mod gadget_product;
|
||||
pub mod key_generator;
|
||||
pub mod keys;
|
||||
pub mod parameters;
|
||||
pub mod plaintext;
|
||||
pub mod rgsw_product;
|
||||
|
||||
@@ -59,11 +59,11 @@ impl Parameters {
|
||||
self.log_q + self.log_p
|
||||
}
|
||||
|
||||
pub fn limbs_q(&self) -> usize {
|
||||
pub fn cols_q(&self) -> usize {
|
||||
(self.log_q + self.log_base2k - 1) / self.log_base2k
|
||||
}
|
||||
|
||||
pub fn limbs_qp(&self) -> usize {
|
||||
pub fn cols_qp(&self) -> usize {
|
||||
(self.log_q + self.log_p + self.log_base2k - 1) / self.log_base2k
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use crate::elem::{Elem, ElemVecZnx, VecZnxCommon};
|
||||
use crate::elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon};
|
||||
use crate::parameters::Parameters;
|
||||
use base2k::{Module, VecZnx};
|
||||
|
||||
@@ -46,43 +46,61 @@ where
|
||||
Self(Elem::<T>::from_bytes(module, log_base2k, log_q, 1, bytes))
|
||||
}
|
||||
|
||||
pub fn n(&self) -> usize {
|
||||
self.0.n()
|
||||
}
|
||||
|
||||
pub fn log_q(&self) -> usize {
|
||||
self.0.log_q
|
||||
}
|
||||
|
||||
pub fn rows(&self) -> usize {
|
||||
self.0.rows()
|
||||
}
|
||||
|
||||
pub fn cols(&self) -> usize {
|
||||
self.0.cols()
|
||||
}
|
||||
|
||||
pub fn at(&self, i: usize) -> &T {
|
||||
self.0.at(i)
|
||||
}
|
||||
|
||||
pub fn at_mut(&mut self, i: usize) -> &mut T {
|
||||
self.0.at_mut(i)
|
||||
}
|
||||
|
||||
pub fn log_base2k(&self) -> usize {
|
||||
self.0.log_base2k()
|
||||
}
|
||||
|
||||
pub fn log_scale(&self) -> usize {
|
||||
self.0.log_scale()
|
||||
}
|
||||
|
||||
pub fn zero(&mut self) {
|
||||
self.0.zero()
|
||||
}
|
||||
|
||||
pub fn as_ciphertext(&self) -> Ciphertext<T> {
|
||||
unsafe { Ciphertext::<T>(std::ptr::read(&self.0)) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ElemCommon<T> for Plaintext<T>
|
||||
where
|
||||
T: VecZnxCommon<Owned = T>,
|
||||
Elem<T>: ElemVecZnx<T>,
|
||||
{
|
||||
fn n(&self) -> usize {
|
||||
self.0.n()
|
||||
}
|
||||
|
||||
fn log_n(&self) -> usize {
|
||||
self.elem().log_n()
|
||||
}
|
||||
|
||||
fn log_q(&self) -> usize {
|
||||
self.0.log_q
|
||||
}
|
||||
|
||||
fn elem(&self) -> &Elem<T> {
|
||||
&self.0
|
||||
}
|
||||
|
||||
fn elem_mut(&mut self) -> &mut Elem<T> {
|
||||
&mut self.0
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.elem().size()
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.0.rows()
|
||||
}
|
||||
|
||||
fn cols(&self) -> usize {
|
||||
self.0.cols()
|
||||
}
|
||||
|
||||
fn at(&self, i: usize) -> &T {
|
||||
self.0.at(i)
|
||||
}
|
||||
|
||||
fn at_mut(&mut self, i: usize) -> &mut T {
|
||||
self.0.at_mut(i)
|
||||
}
|
||||
|
||||
fn log_base2k(&self) -> usize {
|
||||
self.0.log_base2k()
|
||||
}
|
||||
|
||||
fn log_scale(&self) -> usize {
|
||||
self.0.log_scale()
|
||||
}
|
||||
}
|
||||
|
||||
55
rlwe/src/rgsw_product.rs
Normal file
55
rlwe/src/rgsw_product.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use crate::{
|
||||
ciphertext::Ciphertext,
|
||||
elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon},
|
||||
};
|
||||
use base2k::{Module, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
|
||||
use std::cmp::min;
|
||||
|
||||
pub fn rgsw_product<T>(
|
||||
module: &Module,
|
||||
_res: &mut Elem<T>,
|
||||
a: &Ciphertext<T>,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
tmp_bytes: &mut [u8],
|
||||
) where
|
||||
T: VecZnxCommon<Owned = T>,
|
||||
Elem<T>: ElemVecZnx<T>,
|
||||
{
|
||||
let _log_base2k: usize = b.log_base2k();
|
||||
let rows: usize = min(b.rows(), a.cols());
|
||||
let cols: usize = b.cols();
|
||||
let in_cols = a.cols();
|
||||
let out_cols: usize = a.cols();
|
||||
|
||||
let bytes_of_vec_znx_dft = module.bytes_of_vec_znx_dft(cols);
|
||||
let bytes_of_vmp_apply_dft_to_dft =
|
||||
module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, rows, cols);
|
||||
|
||||
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
||||
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
||||
let (tmp_bytes_tmp_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
||||
let (tmp_bytes_r1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
||||
let (tmp_bytes_r2_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
||||
let (bytes_of_vmp_apply_dft_to_dft, tmp_bytes) =
|
||||
tmp_bytes.split_at_mut(bytes_of_vmp_apply_dft_to_dft);
|
||||
|
||||
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c0_dft);
|
||||
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft);
|
||||
let mut _tmp_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_tmp_dft);
|
||||
let mut r1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r1_dft);
|
||||
let mut _r2_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r2_dft);
|
||||
|
||||
// c0_dft <- DFT(a[0])
|
||||
module.vec_znx_dft(&mut c0_dft, a.at(0), in_cols);
|
||||
|
||||
// r_dft <- sum[rows] c0_dft[cols] x RGSW[0][cols]
|
||||
module.vmp_apply_dft_to_dft(
|
||||
&mut r1_dft,
|
||||
&c1_dft,
|
||||
&b.0.value[0],
|
||||
bytes_of_vmp_apply_dft_to_dft,
|
||||
);
|
||||
|
||||
// c1_dft <- DFT(a[1])
|
||||
module.vec_znx_dft(&mut c1_dft, a.at(1), in_cols);
|
||||
}
|
||||
Reference in New Issue
Block a user