mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Merge branch 'dev_trace'
This commit is contained in:
2
.gitmodules
vendored
2
.gitmodules
vendored
@@ -1,3 +1,3 @@
|
|||||||
[submodule "base2k/spqlios-arithmetic"]
|
[submodule "base2k/spqlios-arithmetic"]
|
||||||
path = base2k/spqlios-arithmetic
|
path = base2k/spqlios-arithmetic
|
||||||
url = https://github.com/Pro7ech/spqlios-arithmetic
|
url = https://github.com/phantomzone-org/spqlios-arithmetic
|
||||||
|
|||||||
@@ -1,14 +1,9 @@
|
|||||||
use base2k::ffi::reim::*;
|
use base2k::ffi::reim::*;
|
||||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
|
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
|
|
||||||
fn fft(c: &mut Criterion) {
|
fn fft(c: &mut Criterion) {
|
||||||
fn forward<'a>(
|
fn forward<'a>(m: u32, log_bound: u32, reim_fft_precomp: *mut reim_fft_precomp, a: &'a [i64]) -> Box<dyn FnMut() + 'a> {
|
||||||
m: u32,
|
|
||||||
log_bound: u32,
|
|
||||||
reim_fft_precomp: *mut reim_fft_precomp,
|
|
||||||
a: &'a [i64],
|
|
||||||
) -> Box<dyn FnMut() + 'a> {
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let buf_a: *mut f64 = reim_fft_precomp_get_buffer(reim_fft_precomp, 0);
|
let buf_a: *mut f64 = reim_fft_precomp_get_buffer(reim_fft_precomp, 0);
|
||||||
reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr());
|
reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr());
|
||||||
@@ -16,12 +11,7 @@ fn fft(c: &mut Criterion) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn backward<'a>(
|
fn backward<'a>(m: u32, log_bound: u32, reim_ifft_precomp: *mut reim_ifft_precomp, a: &'a [i64]) -> Box<dyn FnMut() + 'a> {
|
||||||
m: u32,
|
|
||||||
log_bound: u32,
|
|
||||||
reim_ifft_precomp: *mut reim_ifft_precomp,
|
|
||||||
a: &'a [i64],
|
|
||||||
) -> Box<dyn FnMut() + 'a> {
|
|
||||||
Box::new(move || unsafe {
|
Box::new(move || unsafe {
|
||||||
let buf_a: *mut f64 = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0);
|
let buf_a: *mut f64 = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0);
|
||||||
reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr());
|
reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr());
|
||||||
@@ -29,8 +19,7 @@ fn fft(c: &mut Criterion) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
|
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("fft");
|
||||||
c.benchmark_group("fft");
|
|
||||||
|
|
||||||
for log_n in 10..17 {
|
for log_n in 10..17 {
|
||||||
let n: usize = 1 << log_n;
|
let n: usize = 1 << log_n;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use base2k::{
|
use base2k::{
|
||||||
alloc_aligned, Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx,
|
BACKEND, Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
|
||||||
VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, BACKEND,
|
VecZnxDftOps, VecZnxOps, alloc_aligned,
|
||||||
};
|
};
|
||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
@@ -38,13 +38,13 @@ fn main() {
|
|||||||
let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.cols());
|
let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.cols());
|
||||||
|
|
||||||
// Applies buf_dft <- s * a
|
// Applies buf_dft <- s * a
|
||||||
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a, a.cols());
|
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
|
||||||
|
|
||||||
// Alias scratch space
|
// Alias scratch space
|
||||||
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
|
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
|
||||||
|
|
||||||
// buf_big <- IDFT(buf_dft) (not normalized)
|
// buf_big <- IDFT(buf_dft) (not normalized)
|
||||||
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, a.cols());
|
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
|
||||||
|
|
||||||
let mut m: VecZnx = module.new_vec_znx(msg_cols);
|
let mut m: VecZnx = module.new_vec_znx(msg_cols);
|
||||||
|
|
||||||
@@ -71,11 +71,11 @@ fn main() {
|
|||||||
19.0,
|
19.0,
|
||||||
);
|
);
|
||||||
|
|
||||||
//Decrypt
|
// Decrypt
|
||||||
|
|
||||||
// buf_big <- a * s
|
// buf_big <- a * s
|
||||||
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a, a.cols());
|
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
|
||||||
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, b.cols());
|
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
|
||||||
|
|
||||||
// buf_big <- a * s + b
|
// buf_big <- a * s + b
|
||||||
module.vec_znx_big_add_small_inplace(&mut buf_big, &b);
|
module.vec_znx_big_add_small_inplace(&mut buf_big, &b);
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use base2k::{
|
use base2k::{
|
||||||
alloc_aligned, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
|
BACKEND, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VecZnxVec, VmpPMat,
|
||||||
VecZnxDftOps, VecZnxOps, VecZnxVec, VmpPMat, VmpPMatOps, BACKEND,
|
VmpPMatOps, alloc_aligned,
|
||||||
};
|
};
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
@@ -16,8 +16,7 @@ fn main() {
|
|||||||
let cols: usize = cols + 1;
|
let cols: usize = cols + 1;
|
||||||
|
|
||||||
// Maximum size of the byte scratch needed
|
// Maximum size of the byte scratch needed
|
||||||
let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows, cols)
|
let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows, cols) | module.vmp_apply_dft_tmp_bytes(cols, cols, rows, cols);
|
||||||
| module.vmp_apply_dft_tmp_bytes(cols, cols, rows, cols);
|
|
||||||
|
|
||||||
let mut buf: Vec<u8> = alloc_aligned(tmp_bytes);
|
let mut buf: Vec<u8> = alloc_aligned(tmp_bytes);
|
||||||
|
|
||||||
@@ -49,7 +48,7 @@ fn main() {
|
|||||||
module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf);
|
module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf);
|
||||||
|
|
||||||
let mut c_big: VecZnxBig = c_dft.as_vec_znx_big();
|
let mut c_big: VecZnxBig = c_dft.as_vec_znx_big();
|
||||||
module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft, cols);
|
module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft);
|
||||||
|
|
||||||
let mut res: VecZnx = module.new_vec_znx(cols);
|
let mut res: VecZnx = module.new_vec_znx(cols);
|
||||||
module.vec_znx_big_normalize(log_base2k, &mut res, &c_big, &mut buf);
|
module.vec_znx_big_normalize(log_base2k, &mut res, &c_big, &mut buf);
|
||||||
|
|||||||
Submodule base2k/spqlios-arithmetic updated: 07f3c8d2b8...e3d3247335
@@ -40,14 +40,7 @@ pub trait Encoding {
|
|||||||
/// * `i`: index of the coefficient on which to encode the data.
|
/// * `i`: index of the coefficient on which to encode the data.
|
||||||
/// * `data`: data to encode on the receiver.
|
/// * `data`: data to encode on the receiver.
|
||||||
/// * `log_max`: base two logarithm of the infinity norm of the input data.
|
/// * `log_max`: base two logarithm of the infinity norm of the input data.
|
||||||
fn encode_coeff_i64(
|
fn encode_coeff_i64(&mut self, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize);
|
||||||
&mut self,
|
|
||||||
log_base2k: usize,
|
|
||||||
log_k: usize,
|
|
||||||
i: usize,
|
|
||||||
data: i64,
|
|
||||||
log_max: usize,
|
|
||||||
);
|
|
||||||
|
|
||||||
/// decode a single of i64 from the receiver at the given index.
|
/// decode a single of i64 from the receiver at the given index.
|
||||||
///
|
///
|
||||||
@@ -73,14 +66,7 @@ impl Encoding for VecZnx {
|
|||||||
decode_vec_float(self, log_base2k, data)
|
decode_vec_float(self, log_base2k, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn encode_coeff_i64(
|
fn encode_coeff_i64(&mut self, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) {
|
||||||
&mut self,
|
|
||||||
log_base2k: usize,
|
|
||||||
log_k: usize,
|
|
||||||
i: usize,
|
|
||||||
value: i64,
|
|
||||||
log_max: usize,
|
|
||||||
) {
|
|
||||||
encode_coeff_i64(self, log_base2k, log_k, i, value, log_max)
|
encode_coeff_i64(self, log_base2k, log_k, i, value, log_max)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,8 +105,7 @@ fn encode_vec_i64(a: &mut VecZnx, log_base2k: usize, log_k: usize, data: &[i64],
|
|||||||
.enumerate()
|
.enumerate()
|
||||||
.for_each(|(i, i_rev)| {
|
.for_each(|(i, i_rev)| {
|
||||||
let shift: usize = i * log_base2k;
|
let shift: usize = i * log_base2k;
|
||||||
izip!(a.at_mut(i_rev)[..size].iter_mut(), data[..size].iter())
|
izip!(a.at_mut(i_rev)[..size].iter_mut(), data[..size].iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
|
||||||
.for_each(|(y, x)| *y = (x >> shift) & mask);
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,14 +174,7 @@ fn decode_vec_float(a: &VecZnx, log_base2k: usize, data: &mut [Float]) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn encode_coeff_i64(
|
fn encode_coeff_i64(a: &mut VecZnx, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) {
|
||||||
a: &mut VecZnx,
|
|
||||||
log_base2k: usize,
|
|
||||||
log_k: usize,
|
|
||||||
i: usize,
|
|
||||||
value: i64,
|
|
||||||
log_max: usize,
|
|
||||||
) {
|
|
||||||
debug_assert!(i < a.n());
|
debug_assert!(i < a.n());
|
||||||
let cols: usize = (log_k + log_base2k - 1) / log_base2k;
|
let cols: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
|
|||||||
@@ -62,10 +62,7 @@ unsafe extern "C" {
|
|||||||
pub unsafe fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP;
|
pub unsafe fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP;
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_fft_precomp_get_buffer(
|
pub unsafe fn reim_fft_precomp_get_buffer(tables: *const REIM_FFT_PRECOMP, buffer_index: u32) -> *mut f64;
|
||||||
tables: *const REIM_FFT_PRECOMP,
|
|
||||||
buffer_index: u32,
|
|
||||||
) -> *mut f64;
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn new_reim_fft_buffer(m: u32) -> *mut f64;
|
pub unsafe fn new_reim_fft_buffer(m: u32) -> *mut f64;
|
||||||
@@ -80,10 +77,7 @@ unsafe extern "C" {
|
|||||||
pub unsafe fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP;
|
pub unsafe fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP;
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_ifft_precomp_get_buffer(
|
pub unsafe fn reim_ifft_precomp_get_buffer(tables: *const REIM_IFFT_PRECOMP, buffer_index: u32) -> *mut f64;
|
||||||
tables: *const REIM_IFFT_PRECOMP,
|
|
||||||
buffer_index: u32,
|
|
||||||
) -> *mut f64;
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64);
|
pub unsafe fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64);
|
||||||
@@ -92,120 +86,58 @@ unsafe extern "C" {
|
|||||||
pub unsafe fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP;
|
pub unsafe fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP;
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_fftvec_mul(
|
pub unsafe fn reim_fftvec_mul(tables: *const REIM_FFTVEC_MUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64);
|
||||||
tables: *const REIM_FFTVEC_MUL_PRECOMP,
|
|
||||||
r: *mut f64,
|
|
||||||
a: *const f64,
|
|
||||||
b: *const f64,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP;
|
pub unsafe fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP;
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_fftvec_addmul(
|
pub unsafe fn reim_fftvec_addmul(tables: *const REIM_FFTVEC_ADDMUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64);
|
||||||
tables: *const REIM_FFTVEC_ADDMUL_PRECOMP,
|
|
||||||
r: *mut f64,
|
|
||||||
a: *const f64,
|
|
||||||
b: *const f64,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn new_reim_from_znx32_precomp(
|
pub unsafe fn new_reim_from_znx32_precomp(m: u32, log2bound: u32) -> *mut REIM_FROM_ZNX32_PRECOMP;
|
||||||
m: u32,
|
|
||||||
log2bound: u32,
|
|
||||||
) -> *mut REIM_FROM_ZNX32_PRECOMP;
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_from_znx32(
|
pub unsafe fn reim_from_znx32(tables: *const REIM_FROM_ZNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32);
|
||||||
tables: *const REIM_FROM_ZNX32_PRECOMP,
|
|
||||||
r: *mut ::std::os::raw::c_void,
|
|
||||||
a: *const i32,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_from_znx64(
|
pub unsafe fn reim_from_znx64(tables: *const REIM_FROM_ZNX64_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i64);
|
||||||
tables: *const REIM_FROM_ZNX64_PRECOMP,
|
|
||||||
r: *mut ::std::os::raw::c_void,
|
|
||||||
a: *const i64,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP;
|
pub unsafe fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP;
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_from_znx64_simple(
|
pub unsafe fn reim_from_znx64_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, a: *const i64);
|
||||||
m: u32,
|
|
||||||
log2bound: u32,
|
|
||||||
r: *mut ::std::os::raw::c_void,
|
|
||||||
a: *const i64,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP;
|
pub unsafe fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP;
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_from_tnx32(
|
pub unsafe fn reim_from_tnx32(tables: *const REIM_FROM_TNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32);
|
||||||
tables: *const REIM_FROM_TNX32_PRECOMP,
|
|
||||||
r: *mut ::std::os::raw::c_void,
|
|
||||||
a: *const i32,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn new_reim_to_tnx32_precomp(
|
pub unsafe fn new_reim_to_tnx32_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX32_PRECOMP;
|
||||||
m: u32,
|
|
||||||
divisor: f64,
|
|
||||||
log2overhead: u32,
|
|
||||||
) -> *mut REIM_TO_TNX32_PRECOMP;
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_to_tnx32(
|
pub unsafe fn reim_to_tnx32(tables: *const REIM_TO_TNX32_PRECOMP, r: *mut i32, a: *const ::std::os::raw::c_void);
|
||||||
tables: *const REIM_TO_TNX32_PRECOMP,
|
|
||||||
r: *mut i32,
|
|
||||||
a: *const ::std::os::raw::c_void,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn new_reim_to_tnx_precomp(
|
pub unsafe fn new_reim_to_tnx_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX_PRECOMP;
|
||||||
m: u32,
|
|
||||||
divisor: f64,
|
|
||||||
log2overhead: u32,
|
|
||||||
) -> *mut REIM_TO_TNX_PRECOMP;
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64);
|
pub unsafe fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64);
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_to_tnx_simple(
|
pub unsafe fn reim_to_tnx_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut f64, a: *const f64);
|
||||||
m: u32,
|
|
||||||
divisor: f64,
|
|
||||||
log2overhead: u32,
|
|
||||||
r: *mut f64,
|
|
||||||
a: *const f64,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn new_reim_to_znx64_precomp(
|
pub unsafe fn new_reim_to_znx64_precomp(m: u32, divisor: f64, log2bound: u32) -> *mut REIM_TO_ZNX64_PRECOMP;
|
||||||
m: u32,
|
|
||||||
divisor: f64,
|
|
||||||
log2bound: u32,
|
|
||||||
) -> *mut REIM_TO_ZNX64_PRECOMP;
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_to_znx64(
|
pub unsafe fn reim_to_znx64(precomp: *const REIM_TO_ZNX64_PRECOMP, r: *mut i64, a: *const ::std::os::raw::c_void);
|
||||||
precomp: *const REIM_TO_ZNX64_PRECOMP,
|
|
||||||
r: *mut i64,
|
|
||||||
a: *const ::std::os::raw::c_void,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_to_znx64_simple(
|
pub unsafe fn reim_to_znx64_simple(m: u32, divisor: f64, log2bound: u32, r: *mut i64, a: *const ::std::os::raw::c_void);
|
||||||
m: u32,
|
|
||||||
divisor: f64,
|
|
||||||
log2bound: u32,
|
|
||||||
r: *mut i64,
|
|
||||||
a: *const ::std::os::raw::c_void,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void);
|
pub unsafe fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void);
|
||||||
@@ -230,22 +162,11 @@ unsafe extern "C" {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_from_znx32_simple(
|
pub unsafe fn reim_from_znx32_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
|
||||||
m: u32,
|
|
||||||
log2bound: u32,
|
|
||||||
r: *mut ::std::os::raw::c_void,
|
|
||||||
x: *const i32,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
|
pub unsafe fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn reim_to_tnx32_simple(
|
pub unsafe fn reim_to_tnx32_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut i32, x: *const ::std::os::raw::c_void);
|
||||||
m: u32,
|
|
||||||
divisor: f64,
|
|
||||||
log2overhead: u32,
|
|
||||||
r: *mut i32,
|
|
||||||
x: *const ::std::os::raw::c_void,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,14 +44,7 @@ unsafe extern "C" {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn vec_znx_dft(
|
pub unsafe fn vec_znx_dft(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64, a: *const i64, a_size: u64, a_sl: u64);
|
||||||
module: *const MODULE,
|
|
||||||
res: *mut VEC_ZNX_DFT,
|
|
||||||
res_size: u64,
|
|
||||||
a: *const i64,
|
|
||||||
a_size: u64,
|
|
||||||
a_sl: u64,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn vec_znx_idft(
|
pub unsafe fn vec_znx_idft(
|
||||||
|
|||||||
@@ -37,13 +37,22 @@ unsafe extern "C" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn vmp_apply_dft_tmp_bytes(
|
pub unsafe fn vmp_apply_dft_add(
|
||||||
module: *const MODULE,
|
module: *const MODULE,
|
||||||
|
res: *mut VEC_ZNX_DFT,
|
||||||
res_size: u64,
|
res_size: u64,
|
||||||
|
a: *const i64,
|
||||||
a_size: u64,
|
a_size: u64,
|
||||||
|
a_sl: u64,
|
||||||
|
pmat: *const VMP_PMAT,
|
||||||
nrows: u64,
|
nrows: u64,
|
||||||
ncols: u64,
|
ncols: u64,
|
||||||
) -> u64;
|
tmp_space: *mut u8,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe extern "C" {
|
||||||
|
pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
@@ -60,6 +69,20 @@ unsafe extern "C" {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsafe extern "C" {
|
||||||
|
pub unsafe fn vmp_apply_dft_to_dft_add(
|
||||||
|
module: *const MODULE,
|
||||||
|
res: *mut VEC_ZNX_DFT,
|
||||||
|
res_size: u64,
|
||||||
|
a_dft: *const VEC_ZNX_DFT,
|
||||||
|
a_size: u64,
|
||||||
|
pmat: *const VMP_PMAT,
|
||||||
|
nrows: u64,
|
||||||
|
ncols: u64,
|
||||||
|
tmp_space: *mut u8,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
|
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||||
module: *const MODULE,
|
module: *const MODULE,
|
||||||
|
|||||||
@@ -64,24 +64,11 @@ unsafe extern "C" {
|
|||||||
pub unsafe fn rnx_mul_xp_minus_one_inplace(nn: u64, p: i64, res: *mut f64);
|
pub unsafe fn rnx_mul_xp_minus_one_inplace(nn: u64, p: i64, res: *mut f64);
|
||||||
}
|
}
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn znx_normalize(
|
pub unsafe fn znx_normalize(nn: u64, base_k: u64, out: *mut i64, carry_out: *mut i64, in_: *const i64, carry_in: *const i64);
|
||||||
nn: u64,
|
|
||||||
base_k: u64,
|
|
||||||
out: *mut i64,
|
|
||||||
carry_out: *mut i64,
|
|
||||||
in_: *const i64,
|
|
||||||
carry_in: *const i64,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn znx_small_single_product(
|
pub unsafe fn znx_small_single_product(module: *const MODULE, res: *mut i64, a: *const i64, b: *const i64, tmp: *mut u8);
|
||||||
module: *const MODULE,
|
|
||||||
res: *mut i64,
|
|
||||||
a: *const i64,
|
|
||||||
b: *const i64,
|
|
||||||
tmp: *mut u8,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
|
|||||||
@@ -1,11 +1,5 @@
|
|||||||
pub mod encoding;
|
pub mod encoding;
|
||||||
#[allow(
|
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
|
||||||
non_camel_case_types,
|
|
||||||
non_snake_case,
|
|
||||||
non_upper_case_globals,
|
|
||||||
dead_code,
|
|
||||||
improper_ctypes
|
|
||||||
)]
|
|
||||||
// Other modules and exports
|
// Other modules and exports
|
||||||
pub mod ffi;
|
pub mod ffi;
|
||||||
pub mod infos;
|
pub mod infos;
|
||||||
@@ -42,7 +36,10 @@ pub fn is_aligned<T>(ptr: *const T) -> bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn assert_alignement<T>(ptr: *const T) {
|
pub fn assert_alignement<T>(ptr: *const T) {
|
||||||
assert!(is_aligned(ptr), "invalid alignement: ensure passed bytes have been allocated with [alloc_aligned_u8] or [alloc_aligned]")
|
assert!(
|
||||||
|
is_aligned(ptr),
|
||||||
|
"invalid alignement: ensure passed bytes have been allocated with [alloc_aligned_u8] or [alloc_aligned]"
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cast<T, V>(data: &[T]) -> &[V] {
|
pub fn cast<T, V>(data: &[T]) -> &[V] {
|
||||||
@@ -57,7 +54,7 @@ pub fn cast_mut<T, V>(data: &[T]) -> &mut [V] {
|
|||||||
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
|
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
|
||||||
}
|
}
|
||||||
|
|
||||||
use std::alloc::{alloc, Layout};
|
use std::alloc::{Layout, alloc};
|
||||||
use std::ptr;
|
use std::ptr;
|
||||||
|
|
||||||
/// Allocates a block of bytes with a custom alignement.
|
/// Allocates a block of bytes with a custom alignement.
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::ffi::module::{delete_module_info, module_info_t, new_module_info, MODULE};
|
|
||||||
use crate::GALOISGENERATOR;
|
use crate::GALOISGENERATOR;
|
||||||
|
use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info};
|
||||||
|
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
#[repr(u8)]
|
#[repr(u8)]
|
||||||
@@ -51,27 +51,21 @@ impl Module {
|
|||||||
(self.n() << 1) as _
|
(self.n() << 1) as _
|
||||||
}
|
}
|
||||||
|
|
||||||
// GALOISGENERATOR^|gen| * sign(gen)
|
// Returns GALOISGENERATOR^|gen| * sign(gen)
|
||||||
pub fn galois_element(&self, gen: i64) -> i64 {
|
pub fn galois_element(&self, gen: i64) -> i64 {
|
||||||
if gen == 0 {
|
if gen == 0 {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
((mod_exp_u64(GALOISGENERATOR, gen.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * gen.signum()
|
||||||
|
}
|
||||||
|
|
||||||
let mut gal_el: u64 = 1;
|
// Returns gen^-1
|
||||||
let mut gen_1_pow: u64 = GALOISGENERATOR;
|
pub fn galois_element_inv(&self, gen: i64) -> i64 {
|
||||||
let mut e: usize = gen.abs() as usize;
|
if gen == 0 {
|
||||||
while e > 0 {
|
panic!("cannot invert 0")
|
||||||
if e & 1 == 1 {
|
|
||||||
gal_el = gal_el.wrapping_mul(gen_1_pow);
|
|
||||||
}
|
|
||||||
|
|
||||||
gen_1_pow = gen_1_pow.wrapping_mul(gen_1_pow);
|
|
||||||
e >>= 1;
|
|
||||||
}
|
}
|
||||||
|
((mod_exp_u64(gen.abs() as u64, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1)) as i64)
|
||||||
gal_el &= self.cyclotomic_order() - 1;
|
* gen.signum()
|
||||||
|
|
||||||
(gal_el as i64) * gen.signum()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn free(self) {
|
pub fn free(self) {
|
||||||
@@ -79,3 +73,17 @@ impl Module {
|
|||||||
drop(self);
|
drop(self);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mod_exp_u64(x: u64, e: usize) -> u64 {
|
||||||
|
let mut y: u64 = 1;
|
||||||
|
let mut x_pow: u64 = x;
|
||||||
|
let mut exp = e;
|
||||||
|
while exp > 0 {
|
||||||
|
if exp & 1 == 1 {
|
||||||
|
y = y.wrapping_mul(x_pow);
|
||||||
|
}
|
||||||
|
x_pow = x_pow.wrapping_mul(x_pow);
|
||||||
|
exp >>= 1;
|
||||||
|
}
|
||||||
|
y
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,15 +18,7 @@ pub trait Sampling {
|
|||||||
);
|
);
|
||||||
|
|
||||||
/// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\].
|
/// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\].
|
||||||
fn add_normal(
|
fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64);
|
||||||
&self,
|
|
||||||
log_base2k: usize,
|
|
||||||
a: &mut VecZnx,
|
|
||||||
log_k: usize,
|
|
||||||
source: &mut Source,
|
|
||||||
sigma: f64,
|
|
||||||
bound: f64,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Sampling for Module {
|
impl Sampling for Module {
|
||||||
@@ -63,7 +55,7 @@ impl Sampling for Module {
|
|||||||
while dist_f64.abs() > bound {
|
while dist_f64.abs() > bound {
|
||||||
dist_f64 = dist.sample(source)
|
dist_f64 = dist.sample(source)
|
||||||
}
|
}
|
||||||
*a += (dist_f64.round() as i64) << log_base2k_rem
|
*a += (dist_f64.round() as i64) << log_base2k_rem;
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
a.at_mut(a.cols() - 1).iter_mut().for_each(|a| {
|
a.at_mut(a.cols() - 1).iter_mut().for_each(|a| {
|
||||||
@@ -76,15 +68,7 @@ impl Sampling for Module {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_normal(
|
fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64) {
|
||||||
&self,
|
|
||||||
log_base2k: usize,
|
|
||||||
a: &mut VecZnx,
|
|
||||||
log_k: usize,
|
|
||||||
source: &mut Source,
|
|
||||||
sigma: f64,
|
|
||||||
bound: f64,
|
|
||||||
) {
|
|
||||||
self.add_dist_f64(
|
self.add_dist_f64(
|
||||||
log_base2k,
|
log_base2k,
|
||||||
a,
|
a,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use crate::{Encoding, Infos, VecZnx};
|
use crate::{Encoding, Infos, VecZnx};
|
||||||
|
use rug::Float;
|
||||||
use rug::float::Round;
|
use rug::float::Round;
|
||||||
use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound};
|
use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound};
|
||||||
use rug::Float;
|
|
||||||
|
|
||||||
impl VecZnx {
|
impl VecZnx {
|
||||||
pub fn std(&self, log_base2k: usize) -> f64 {
|
pub fn std(&self, log_base2k: usize) -> f64 {
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
use crate::ffi::svp;
|
use crate::ffi::svp::{self, svp_ppol_t};
|
||||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||||
use crate::{assert_alignement, Module, VecZnx, VecZnxDft};
|
use crate::{BACKEND, Module, VecZnx, VecZnxDft, assert_alignement};
|
||||||
|
|
||||||
use crate::{alloc_aligned, cast_mut, Infos};
|
use crate::{Infos, alloc_aligned, cast_mut};
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
use rand_core::RngCore;
|
use rand_core::RngCore;
|
||||||
use rand_distr::{Distribution, WeightedIndex};
|
use rand_distr::{Distribution, WeightedIndex};
|
||||||
@@ -35,15 +35,15 @@ impl Scalar {
|
|||||||
self.n
|
self.n
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn buffer_size(n: usize) -> usize {
|
pub fn bytes_of(n: usize) -> usize {
|
||||||
n
|
n * std::mem::size_of::<i64>()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_buffer(&mut self, n: usize, bytes: &mut [u8]) -> Self {
|
pub fn from_bytes(n: usize, bytes: &mut [u8]) -> Self {
|
||||||
let size: usize = Self::buffer_size(n);
|
let size: usize = Self::bytes_of(n);
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
bytes.len() == size,
|
bytes.len() == size,
|
||||||
"invalid buffer: bytes.len()={} < self.buffer_size(n={})={}",
|
"invalid buffer: bytes.len()={} < self.bytes_of(n={})={}",
|
||||||
bytes.len(),
|
bytes.len(),
|
||||||
n,
|
n,
|
||||||
size
|
size
|
||||||
@@ -63,11 +63,37 @@ impl Scalar {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn from_bytes_borrow(n: usize, bytes: &mut [u8]) -> Self {
|
||||||
|
let size: usize = Self::bytes_of(n);
|
||||||
|
debug_assert!(
|
||||||
|
bytes.len() == size,
|
||||||
|
"invalid buffer: bytes.len()={} < self.bytes_of(n={})={}",
|
||||||
|
bytes.len(),
|
||||||
|
n,
|
||||||
|
size
|
||||||
|
);
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_alignement(bytes.as_ptr())
|
||||||
|
}
|
||||||
|
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
|
||||||
|
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
|
||||||
|
Self {
|
||||||
|
n: n,
|
||||||
|
data: Vec::new(),
|
||||||
|
ptr: ptr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn as_ptr(&self) -> *const i64 {
|
pub fn as_ptr(&self) -> *const i64 {
|
||||||
self.ptr
|
self.ptr
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn raw(&self) -> &[i64] {
|
pub fn raw(&self) -> &[i64] {
|
||||||
|
unsafe { std::slice::from_raw_parts(self.ptr, self.n) }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn raw_mut(&self) -> &mut [i64] {
|
||||||
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n) }
|
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n) }
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,26 +113,89 @@ impl Scalar {
|
|||||||
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
|
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
|
||||||
self.data.shuffle(source);
|
self.data.shuffle(source);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn as_vec_znx(&self) -> VecZnx {
|
||||||
|
VecZnx {
|
||||||
|
n: self.n,
|
||||||
|
cols: 1,
|
||||||
|
data: Vec::new(),
|
||||||
|
ptr: self.ptr,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct SvpPPol(pub *mut svp::svp_ppol_t, pub usize);
|
pub trait ScalarOps {
|
||||||
|
fn bytes_of_scalar(&self) -> usize;
|
||||||
|
fn new_scalar(&self) -> Scalar;
|
||||||
|
fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar;
|
||||||
|
fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar;
|
||||||
|
}
|
||||||
|
impl ScalarOps for Module {
|
||||||
|
fn bytes_of_scalar(&self) -> usize {
|
||||||
|
Scalar::bytes_of(self.n())
|
||||||
|
}
|
||||||
|
fn new_scalar(&self) -> Scalar {
|
||||||
|
Scalar::new(self.n())
|
||||||
|
}
|
||||||
|
fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar {
|
||||||
|
Scalar::from_bytes(self.n(), bytes)
|
||||||
|
}
|
||||||
|
fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar {
|
||||||
|
Scalar::from_bytes_borrow(self.n(), tmp_bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SvpPPol {
|
||||||
|
pub n: usize,
|
||||||
|
pub data: Vec<u8>,
|
||||||
|
pub ptr: *mut u8,
|
||||||
|
pub backend: BACKEND,
|
||||||
|
}
|
||||||
|
|
||||||
/// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft].
|
/// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft].
|
||||||
/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb.
|
/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb.
|
||||||
/// The backend array of an [SvpPPol] is allocated in C and must be freed manually.
|
|
||||||
impl SvpPPol {
|
impl SvpPPol {
|
||||||
/// Returns the ring degree of the [SvpPPol].
|
pub fn new(module: &Module) -> Self {
|
||||||
pub fn n(&self) -> usize {
|
module.new_svp_ppol()
|
||||||
self.1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_bytes(size: usize, bytes: &mut [u8]) -> SvpPPol {
|
/// Returns the ring degree of the [SvpPPol].
|
||||||
|
pub fn n(&self) -> usize {
|
||||||
|
self.n
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bytes_of(module: &Module) -> usize {
|
||||||
|
module.bytes_of_svp_ppol()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_bytes(module: &Module, bytes: &mut [u8]) -> SvpPPol {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_alignement(bytes.as_ptr())
|
assert_alignement(bytes.as_ptr());
|
||||||
|
assert_eq!(bytes.len(), module.bytes_of_svp_ppol());
|
||||||
|
}
|
||||||
|
unsafe {
|
||||||
|
Self {
|
||||||
|
n: module.n(),
|
||||||
|
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
|
||||||
|
ptr: bytes.as_mut_ptr(),
|
||||||
|
backend: module.backend(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_bytes_borrow(module: &Module, tmp_bytes: &mut [u8]) -> SvpPPol {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
|
assert_eq!(tmp_bytes.len(), module.bytes_of_svp_ppol());
|
||||||
|
}
|
||||||
|
Self {
|
||||||
|
n: module.n(),
|
||||||
|
data: Vec::new(),
|
||||||
|
ptr: tmp_bytes.as_mut_ptr(),
|
||||||
|
backend: module.backend(),
|
||||||
}
|
}
|
||||||
debug_assert!(bytes.len() << 3 >= size);
|
|
||||||
SvpPPol(bytes.as_mut_ptr() as *mut svp::svp_ppol_t, size)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of cols of the [SvpPPol], which is always 1.
|
/// Returns the number of cols of the [SvpPPol], which is always 1.
|
||||||
@@ -120,45 +209,64 @@ pub trait SvpPPolOps {
|
|||||||
fn new_svp_ppol(&self) -> SvpPPol;
|
fn new_svp_ppol(&self) -> SvpPPol;
|
||||||
|
|
||||||
/// Returns the minimum number of bytes necessary to allocate
|
/// Returns the minimum number of bytes necessary to allocate
|
||||||
/// a new [SvpPPol] through [SvpPPol::from_bytes].
|
/// a new [SvpPPol] through [SvpPPol::from_bytes] ro.
|
||||||
fn bytes_of_svp_ppol(&self) -> usize;
|
fn bytes_of_svp_ppol(&self) -> usize;
|
||||||
|
|
||||||
|
/// Allocates a new [SvpPPol] from an array of bytes.
|
||||||
|
/// The array of bytes is owned by the [SvpPPol].
|
||||||
|
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
|
||||||
|
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol;
|
||||||
|
|
||||||
|
/// Allocates a new [SvpPPol] from an array of bytes.
|
||||||
|
/// The array of bytes is borrowed by the [SvpPPol].
|
||||||
|
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
|
||||||
|
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol;
|
||||||
|
|
||||||
/// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft].
|
/// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft].
|
||||||
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar);
|
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar);
|
||||||
|
|
||||||
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
|
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
|
||||||
/// the [VecZnxDft] is multiplied with [SvpPPol].
|
/// the [VecZnxDft] is multiplied with [SvpPPol].
|
||||||
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_cols: usize);
|
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SvpPPolOps for Module {
|
impl SvpPPolOps for Module {
|
||||||
fn new_svp_ppol(&self) -> SvpPPol {
|
fn new_svp_ppol(&self) -> SvpPPol {
|
||||||
unsafe { SvpPPol(svp::new_svp_ppol(self.ptr), self.n()) }
|
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_svp_ppol());
|
||||||
|
let ptr: *mut u8 = data.as_mut_ptr();
|
||||||
|
SvpPPol {
|
||||||
|
data: data,
|
||||||
|
ptr: ptr,
|
||||||
|
n: self.n(),
|
||||||
|
backend: self.backend(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn bytes_of_svp_ppol(&self) -> usize {
|
fn bytes_of_svp_ppol(&self) -> usize {
|
||||||
unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize }
|
unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) {
|
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol {
|
||||||
unsafe { svp::svp_prepare(self.ptr, svp_ppol.0, a.as_ptr()) }
|
SvpPPol::from_bytes(self, bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_cols: usize) {
|
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol {
|
||||||
debug_assert!(
|
SvpPPol::from_bytes_borrow(self, tmp_bytes)
|
||||||
c.cols() >= b_cols,
|
}
|
||||||
"invalid c_vector: c_vector.cols()={} < b.cols()={}",
|
|
||||||
c.cols(),
|
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) {
|
||||||
b_cols
|
unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) }
|
||||||
);
|
}
|
||||||
|
|
||||||
|
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) {
|
||||||
unsafe {
|
unsafe {
|
||||||
svp::svp_apply_dft(
|
svp::svp_apply_dft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
c.ptr as *mut vec_znx_dft_t,
|
c.ptr as *mut vec_znx_dft_t,
|
||||||
b_cols as u64,
|
c.cols() as u64,
|
||||||
a.0,
|
a.ptr as *const svp_ppol_t,
|
||||||
b.as_ptr(),
|
b.as_ptr(),
|
||||||
b_cols as u64,
|
b.cols() as u64,
|
||||||
b.n() as u64,
|
b.n() as u64,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
use crate::cast_mut;
|
use crate::cast_mut;
|
||||||
use crate::ffi::vec_znx;
|
use crate::ffi::vec_znx;
|
||||||
use crate::ffi::znx;
|
use crate::ffi::znx;
|
||||||
use crate::{alloc_aligned, assert_alignement};
|
|
||||||
use crate::{Infos, Module};
|
use crate::{Infos, Module};
|
||||||
|
use crate::{alloc_aligned, assert_alignement};
|
||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
|
||||||
@@ -12,16 +12,16 @@ use std::cmp::min;
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct VecZnx {
|
pub struct VecZnx {
|
||||||
/// Polynomial degree.
|
/// Polynomial degree.
|
||||||
n: usize,
|
pub n: usize,
|
||||||
|
|
||||||
/// Number of columns.
|
/// Number of columns.
|
||||||
cols: usize,
|
pub cols: usize,
|
||||||
|
|
||||||
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
|
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
|
||||||
data: Vec<i64>,
|
pub data: Vec<i64>,
|
||||||
|
|
||||||
/// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it).
|
/// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it).
|
||||||
ptr: *mut i64,
|
pub ptr: *mut i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait VecZnxVec {
|
pub trait VecZnxVec {
|
||||||
@@ -347,8 +347,11 @@ pub trait VecZnxOps {
|
|||||||
/// c <- a - b.
|
/// c <- a - b.
|
||||||
fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
|
fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
|
||||||
|
|
||||||
|
/// b <- a - b.
|
||||||
|
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx);
|
||||||
|
|
||||||
/// b <- b - a.
|
/// b <- b - a.
|
||||||
fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx);
|
fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx);
|
||||||
|
|
||||||
/// b <- -a.
|
/// b <- -a.
|
||||||
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx);
|
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx);
|
||||||
@@ -363,10 +366,10 @@ pub trait VecZnxOps {
|
|||||||
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx);
|
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx);
|
||||||
|
|
||||||
/// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
|
/// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
|
||||||
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx, a_cols: usize);
|
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx);
|
||||||
|
|
||||||
/// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
|
/// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
|
||||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_cols: usize);
|
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx);
|
||||||
|
|
||||||
/// Splits b into subrings and copies them them into a.
|
/// Splits b into subrings and copies them them into a.
|
||||||
///
|
///
|
||||||
@@ -452,8 +455,8 @@ impl VecZnxOps for Module {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// b <- a + b
|
// b <- a - b
|
||||||
fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx::vec_znx_sub(
|
vec_znx::vec_znx_sub(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
@@ -470,6 +473,24 @@ impl VecZnxOps for Module {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// b <- b - a
|
||||||
|
fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||||
|
unsafe {
|
||||||
|
vec_znx::vec_znx_sub(
|
||||||
|
self.ptr,
|
||||||
|
b.as_mut_ptr(),
|
||||||
|
b.cols() as u64,
|
||||||
|
b.n() as u64,
|
||||||
|
b.as_ptr(),
|
||||||
|
b.cols() as u64,
|
||||||
|
b.n() as u64,
|
||||||
|
a.as_ptr(),
|
||||||
|
a.cols() as u64,
|
||||||
|
a.n() as u64,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) {
|
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx::vec_znx_negate(
|
vec_znx::vec_znx_negate(
|
||||||
@@ -540,10 +561,9 @@ impl VecZnxOps for Module {
|
|||||||
/// # Panics
|
/// # Panics
|
||||||
///
|
///
|
||||||
/// The method will panic if the argument `a` is greater than `a.cols()`.
|
/// The method will panic if the argument `a` is greater than `a.cols()`.
|
||||||
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx, a_cols: usize) {
|
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) {
|
||||||
debug_assert_eq!(a.n(), self.n());
|
debug_assert_eq!(a.n(), self.n());
|
||||||
debug_assert_eq!(b.n(), self.n());
|
debug_assert_eq!(b.n(), self.n());
|
||||||
debug_assert!(a.cols() >= a_cols);
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx::vec_znx_automorphism(
|
vec_znx::vec_znx_automorphism(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
@@ -552,7 +572,7 @@ impl VecZnxOps for Module {
|
|||||||
b.cols() as u64,
|
b.cols() as u64,
|
||||||
b.n() as u64,
|
b.n() as u64,
|
||||||
a.as_ptr(),
|
a.as_ptr(),
|
||||||
a_cols as u64,
|
a.cols() as u64,
|
||||||
a.n() as u64,
|
a.n() as u64,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -569,9 +589,8 @@ impl VecZnxOps for Module {
|
|||||||
/// # Panics
|
/// # Panics
|
||||||
///
|
///
|
||||||
/// The method will panic if the argument `cols` is greater than `self.cols()`.
|
/// The method will panic if the argument `cols` is greater than `self.cols()`.
|
||||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_cols: usize) {
|
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) {
|
||||||
debug_assert_eq!(a.n(), self.n());
|
debug_assert_eq!(a.n(), self.n());
|
||||||
debug_assert!(a.cols() >= a_cols);
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx::vec_znx_automorphism(
|
vec_znx::vec_znx_automorphism(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
@@ -580,7 +599,7 @@ impl VecZnxOps for Module {
|
|||||||
a.cols() as u64,
|
a.cols() as u64,
|
||||||
a.n() as u64,
|
a.n() as u64,
|
||||||
a.as_ptr(),
|
a.as_ptr(),
|
||||||
a_cols as u64,
|
a.cols() as u64,
|
||||||
a.n() as u64,
|
a.n() as u64,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::ffi::vec_znx_big::{self, vec_znx_big_t};
|
use crate::ffi::vec_znx_big::{self, vec_znx_big_t};
|
||||||
use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, BACKEND};
|
use crate::{BACKEND, Infos, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement};
|
||||||
|
|
||||||
pub struct VecZnxBig {
|
pub struct VecZnxBig {
|
||||||
pub data: Vec<u8>,
|
pub data: Vec<u8>,
|
||||||
@@ -16,6 +16,7 @@ impl VecZnxBig {
|
|||||||
pub fn from_bytes(module: &Module, cols: usize, bytes: &mut [u8]) -> Self {
|
pub fn from_bytes(module: &Module, cols: usize, bytes: &mut [u8]) -> Self {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
|
assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols));
|
||||||
assert_alignement(bytes.as_ptr())
|
assert_alignement(bytes.as_ptr())
|
||||||
};
|
};
|
||||||
unsafe {
|
unsafe {
|
||||||
@@ -54,14 +55,6 @@ impl VecZnxBig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn n(&self) -> usize {
|
|
||||||
self.n
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn cols(&self) -> usize {
|
|
||||||
self.cols
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn backend(&self) -> BACKEND {
|
pub fn backend(&self) -> BACKEND {
|
||||||
self.backend
|
self.backend
|
||||||
}
|
}
|
||||||
@@ -77,12 +70,36 @@ impl VecZnxBig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Infos for VecZnxBig {
|
||||||
|
/// Returns the base 2 logarithm of the [VecZnx] degree.
|
||||||
|
fn log_n(&self) -> usize {
|
||||||
|
(usize::BITS - (self.n - 1).leading_zeros()) as _
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the [VecZnx] degree.
|
||||||
|
fn n(&self) -> usize {
|
||||||
|
self.n
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of cols of the [VecZnx].
|
||||||
|
fn cols(&self) -> usize {
|
||||||
|
self.cols
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of rows of the [VecZnx].
|
||||||
|
fn rows(&self) -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait VecZnxBigOps {
|
pub trait VecZnxBigOps {
|
||||||
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
||||||
fn new_vec_znx_big(&self, cols: usize) -> VecZnxBig;
|
fn new_vec_znx_big(&self, cols: usize) -> VecZnxBig;
|
||||||
|
|
||||||
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||||
///
|
///
|
||||||
|
/// Behavior: takes ownership of the backing array.
|
||||||
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `cols`: the number of cols of the [VecZnxBig].
|
/// * `cols`: the number of cols of the [VecZnxBig].
|
||||||
@@ -92,6 +109,19 @@ pub trait VecZnxBigOps {
|
|||||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig;
|
fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig;
|
||||||
|
|
||||||
|
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||||
|
///
|
||||||
|
/// Behavior: the backing array is only borrowed.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `cols`: the number of cols of the [VecZnxBig].
|
||||||
|
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||||
|
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig;
|
||||||
|
|
||||||
/// Returns the minimum number of bytes necessary to allocate
|
/// Returns the minimum number of bytes necessary to allocate
|
||||||
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
|
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
|
||||||
fn bytes_of_vec_znx_big(&self, cols: usize) -> usize;
|
fn bytes_of_vec_znx_big(&self, cols: usize) -> usize;
|
||||||
@@ -111,13 +141,7 @@ pub trait VecZnxBigOps {
|
|||||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
|
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
|
||||||
|
|
||||||
/// b <- normalize(a)
|
/// b <- normalize(a)
|
||||||
fn vec_znx_big_normalize(
|
fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]);
|
||||||
&self,
|
|
||||||
log_base2k: usize,
|
|
||||||
b: &mut VecZnx,
|
|
||||||
a: &VecZnxBig,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
);
|
|
||||||
|
|
||||||
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize;
|
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize;
|
||||||
|
|
||||||
@@ -151,19 +175,13 @@ impl VecZnxBigOps for Module {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig {
|
fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig {
|
||||||
debug_assert!(
|
|
||||||
bytes.len() >= <Module as VecZnxBigOps>::bytes_of_vec_znx_big(self, cols),
|
|
||||||
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
|
|
||||||
bytes.len(),
|
|
||||||
<Module as VecZnxBigOps>::bytes_of_vec_znx_big(self, cols)
|
|
||||||
);
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_alignement(bytes.as_ptr())
|
|
||||||
}
|
|
||||||
VecZnxBig::from_bytes(self, cols, bytes)
|
VecZnxBig::from_bytes(self, cols, bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig {
|
||||||
|
VecZnxBig::from_bytes_borrow(self, cols, tmp_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
fn bytes_of_vec_znx_big(&self, cols: usize) -> usize {
|
fn bytes_of_vec_znx_big(&self, cols: usize) -> usize {
|
||||||
unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, cols as u64) as usize }
|
unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, cols as u64) as usize }
|
||||||
}
|
}
|
||||||
@@ -232,13 +250,7 @@ impl VecZnxBigOps for Module {
|
|||||||
unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize }
|
unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_normalize(
|
fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) {
|
||||||
&self,
|
|
||||||
log_base2k: usize,
|
|
||||||
b: &mut VecZnx,
|
|
||||||
a: &VecZnxBig,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
) {
|
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
tmp_bytes.len() >= <Module as VecZnxBigOps>::vec_znx_big_normalize_tmp_bytes(self),
|
tmp_bytes.len() >= <Module as VecZnxBigOps>::vec_znx_big_normalize_tmp_bytes(self),
|
||||||
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",
|
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
||||||
use crate::ffi::vec_znx_dft;
|
use crate::ffi::vec_znx_dft;
|
||||||
use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t};
|
use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t};
|
||||||
use crate::{alloc_aligned, VecZnx, DEFAULTALIGN};
|
use crate::{BACKEND, Infos, Module, VecZnxBig, assert_alignement};
|
||||||
use crate::{assert_alignement, Infos, Module, VecZnxBig, BACKEND};
|
use crate::{DEFAULTALIGN, VecZnx, alloc_aligned};
|
||||||
|
|
||||||
pub struct VecZnxDft {
|
pub struct VecZnxDft {
|
||||||
pub data: Vec<u8>,
|
pub data: Vec<u8>,
|
||||||
@@ -61,14 +61,6 @@ impl VecZnxDft {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn n(&self) -> usize {
|
|
||||||
self.n
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn cols(&self) -> usize {
|
|
||||||
self.cols
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn backend(&self) -> BACKEND {
|
pub fn backend(&self) -> BACKEND {
|
||||||
self.backend
|
self.backend
|
||||||
}
|
}
|
||||||
@@ -102,12 +94,36 @@ impl VecZnxDft {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Infos for VecZnxDft {
|
||||||
|
/// Returns the base 2 logarithm of the [VecZnx] degree.
|
||||||
|
fn log_n(&self) -> usize {
|
||||||
|
(usize::BITS - (self.n - 1).leading_zeros()) as _
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the [VecZnx] degree.
|
||||||
|
fn n(&self) -> usize {
|
||||||
|
self.n
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of cols of the [VecZnx].
|
||||||
|
fn cols(&self) -> usize {
|
||||||
|
self.cols
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of rows of the [VecZnx].
|
||||||
|
fn rows(&self) -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait VecZnxDftOps {
|
pub trait VecZnxDftOps {
|
||||||
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
|
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
|
||||||
fn new_vec_znx_dft(&self, cols: usize) -> VecZnxDft;
|
fn new_vec_znx_dft(&self, cols: usize) -> VecZnxDft;
|
||||||
|
|
||||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||||
///
|
///
|
||||||
|
/// Behavior: takes ownership of the backing array.
|
||||||
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||||
@@ -117,6 +133,19 @@ pub trait VecZnxDftOps {
|
|||||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft;
|
fn new_vec_znx_dft_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft;
|
||||||
|
|
||||||
|
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||||
|
///
|
||||||
|
/// Behavior: the backing array is only borrowed.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||||
|
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||||
|
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft;
|
||||||
|
|
||||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -133,28 +162,15 @@ pub trait VecZnxDftOps {
|
|||||||
fn vec_znx_idft_tmp_bytes(&self) -> usize;
|
fn vec_znx_idft_tmp_bytes(&self) -> usize;
|
||||||
|
|
||||||
/// b <- IDFT(a), uses a as scratch space.
|
/// b <- IDFT(a), uses a as scratch space.
|
||||||
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize);
|
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft);
|
||||||
|
|
||||||
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, a_cols: usize, tmp_bytes: &mut [u8]);
|
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]);
|
||||||
|
|
||||||
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize);
|
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx);
|
||||||
|
|
||||||
fn vec_znx_dft_automorphism(
|
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft);
|
||||||
&self,
|
|
||||||
k: i64,
|
|
||||||
b: &mut VecZnxDft,
|
|
||||||
b_cols: usize,
|
|
||||||
a: &VecZnxDft,
|
|
||||||
a_cols: usize,
|
|
||||||
);
|
|
||||||
|
|
||||||
fn vec_znx_dft_automorphism_inplace(
|
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]);
|
||||||
&self,
|
|
||||||
k: i64,
|
|
||||||
a: &mut VecZnxDft,
|
|
||||||
a_cols: usize,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
);
|
|
||||||
|
|
||||||
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize;
|
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize;
|
||||||
}
|
}
|
||||||
@@ -173,37 +189,25 @@ impl VecZnxDftOps for Module {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
|
fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
|
||||||
debug_assert!(
|
|
||||||
tmp_bytes.len() >= Self::bytes_of_vec_znx_dft(self, cols),
|
|
||||||
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
|
|
||||||
tmp_bytes.len(),
|
|
||||||
Self::bytes_of_vec_znx_dft(self, cols)
|
|
||||||
);
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_alignement(tmp_bytes.as_ptr())
|
|
||||||
}
|
|
||||||
VecZnxDft::from_bytes(self, cols, tmp_bytes)
|
VecZnxDft::from_bytes(self, cols, tmp_bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
|
||||||
|
VecZnxDft::from_bytes_borrow(self, cols, tmp_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize {
|
fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize {
|
||||||
unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize }
|
unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize) {
|
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) {
|
||||||
debug_assert!(
|
|
||||||
b.cols() >= a_cols,
|
|
||||||
"invalid c_vector: b_vector.cols()={} < a_cols={}",
|
|
||||||
b.cols(),
|
|
||||||
a_cols
|
|
||||||
);
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
b.ptr as *mut vec_znx_big_t,
|
b.ptr as *mut vec_znx_big_t,
|
||||||
b.cols() as u64,
|
b.cols() as u64,
|
||||||
a.ptr as *mut vec_znx_dft_t,
|
a.ptr as *mut vec_znx_dft_t,
|
||||||
a_cols as u64,
|
a.cols() as u64,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -216,41 +220,23 @@ impl VecZnxDftOps for Module {
|
|||||||
///
|
///
|
||||||
/// # Panics
|
/// # Panics
|
||||||
/// If b.cols < a_cols
|
/// If b.cols < a_cols
|
||||||
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize) {
|
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx) {
|
||||||
debug_assert!(
|
|
||||||
b.cols() >= a_cols,
|
|
||||||
"invalid a_cols: b.cols()={} < a_cols={}",
|
|
||||||
b.cols(),
|
|
||||||
a_cols
|
|
||||||
);
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx_dft::vec_znx_dft(
|
vec_znx_dft::vec_znx_dft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
b.ptr as *mut vec_znx_dft_t,
|
b.ptr as *mut vec_znx_dft_t,
|
||||||
b.cols() as u64,
|
b.cols() as u64,
|
||||||
a.as_ptr(),
|
a.as_ptr(),
|
||||||
a_cols as u64,
|
a.cols() as u64,
|
||||||
a.n() as u64,
|
a.n() as u64,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
|
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
|
||||||
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, a_cols: usize, tmp_bytes: &mut [u8]) {
|
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert!(
|
|
||||||
b.cols() >= a_cols,
|
|
||||||
"invalid c_vector: b.cols()={} < a_cols={}",
|
|
||||||
b.cols(),
|
|
||||||
a_cols
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
a.cols() >= a_cols,
|
|
||||||
"invalid c_vector: a.cols()={} < a_cols={}",
|
|
||||||
a.cols(),
|
|
||||||
a_cols
|
|
||||||
);
|
|
||||||
assert!(
|
assert!(
|
||||||
tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self),
|
tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self),
|
||||||
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
|
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
|
||||||
@@ -263,65 +249,31 @@ impl VecZnxDftOps for Module {
|
|||||||
vec_znx_dft::vec_znx_idft(
|
vec_znx_dft::vec_znx_idft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
b.ptr as *mut vec_znx_big_t,
|
b.ptr as *mut vec_znx_big_t,
|
||||||
a.cols() as u64,
|
b.cols() as u64,
|
||||||
a.ptr as *const vec_znx_dft_t,
|
a.ptr as *const vec_znx_dft_t,
|
||||||
a_cols as u64,
|
a.cols() as u64,
|
||||||
tmp_bytes.as_mut_ptr(),
|
tmp_bytes.as_mut_ptr(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_dft_automorphism(
|
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft) {
|
||||||
&self,
|
|
||||||
k: i64,
|
|
||||||
b: &mut VecZnxDft,
|
|
||||||
b_cols: usize,
|
|
||||||
a: &VecZnxDft,
|
|
||||||
a_cols: usize,
|
|
||||||
) {
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert!(
|
|
||||||
b.cols() >= a_cols,
|
|
||||||
"invalid c_vector: b.cols()={} < a_cols={}",
|
|
||||||
b.cols(),
|
|
||||||
a_cols
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
a.cols() >= a_cols,
|
|
||||||
"invalid c_vector: a.cols()={} < a_cols={}",
|
|
||||||
a.cols(),
|
|
||||||
a_cols
|
|
||||||
);
|
|
||||||
}
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx_dft::vec_znx_dft_automorphism(
|
vec_znx_dft::vec_znx_dft_automorphism(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
k,
|
k,
|
||||||
b.ptr as *mut vec_znx_dft_t,
|
b.ptr as *mut vec_znx_dft_t,
|
||||||
b_cols as u64,
|
b.cols() as u64,
|
||||||
a.ptr as *const vec_znx_dft_t,
|
a.ptr as *const vec_znx_dft_t,
|
||||||
a_cols as u64,
|
a.cols() as u64,
|
||||||
[0u8; 0].as_mut_ptr(),
|
[0u8; 0].as_mut_ptr(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_dft_automorphism_inplace(
|
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]) {
|
||||||
&self,
|
|
||||||
k: i64,
|
|
||||||
a: &mut VecZnxDft,
|
|
||||||
a_cols: usize,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
) {
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert!(
|
|
||||||
a.cols() >= a_cols,
|
|
||||||
"invalid c_vector: a.cols()={} < a_cols={}",
|
|
||||||
a.cols(),
|
|
||||||
a_cols
|
|
||||||
);
|
|
||||||
assert!(
|
assert!(
|
||||||
tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self),
|
tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self),
|
||||||
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}",
|
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}",
|
||||||
@@ -335,9 +287,9 @@ impl VecZnxDftOps for Module {
|
|||||||
self.ptr,
|
self.ptr,
|
||||||
k,
|
k,
|
||||||
a.ptr as *mut vec_znx_dft_t,
|
a.ptr as *mut vec_znx_dft_t,
|
||||||
a_cols as u64,
|
a.cols() as u64,
|
||||||
a.ptr as *const vec_znx_dft_t,
|
a.ptr as *const vec_znx_dft_t,
|
||||||
a_cols as u64,
|
a.cols() as u64,
|
||||||
tmp_bytes.as_mut_ptr(),
|
tmp_bytes.as_mut_ptr(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -355,11 +307,9 @@ impl VecZnxDftOps for Module {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::{
|
use crate::{BACKEND, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned};
|
||||||
alloc_aligned, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, BACKEND,
|
|
||||||
};
|
|
||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use sampling::source::{new_seed, Source};
|
use sampling::source::{Source, new_seed};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_automorphism_dft() {
|
fn test_automorphism_dft() {
|
||||||
@@ -379,16 +329,16 @@ mod tests {
|
|||||||
let p: i64 = -5;
|
let p: i64 = -5;
|
||||||
|
|
||||||
// a_dft <- DFT(a)
|
// a_dft <- DFT(a)
|
||||||
module.vec_znx_dft(&mut a_dft, &a, cols);
|
module.vec_znx_dft(&mut a_dft, &a);
|
||||||
|
|
||||||
// a_dft <- AUTO(a_dft)
|
// a_dft <- AUTO(a_dft)
|
||||||
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, cols, &mut tmp_bytes);
|
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
|
||||||
|
|
||||||
// a <- AUTO(a)
|
// a <- AUTO(a)
|
||||||
module.vec_znx_automorphism_inplace(p, &mut a, cols);
|
module.vec_znx_automorphism_inplace(p, &mut a);
|
||||||
|
|
||||||
// b_dft <- DFT(AUTO(a))
|
// b_dft <- DFT(AUTO(a))
|
||||||
module.vec_znx_dft(&mut b_dft, &a, cols);
|
module.vec_znx_dft(&mut b_dft, &a);
|
||||||
|
|
||||||
let a_f64: &[f64] = a_dft.raw(&module);
|
let a_f64: &[f64] = a_dft.raw(&module);
|
||||||
let b_f64: &[f64] = b_dft.raw(&module);
|
let b_f64: &[f64] = b_dft.raw(&module);
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
||||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||||
use crate::ffi::vmp::{self, vmp_pmat_t};
|
use crate::ffi::vmp::{self, vmp_pmat_t};
|
||||||
use crate::{
|
use crate::{BACKEND, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement};
|
||||||
alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, BACKEND,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||||
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
|
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
|
||||||
@@ -26,6 +24,7 @@ pub struct VmpPMat {
|
|||||||
/// The ring degree of each [VecZnxDft].
|
/// The ring degree of each [VecZnxDft].
|
||||||
n: usize,
|
n: usize,
|
||||||
|
|
||||||
|
#[warn(dead_code)]
|
||||||
backend: BACKEND,
|
backend: BACKEND,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,8 +98,7 @@ impl VmpPMat {
|
|||||||
|
|
||||||
if self.n < 8 {
|
if self.n < 8 {
|
||||||
res.copy_from_slice(
|
res.copy_from_slice(
|
||||||
&self.raw::<T>()[(row + col * self.rows()) * self.n()
|
&self.raw::<T>()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)],
|
||||||
..(row + col * self.rows()) * (self.n() + 1)],
|
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
(0..self.n >> 3).for_each(|blk| {
|
(0..self.n >> 3).for_each(|blk| {
|
||||||
@@ -119,10 +117,7 @@ impl VmpPMat {
|
|||||||
if col == (ncols - 1) && (ncols & 1 == 1) {
|
if col == (ncols - 1) && (ncols & 1 == 1) {
|
||||||
&self.raw::<T>()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..]
|
&self.raw::<T>()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..]
|
||||||
} else {
|
} else {
|
||||||
&self.raw::<T>()[blk * nrows * ncols * 8
|
&self.raw::<T>()[blk * nrows * ncols * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..]
|
||||||
+ (col / 2) * (2 * nrows) * 8
|
|
||||||
+ row * 2 * 8
|
|
||||||
+ (col % 2) * 8..]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -219,13 +214,7 @@ pub trait VmpPMatOps {
|
|||||||
/// * `a_cols`: number of cols of the input [VecZnx].
|
/// * `a_cols`: number of cols of the input [VecZnx].
|
||||||
/// * `rows`: number of rows of the input [VmpPMat].
|
/// * `rows`: number of rows of the input [VmpPMat].
|
||||||
/// * `cols`: number of cols of the input [VmpPMat].
|
/// * `cols`: number of cols of the input [VmpPMat].
|
||||||
fn vmp_apply_dft_tmp_bytes(
|
fn vmp_apply_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize;
|
||||||
&self,
|
|
||||||
c_cols: usize,
|
|
||||||
a_cols: usize,
|
|
||||||
rows: usize,
|
|
||||||
cols: usize,
|
|
||||||
) -> usize;
|
|
||||||
|
|
||||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
|
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
|
||||||
///
|
///
|
||||||
@@ -253,6 +242,32 @@ pub trait VmpPMatOps {
|
|||||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
|
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
|
||||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]);
|
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]);
|
||||||
|
|
||||||
|
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver.
|
||||||
|
///
|
||||||
|
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||||
|
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||||
|
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||||
|
///
|
||||||
|
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
|
||||||
|
/// `j` cols, the output is a [VecZnx] of `j` cols.
|
||||||
|
///
|
||||||
|
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
||||||
|
/// |h i j|
|
||||||
|
/// |k l m|
|
||||||
|
/// ```
|
||||||
|
/// where each element is a [VecZnxDft].
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
|
||||||
|
/// * `a`: the left operand [VecZnx] of the vector matrix product.
|
||||||
|
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
|
||||||
|
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
|
||||||
|
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]);
|
||||||
|
|
||||||
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft].
|
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft].
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -261,13 +276,7 @@ pub trait VmpPMatOps {
|
|||||||
/// * `a_cols`: number of cols of the input [VecZnxDft].
|
/// * `a_cols`: number of cols of the input [VecZnxDft].
|
||||||
/// * `rows`: number of rows of the input [VmpPMat].
|
/// * `rows`: number of rows of the input [VmpPMat].
|
||||||
/// * `cols`: number of cols of the input [VmpPMat].
|
/// * `cols`: number of cols of the input [VmpPMat].
|
||||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize;
|
||||||
&self,
|
|
||||||
c_cols: usize,
|
|
||||||
a_cols: usize,
|
|
||||||
rows: usize,
|
|
||||||
cols: usize,
|
|
||||||
) -> usize;
|
|
||||||
|
|
||||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
|
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
|
||||||
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||||
@@ -296,6 +305,33 @@ pub trait VmpPMatOps {
|
|||||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||||
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]);
|
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]);
|
||||||
|
|
||||||
|
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it.
|
||||||
|
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||||
|
///
|
||||||
|
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||||
|
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||||
|
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||||
|
///
|
||||||
|
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
|
||||||
|
/// `j` cols, the output is a [VecZnx] of `j` cols.
|
||||||
|
///
|
||||||
|
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
||||||
|
/// |h i j|
|
||||||
|
/// |k l m|
|
||||||
|
/// ```
|
||||||
|
/// where each element is a [VecZnxDft].
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
|
||||||
|
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
|
||||||
|
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
|
||||||
|
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||||
|
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]);
|
||||||
|
|
||||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place.
|
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place.
|
||||||
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||||
///
|
///
|
||||||
@@ -461,13 +497,7 @@ impl VmpPMatOps for Module {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_apply_dft_tmp_bytes(
|
fn vmp_apply_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize {
|
||||||
&self,
|
|
||||||
res_cols: usize,
|
|
||||||
a_cols: usize,
|
|
||||||
gct_rows: usize,
|
|
||||||
gct_cols: usize,
|
|
||||||
) -> usize {
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_apply_dft_tmp_bytes(
|
vmp::vmp_apply_dft_tmp_bytes(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
@@ -480,9 +510,7 @@ impl VmpPMatOps for Module {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
|
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||||
debug_assert!(
|
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
|
||||||
tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
|
|
||||||
);
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_alignement(tmp_bytes.as_ptr());
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
@@ -503,13 +531,29 @@ impl VmpPMatOps for Module {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||||
&self,
|
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
|
||||||
res_cols: usize,
|
#[cfg(debug_assertions)]
|
||||||
a_cols: usize,
|
{
|
||||||
gct_rows: usize,
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
gct_cols: usize,
|
}
|
||||||
) -> usize {
|
unsafe {
|
||||||
|
vmp::vmp_apply_dft_add(
|
||||||
|
self.ptr,
|
||||||
|
c.ptr as *mut vec_znx_dft_t,
|
||||||
|
c.cols() as u64,
|
||||||
|
a.as_ptr(),
|
||||||
|
a.cols() as u64,
|
||||||
|
a.n() as u64,
|
||||||
|
b.as_ptr() as *const vmp_pmat_t,
|
||||||
|
b.rows() as u64,
|
||||||
|
b.cols() as u64,
|
||||||
|
tmp_bytes.as_mut_ptr(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize {
|
||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
@@ -521,17 +565,8 @@ impl VmpPMatOps for Module {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_apply_dft_to_dft(
|
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||||
&self,
|
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
|
||||||
c: &mut VecZnxDft,
|
|
||||||
a: &VecZnxDft,
|
|
||||||
b: &VmpPMat,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
) {
|
|
||||||
debug_assert!(
|
|
||||||
tmp_bytes.len()
|
|
||||||
>= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
|
|
||||||
);
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_alignement(tmp_bytes.as_ptr());
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
@@ -551,11 +586,29 @@ impl VmpPMatOps for Module {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||||
|
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
|
}
|
||||||
|
unsafe {
|
||||||
|
vmp::vmp_apply_dft_to_dft_add(
|
||||||
|
self.ptr,
|
||||||
|
c.ptr as *mut vec_znx_dft_t,
|
||||||
|
c.cols() as u64,
|
||||||
|
a.ptr as *const vec_znx_dft_t,
|
||||||
|
a.cols() as u64,
|
||||||
|
b.as_ptr() as *const vmp_pmat_t,
|
||||||
|
b.rows() as u64,
|
||||||
|
b.cols() as u64,
|
||||||
|
tmp_bytes.as_mut_ptr(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) {
|
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||||
debug_assert!(
|
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols()));
|
||||||
tmp_bytes.len()
|
|
||||||
>= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols())
|
|
||||||
);
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_alignement(tmp_bytes.as_ptr());
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
@@ -579,8 +632,7 @@ impl VmpPMatOps for Module {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::{
|
use crate::{
|
||||||
alloc_aligned, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps,
|
Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, alloc_aligned,
|
||||||
VecZnxOps, VmpPMat, VmpPMatOps,
|
|
||||||
};
|
};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
@@ -598,13 +650,12 @@ mod tests {
|
|||||||
let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
|
let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
|
||||||
let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
|
let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
|
||||||
|
|
||||||
let mut tmp_bytes: Vec<u8> =
|
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols));
|
||||||
alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols));
|
|
||||||
|
|
||||||
for row_i in 0..vpmat_rows {
|
for row_i in 0..vpmat_rows {
|
||||||
let mut source: Source = Source::new([0u8; 32]);
|
let mut source: Source = Source::new([0u8; 32]);
|
||||||
module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source);
|
module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source);
|
||||||
module.vec_znx_dft(&mut a_dft, &a, vpmat_cols);
|
module.vec_znx_dft(&mut a_dft, &a);
|
||||||
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
|
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
|
||||||
|
|
||||||
// Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft)
|
// Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft)
|
||||||
@@ -617,7 +668,7 @@ mod tests {
|
|||||||
|
|
||||||
// Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big)
|
// Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big)
|
||||||
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i);
|
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i);
|
||||||
module.vec_znx_idft(&mut a_big, &a_dft, vpmat_cols, &mut tmp_bytes);
|
module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes);
|
||||||
assert_eq!(a_big.raw::<i64>(&module), b_big.raw::<i64>(&module));
|
assert_eq!(a_big.raw::<i64>(&module), b_big.raw::<i64>(&module));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
use base2k::{
|
use base2k::{
|
||||||
Infos, BACKEND, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps,
|
BACKEND, Infos, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8,
|
||||||
VmpPMat, alloc_aligned_u8,
|
|
||||||
};
|
};
|
||||||
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
||||||
use rlwe::{
|
use rlwe::{
|
||||||
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||||
elem::ElemCommon,
|
elem::ElemCommon,
|
||||||
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
|
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
|
||||||
gadget_product::{gadget_product_core, gadget_product_tmp_bytes},
|
gadget_product::{gadget_product_core, gadget_product_core_tmp_bytes},
|
||||||
keys::SecretKey,
|
keys::SecretKey,
|
||||||
parameters::{Parameters, ParametersLiteral},
|
parameters::{Parameters, ParametersLiteral},
|
||||||
};
|
};
|
||||||
@@ -19,20 +18,16 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
|||||||
res_dft_0: &'a mut VecZnxDft,
|
res_dft_0: &'a mut VecZnxDft,
|
||||||
res_dft_1: &'a mut VecZnxDft,
|
res_dft_1: &'a mut VecZnxDft,
|
||||||
a: &'a VecZnx,
|
a: &'a VecZnx,
|
||||||
a_cols: usize,
|
|
||||||
b: &'a Ciphertext<VmpPMat>,
|
b: &'a Ciphertext<VmpPMat>,
|
||||||
b_cols: usize,
|
b_cols: usize,
|
||||||
tmp_bytes: &'a mut [u8],
|
tmp_bytes: &'a mut [u8],
|
||||||
) -> Box<dyn FnMut() + 'a> {
|
) -> Box<dyn FnMut() + 'a> {
|
||||||
Box::new(move || {
|
Box::new(move || {
|
||||||
gadget_product_core(
|
gadget_product_core(module, res_dft_0, res_dft_1, a, b, b_cols, tmp_bytes);
|
||||||
module, res_dft_0, res_dft_1, a, a_cols, b, b_cols, tmp_bytes,
|
|
||||||
);
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
|
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("gadget_product_inplace");
|
||||||
c.benchmark_group("gadget_product_inplace");
|
|
||||||
|
|
||||||
for log_n in 10..11 {
|
for log_n in 10..11 {
|
||||||
let params_lit: ParametersLiteral = ParametersLiteral {
|
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||||
@@ -50,7 +45,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
|||||||
|
|
||||||
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
|
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
|
||||||
params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
|
params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
|
||||||
| gadget_product_tmp_bytes(
|
| gadget_product_core_tmp_bytes(
|
||||||
params.module(),
|
params.module(),
|
||||||
params.log_base2k(),
|
params.log_base2k(),
|
||||||
params.log_q(),
|
params.log_q(),
|
||||||
@@ -119,7 +114,6 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
|||||||
.module()
|
.module()
|
||||||
.fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa);
|
.fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa);
|
||||||
|
|
||||||
let a_cols: usize = a.cols();
|
|
||||||
let b_cols: usize = gadget_ct.cols();
|
let b_cols: usize = gadget_ct.cols();
|
||||||
|
|
||||||
let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), {
|
let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), {
|
||||||
@@ -128,7 +122,6 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
|||||||
&mut res_dft_0,
|
&mut res_dft_0,
|
||||||
&mut res_dft_1,
|
&mut res_dft_1,
|
||||||
&mut a,
|
&mut a,
|
||||||
a_cols,
|
|
||||||
&gadget_ct,
|
&gadget_ct,
|
||||||
b_cols,
|
b_cols,
|
||||||
&mut tmp_bytes,
|
&mut tmp_bytes,
|
||||||
|
|||||||
@@ -22,10 +22,8 @@ fn main() {
|
|||||||
|
|
||||||
let params: Parameters = Parameters::new(¶ms_lit);
|
let params: Parameters = Parameters::new(¶ms_lit);
|
||||||
|
|
||||||
let mut tmp_bytes: Vec<u8> = alloc_aligned(
|
let mut tmp_bytes: Vec<u8> =
|
||||||
params.decrypt_rlwe_tmp_byte(params.log_q())
|
alloc_aligned(params.decrypt_rlwe_tmp_byte(params.log_q()) | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()));
|
||||||
| params.encrypt_rlwe_sk_tmp_bytes(params.log_q()),
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut source: Source = Source::new([0; 32]);
|
let mut source: Source = Source::new([0; 32]);
|
||||||
let mut sk: SecretKey = SecretKey::new(params.module());
|
let mut sk: SecretKey = SecretKey::new(params.module());
|
||||||
|
|||||||
@@ -1,151 +0,0 @@
|
|||||||
use base2k::{
|
|
||||||
Encoding, Infos, Module, Sampling, SvpPPol, SvpPPolOps, VecZnx, VecZnxDftOps, VecZnxOps,
|
|
||||||
VmpPMat, VmpPMatOps, is_aligned,
|
|
||||||
};
|
|
||||||
use itertools::izip;
|
|
||||||
use rlwe::ciphertext::{Ciphertext, new_gadget_ciphertext};
|
|
||||||
use rlwe::elem::ElemCommon;
|
|
||||||
use rlwe::encryptor::encrypt_rlwe_sk;
|
|
||||||
use rlwe::keys::SecretKey;
|
|
||||||
use rlwe::plaintext::Plaintext;
|
|
||||||
use sampling::source::{Source, new_seed};
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
let n: usize = 32;
|
|
||||||
let module: Module = Module::new(n, base2k::BACKEND::FFT64);
|
|
||||||
let log_base2k: usize = 16;
|
|
||||||
let log_k: usize = 32;
|
|
||||||
let cols: usize = 4;
|
|
||||||
|
|
||||||
let mut a: VecZnx = module.new_vec_znx(cols);
|
|
||||||
let mut data: Vec<i64> = vec![0i64; n];
|
|
||||||
data[0] = 0;
|
|
||||||
data[1] = 0;
|
|
||||||
a.encode_vec_i64(log_base2k, log_k, &data, 16);
|
|
||||||
|
|
||||||
let mut a_dft: base2k::VecZnxDft = module.new_vec_znx_dft(cols);
|
|
||||||
|
|
||||||
module.vec_znx_dft(&mut a_dft, &a, cols);
|
|
||||||
|
|
||||||
(0..cols).for_each(|i| {
|
|
||||||
println!("{:?}", a_dft.at::<f64>(&module, i));
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct GadgetCiphertextProtocol {}
|
|
||||||
|
|
||||||
impl GadgetCiphertextProtocol {
|
|
||||||
pub fn new() -> GadgetCiphertextProtocol {
|
|
||||||
Self {}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allocate(
|
|
||||||
module: &Module,
|
|
||||||
log_base2k: usize,
|
|
||||||
rows: usize,
|
|
||||||
log_q: usize,
|
|
||||||
) -> GadgetCiphertextShare {
|
|
||||||
GadgetCiphertextShare::new(module, log_base2k, rows, log_q)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn gen_share(
|
|
||||||
module: &Module,
|
|
||||||
sk: &SecretKey,
|
|
||||||
pt: &Plaintext,
|
|
||||||
seed: &[u8; 32],
|
|
||||||
share: &mut GadgetCiphertextShare,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
) {
|
|
||||||
share.seed.copy_from_slice(seed);
|
|
||||||
let mut source_xe: Source = Source::new(new_seed());
|
|
||||||
let mut source_xa: Source = Source::new(*seed);
|
|
||||||
let mut sk_ppol: SvpPPol = module.new_svp_ppol();
|
|
||||||
sk.prepare(module, &mut sk_ppol);
|
|
||||||
share.value.iter_mut().for_each(|ai| {
|
|
||||||
//let elem = Elem<VecZnx>{};
|
|
||||||
//encrypt_rlwe_sk_thread_safe(module, ai, Some(pt.elem()), &sk_ppol, &mut source_xa, &mut source_xe, 3.2, tmp_bytes);
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct GadgetCiphertextShare {
|
|
||||||
pub seed: [u8; 32],
|
|
||||||
pub log_q: usize,
|
|
||||||
pub log_base2k: usize,
|
|
||||||
pub value: Vec<VecZnx>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GadgetCiphertextShare {
|
|
||||||
pub fn new(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Self {
|
|
||||||
let value: Vec<VecZnx> = Vec::new();
|
|
||||||
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
|
|
||||||
(0..rows).for_each(|_| {
|
|
||||||
let vec_znx: VecZnx = module.new_vec_znx(cols);
|
|
||||||
});
|
|
||||||
Self {
|
|
||||||
seed: [u8::default(); 32],
|
|
||||||
log_q: log_q,
|
|
||||||
log_base2k: log_base2k,
|
|
||||||
value: value,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn rows(&self) -> usize {
|
|
||||||
self.value.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn cols(&self) -> usize {
|
|
||||||
self.value[0].cols()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn aggregate_inplace(&mut self, module: &Module, a: &GadgetCiphertextShare) {
|
|
||||||
izip!(self.value.iter_mut(), a.value.iter()).for_each(|(bi, ai)| {
|
|
||||||
module.vec_znx_add_inplace(bi, ai);
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get(&self, module: &Module, b: &mut Ciphertext<VmpPMat>, tmp_bytes: &mut [u8]) {
|
|
||||||
assert!(is_aligned(tmp_bytes.as_ptr()));
|
|
||||||
|
|
||||||
let rows: usize = b.rows();
|
|
||||||
let cols: usize = b.cols();
|
|
||||||
|
|
||||||
assert!(tmp_bytes.len() >= gadget_ciphertext_share_get_tmp_bytes(module, rows, cols));
|
|
||||||
|
|
||||||
assert_eq!(self.value.len(), rows);
|
|
||||||
assert_eq!(self.value[0].cols(), cols);
|
|
||||||
|
|
||||||
let (tmp_bytes_vmp_prepare_row, tmp_bytes_vec_znx) =
|
|
||||||
tmp_bytes.split_at_mut(module.vmp_prepare_tmp_bytes(rows, cols));
|
|
||||||
|
|
||||||
let mut c: VecZnx = VecZnx::from_bytes_borrow(module.n(), cols, tmp_bytes_vec_znx);
|
|
||||||
|
|
||||||
let mut source: Source = Source::new(self.seed);
|
|
||||||
|
|
||||||
(0..self.value.len()).for_each(|row_i| {
|
|
||||||
module.vmp_prepare_row(
|
|
||||||
b.at_mut(0),
|
|
||||||
self.value[row_i].raw(),
|
|
||||||
row_i,
|
|
||||||
tmp_bytes_vmp_prepare_row,
|
|
||||||
);
|
|
||||||
module.fill_uniform(self.log_base2k, &mut c, cols, &mut source);
|
|
||||||
module.vmp_prepare_row(b.at_mut(1), c.raw(), row_i, tmp_bytes_vmp_prepare_row)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_new(&self, module: &Module, tmp_bytes: &mut [u8]) -> Ciphertext<VmpPMat> {
|
|
||||||
let mut b: Ciphertext<VmpPMat> =
|
|
||||||
new_gadget_ciphertext(module, self.log_base2k, self.rows(), self.log_q);
|
|
||||||
self.get(module, &mut b, tmp_bytes);
|
|
||||||
b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn gadget_ciphertext_share_get_tmp_bytes(module: &Module, rows: usize, cols: usize) -> usize {
|
|
||||||
module.vmp_prepare_tmp_bytes(rows, cols) + module.bytes_of_vec_znx(cols)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct CircularCiphertextProtocol {}
|
|
||||||
|
|
||||||
pub struct CircularGadgetCiphertextProtocol {}
|
|
||||||
349
rlwe/src/automorphism.rs
Normal file
349
rlwe/src/automorphism.rs
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
use crate::{
|
||||||
|
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||||
|
elem::ElemCommon,
|
||||||
|
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
|
||||||
|
key_switching::{key_switch_rlwe, key_switch_rlwe_inplace, key_switch_tmp_bytes},
|
||||||
|
keys::SecretKey,
|
||||||
|
parameters::Parameters,
|
||||||
|
};
|
||||||
|
use base2k::{
|
||||||
|
Module, Scalar, ScalarOps, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat,
|
||||||
|
VmpPMatOps, assert_alignement,
|
||||||
|
};
|
||||||
|
use sampling::source::Source;
|
||||||
|
use std::{cmp::min, collections::HashMap};
|
||||||
|
|
||||||
|
/// Stores DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E, A]) where AUTO(X, p): X^{i} -> X^{i*p}
|
||||||
|
pub struct AutomorphismKey {
|
||||||
|
pub value: Ciphertext<VmpPMat>,
|
||||||
|
pub p: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn automorphis_key_new_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
|
||||||
|
module.bytes_of_scalar() + module.bytes_of_svp_ppol() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parameters {
|
||||||
|
pub fn automorphism_key_new_tmp_bytes(&self, rows: usize, log_q: usize) -> usize {
|
||||||
|
automorphis_key_new_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn automorphism_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
|
||||||
|
automorphism_tmp_bytes(
|
||||||
|
self.module(),
|
||||||
|
self.log_base2k(),
|
||||||
|
res_logq,
|
||||||
|
in_logq,
|
||||||
|
gct_logq,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AutomorphismKey {
|
||||||
|
pub fn new(
|
||||||
|
module: &Module,
|
||||||
|
p: i64,
|
||||||
|
sk: &SecretKey,
|
||||||
|
log_base2k: usize,
|
||||||
|
rows: usize,
|
||||||
|
log_q: usize,
|
||||||
|
source_xa: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
sigma: f64,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) -> Self {
|
||||||
|
Self::new_many_core(
|
||||||
|
module,
|
||||||
|
&vec![p],
|
||||||
|
sk,
|
||||||
|
log_base2k,
|
||||||
|
rows,
|
||||||
|
log_q,
|
||||||
|
source_xa,
|
||||||
|
source_xe,
|
||||||
|
sigma,
|
||||||
|
tmp_bytes,
|
||||||
|
)
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_many(
|
||||||
|
module: &Module,
|
||||||
|
p: &Vec<i64>,
|
||||||
|
sk: &SecretKey,
|
||||||
|
log_base2k: usize,
|
||||||
|
rows: usize,
|
||||||
|
log_q: usize,
|
||||||
|
source_xa: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
sigma: f64,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) -> HashMap<i64, AutomorphismKey> {
|
||||||
|
Self::new_many_core(
|
||||||
|
module, p, sk, log_base2k, rows, log_q, source_xa, source_xe, sigma, tmp_bytes,
|
||||||
|
)
|
||||||
|
.into_iter()
|
||||||
|
.zip(p.iter().cloned())
|
||||||
|
.map(|(key, pi)| (pi, key))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_many_core(
|
||||||
|
module: &Module,
|
||||||
|
p: &Vec<i64>,
|
||||||
|
sk: &SecretKey,
|
||||||
|
log_base2k: usize,
|
||||||
|
rows: usize,
|
||||||
|
log_q: usize,
|
||||||
|
source_xa: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
sigma: f64,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) -> Vec<Self> {
|
||||||
|
let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar());
|
||||||
|
let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol());
|
||||||
|
|
||||||
|
let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes);
|
||||||
|
let mut sk_out: SvpPPol = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes);
|
||||||
|
|
||||||
|
let mut keys: Vec<AutomorphismKey> = Vec::new();
|
||||||
|
|
||||||
|
p.iter().for_each(|pi| {
|
||||||
|
let mut value: Ciphertext<VmpPMat> = new_gadget_ciphertext(module, log_base2k, rows, log_q);
|
||||||
|
|
||||||
|
let p_inv: i64 = module.galois_element_inv(*pi);
|
||||||
|
|
||||||
|
module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx());
|
||||||
|
module.svp_prepare(&mut sk_out, &sk_auto);
|
||||||
|
encrypt_grlwe_sk(
|
||||||
|
module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes,
|
||||||
|
);
|
||||||
|
|
||||||
|
keys.push(Self {
|
||||||
|
value: value,
|
||||||
|
p: *pi,
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
keys
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn automorphism_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
|
||||||
|
key_switch_tmp_bytes(module, log_base2k, res_logq, in_logq, gct_logq)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn automorphism(
|
||||||
|
module: &Module,
|
||||||
|
c: &mut Ciphertext<VecZnx>,
|
||||||
|
a: &Ciphertext<VecZnx>,
|
||||||
|
b: &AutomorphismKey,
|
||||||
|
b_cols: usize,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
key_switch_rlwe(module, c, a, &b.value, b_cols, tmp_bytes);
|
||||||
|
// c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
|
||||||
|
module.vec_znx_automorphism_inplace(b.p, c.at_mut(0));
|
||||||
|
// c[1] = AUTO(b, p)
|
||||||
|
module.vec_znx_automorphism_inplace(b.p, c.at_mut(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn automorphism_inplace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
|
||||||
|
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
|
||||||
|
+ 2 * module.bytes_of_vec_znx_dft(std::cmp::min(c_cols, a_cols));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn automorphism_inplace(
|
||||||
|
module: &Module,
|
||||||
|
a: &mut Ciphertext<VecZnx>,
|
||||||
|
b: &AutomorphismKey,
|
||||||
|
b_cols: usize,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
key_switch_rlwe_inplace(module, a, &b.value, b_cols, tmp_bytes);
|
||||||
|
// a[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
|
||||||
|
module.vec_znx_automorphism_inplace(b.p, a.at_mut(0));
|
||||||
|
// a[1] = AUTO(b, p)
|
||||||
|
module.vec_znx_automorphism_inplace(b.p, a.at_mut(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn automorphism_big(
|
||||||
|
module: &Module,
|
||||||
|
c: &mut Ciphertext<VecZnxBig>,
|
||||||
|
a: &Ciphertext<VecZnx>,
|
||||||
|
b: &AutomorphismKey,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
let cols = std::cmp::min(c.cols(), a.cols());
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert!(tmp_bytes.len() >= automorphism_tmp_bytes(module, c.cols(), a.cols(), b.value.rows(), b.value.cols()));
|
||||||
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||||
|
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||||
|
|
||||||
|
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
|
||||||
|
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
|
||||||
|
|
||||||
|
// a1_dft = DFT(a[1])
|
||||||
|
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||||
|
|
||||||
|
// res_dft = IDFT(<DFT(a), DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E])>) = [-b*AUTO(s, -p) + a * s + e]
|
||||||
|
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(0), tmp_bytes);
|
||||||
|
module.vec_znx_idft_tmp_a(c.at_mut(0), &mut res_dft);
|
||||||
|
|
||||||
|
// res_dft = [-b*AUTO(s, -p) + a * s + e] + [-a * s + m + e] = [-b*AUTO(s, -p) + m + e]
|
||||||
|
module.vec_znx_big_add_small_inplace(c.at_mut(0), a.at(0));
|
||||||
|
|
||||||
|
// c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
|
||||||
|
module.vec_znx_big_automorphism_inplace(b.p, c.at_mut(0));
|
||||||
|
|
||||||
|
// res_dft = IDFT(<DFT(a), DFT([A])>) = [b]
|
||||||
|
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(1), tmp_bytes);
|
||||||
|
module.vec_znx_idft_tmp_a(c.at_mut(1), &mut res_dft);
|
||||||
|
|
||||||
|
// c[1] = AUTO(b, p)
|
||||||
|
module.vec_znx_big_automorphism_inplace(b.p, c.at_mut(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::{AutomorphismKey, automorphism};
|
||||||
|
use crate::{
|
||||||
|
ciphertext::Ciphertext,
|
||||||
|
decryptor::decrypt_rlwe,
|
||||||
|
elem::ElemCommon,
|
||||||
|
encryptor::encrypt_rlwe_sk,
|
||||||
|
keys::SecretKey,
|
||||||
|
parameters::{Parameters, ParametersLiteral},
|
||||||
|
plaintext::Plaintext,
|
||||||
|
};
|
||||||
|
use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, alloc_aligned};
|
||||||
|
use sampling::source::{Source, new_seed};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_automorphism() {
|
||||||
|
let log_base2k: usize = 10;
|
||||||
|
let log_q: usize = 50;
|
||||||
|
let log_p: usize = 15;
|
||||||
|
|
||||||
|
// Basic parameters with enough limbs to test edge cases
|
||||||
|
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||||
|
backend: BACKEND::FFT64,
|
||||||
|
log_n: 12,
|
||||||
|
log_q: log_q,
|
||||||
|
log_p: log_p,
|
||||||
|
log_base2k: log_base2k,
|
||||||
|
log_scale: 20,
|
||||||
|
xe: 3.2,
|
||||||
|
xs: 1 << 11,
|
||||||
|
};
|
||||||
|
|
||||||
|
let params: Parameters = Parameters::new(¶ms_lit);
|
||||||
|
|
||||||
|
let module: &Module = params.module();
|
||||||
|
let log_q: usize = params.log_q();
|
||||||
|
let log_qp: usize = params.log_qp();
|
||||||
|
let gct_rows: usize = params.cols_q();
|
||||||
|
let gct_cols: usize = params.cols_qp();
|
||||||
|
|
||||||
|
// scratch space
|
||||||
|
let mut tmp_bytes: Vec<u8> = alloc_aligned(
|
||||||
|
params.decrypt_rlwe_tmp_byte(log_q)
|
||||||
|
| params.encrypt_rlwe_sk_tmp_bytes(log_q)
|
||||||
|
| params.automorphism_key_new_tmp_bytes(gct_rows, log_qp)
|
||||||
|
| params.automorphism_tmp_bytes(log_q, log_q, log_qp),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Samplers for public and private randomness
|
||||||
|
let mut source_xe: Source = Source::new(new_seed());
|
||||||
|
let mut source_xa: Source = Source::new(new_seed());
|
||||||
|
let mut source_xs: Source = Source::new(new_seed());
|
||||||
|
|
||||||
|
let mut sk: SecretKey = SecretKey::new(module);
|
||||||
|
sk.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||||
|
let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol();
|
||||||
|
module.svp_prepare(&mut sk_svp_ppol, &sk.0);
|
||||||
|
|
||||||
|
let p: i64 = -5;
|
||||||
|
|
||||||
|
let auto_key: AutomorphismKey = AutomorphismKey::new(
|
||||||
|
module,
|
||||||
|
p,
|
||||||
|
&sk,
|
||||||
|
log_base2k,
|
||||||
|
gct_rows,
|
||||||
|
log_qp,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
params.xe(),
|
||||||
|
&mut tmp_bytes,
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut data: Vec<i64> = vec![0i64; params.n()];
|
||||||
|
|
||||||
|
data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||||
|
|
||||||
|
let log_k: usize = 2 * log_base2k;
|
||||||
|
|
||||||
|
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
|
||||||
|
let mut pt: Plaintext = params.new_plaintext(log_q);
|
||||||
|
let mut pt_auto: Plaintext = params.new_plaintext(log_q);
|
||||||
|
|
||||||
|
pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32);
|
||||||
|
module.vec_znx_automorphism(p, pt_auto.at_mut(0), pt.at(0));
|
||||||
|
|
||||||
|
encrypt_rlwe_sk(
|
||||||
|
module,
|
||||||
|
&mut ct.elem_mut(),
|
||||||
|
Some(pt.at(0)),
|
||||||
|
&sk_svp_ppol,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
params.xe(),
|
||||||
|
&mut tmp_bytes,
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut ct_auto: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
|
||||||
|
|
||||||
|
// ct <- AUTO(ct)
|
||||||
|
automorphism(
|
||||||
|
module,
|
||||||
|
&mut ct_auto,
|
||||||
|
&ct,
|
||||||
|
&auto_key,
|
||||||
|
gct_cols,
|
||||||
|
&mut tmp_bytes,
|
||||||
|
);
|
||||||
|
|
||||||
|
// pt = dec(auto(ct)) - auto(pt)
|
||||||
|
decrypt_rlwe(
|
||||||
|
module,
|
||||||
|
pt.elem_mut(),
|
||||||
|
ct_auto.elem(),
|
||||||
|
&sk_svp_ppol,
|
||||||
|
&mut tmp_bytes,
|
||||||
|
);
|
||||||
|
|
||||||
|
module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_auto.at(0));
|
||||||
|
|
||||||
|
// pt.at(0).print(pt.cols(), 16);
|
||||||
|
|
||||||
|
let noise_have: f64 = pt.at(0).std(log_base2k).log2();
|
||||||
|
|
||||||
|
let var_msg: f64 = (params.xs() as f64) / params.n() as f64;
|
||||||
|
let var_a_err: f64 = 1f64 / 12f64;
|
||||||
|
|
||||||
|
let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, ct_auto.log_q(), auto_key.value.log_q());
|
||||||
|
|
||||||
|
println!("noise_pred: {}", noise_pred);
|
||||||
|
println!("noise_have: {}", noise_have);
|
||||||
|
|
||||||
|
assert!(noise_have <= noise_pred + 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -74,24 +74,14 @@ pub fn new_rlwe_ciphertext(module: &Module, log_base2k: usize, log_q: usize) ->
|
|||||||
Ciphertext::<VecZnx>::new(module, log_base2k, log_q, rows)
|
Ciphertext::<VecZnx>::new(module, log_base2k, log_q, rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_gadget_ciphertext(
|
pub fn new_gadget_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext<VmpPMat> {
|
||||||
module: &Module,
|
|
||||||
log_base2k: usize,
|
|
||||||
rows: usize,
|
|
||||||
log_q: usize,
|
|
||||||
) -> Ciphertext<VmpPMat> {
|
|
||||||
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
|
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
|
||||||
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 2, rows, cols);
|
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 2, rows, cols);
|
||||||
elem.log_q = log_q;
|
elem.log_q = log_q;
|
||||||
Ciphertext(elem)
|
Ciphertext(elem)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_rgsw_ciphertext(
|
pub fn new_rgsw_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext<VmpPMat> {
|
||||||
module: &Module,
|
|
||||||
log_base2k: usize,
|
|
||||||
rows: usize,
|
|
||||||
log_q: usize,
|
|
||||||
) -> Ciphertext<VmpPMat> {
|
|
||||||
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
|
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
|
||||||
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 4, rows, cols);
|
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 4, rows, cols);
|
||||||
elem.log_q = log_q;
|
elem.log_q = log_q;
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZn
|
|||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
|
||||||
pub struct Decryptor {
|
pub struct Decryptor {
|
||||||
|
#[warn(dead_code)]
|
||||||
sk: SvpPPol,
|
sk: SvpPPol,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -32,24 +33,12 @@ impl Parameters {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decrypt_rlwe(
|
pub fn decrypt_rlwe(&self, res: &mut Plaintext, ct: &Ciphertext<VecZnx>, sk: &SvpPPol, tmp_bytes: &mut [u8]) {
|
||||||
&self,
|
|
||||||
res: &mut Plaintext,
|
|
||||||
ct: &Ciphertext<VecZnx>,
|
|
||||||
sk: &SvpPPol,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
) {
|
|
||||||
decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes)
|
decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decrypt_rlwe(
|
pub fn decrypt_rlwe(module: &Module, res: &mut Elem<VecZnx>, a: &Elem<VecZnx>, sk: &SvpPPol, tmp_bytes: &mut [u8]) {
|
||||||
module: &Module,
|
|
||||||
res: &mut Elem<VecZnx>,
|
|
||||||
a: &Elem<VecZnx>,
|
|
||||||
sk: &SvpPPol,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
) {
|
|
||||||
let cols: usize = a.cols();
|
let cols: usize = a.cols();
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
@@ -59,17 +48,15 @@ pub fn decrypt_rlwe(
|
|||||||
decrypt_rlwe_tmp_byte(module, cols)
|
decrypt_rlwe_tmp_byte(module, cols)
|
||||||
);
|
);
|
||||||
|
|
||||||
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) =
|
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||||
tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
|
||||||
|
|
||||||
let mut res_dft: VecZnxDft =
|
let mut res_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, cols, tmp_bytes_vec_znx_dft);
|
||||||
VecZnxDft::from_bytes_borrow(module, a.cols(), tmp_bytes_vec_znx_dft);
|
|
||||||
let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big();
|
let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big();
|
||||||
|
|
||||||
// res_dft <- DFT(ct[1]) * DFT(sk)
|
// res_dft <- DFT(ct[1]) * DFT(sk)
|
||||||
module.svp_apply_dft(&mut res_dft, sk, a.at(1), cols);
|
module.svp_apply_dft(&mut res_dft, sk, a.at(1));
|
||||||
// res_big <- ct[1] x sk
|
// res_big <- ct[1] x sk
|
||||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols);
|
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||||
// res_big <- ct[1] x sk + ct[0]
|
// res_big <- ct[1] x sk + ct[0]
|
||||||
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
||||||
// res <- normalize(ct[1] x sk + ct[0])
|
// res <- normalize(ct[1] x sk + ct[0])
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
use base2k::{Infos, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps};
|
use base2k::{Infos, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps};
|
||||||
|
|
||||||
use crate::parameters::Parameters;
|
|
||||||
|
|
||||||
pub struct Elem<T> {
|
pub struct Elem<T> {
|
||||||
pub value: Vec<T>,
|
pub value: Vec<T>,
|
||||||
pub log_base2k: usize,
|
pub log_base2k: usize,
|
||||||
@@ -10,20 +8,8 @@ pub struct Elem<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub trait ElemVecZnx {
|
pub trait ElemVecZnx {
|
||||||
fn from_bytes(
|
fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx>;
|
||||||
module: &Module,
|
fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx>;
|
||||||
log_base2k: usize,
|
|
||||||
log_q: usize,
|
|
||||||
size: usize,
|
|
||||||
bytes: &mut [u8],
|
|
||||||
) -> Elem<VecZnx>;
|
|
||||||
fn from_bytes_borrow(
|
|
||||||
module: &Module,
|
|
||||||
log_base2k: usize,
|
|
||||||
log_q: usize,
|
|
||||||
size: usize,
|
|
||||||
bytes: &mut [u8],
|
|
||||||
) -> Elem<VecZnx>;
|
|
||||||
fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize;
|
fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize;
|
||||||
fn zero(&mut self);
|
fn zero(&mut self);
|
||||||
}
|
}
|
||||||
@@ -34,13 +20,7 @@ impl ElemVecZnx for Elem<VecZnx> {
|
|||||||
module.n() * cols * size * 8
|
module.n() * cols * size * 8
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_bytes(
|
fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx> {
|
||||||
module: &Module,
|
|
||||||
log_base2k: usize,
|
|
||||||
log_q: usize,
|
|
||||||
size: usize,
|
|
||||||
bytes: &mut [u8],
|
|
||||||
) -> Elem<VecZnx> {
|
|
||||||
assert!(size > 0);
|
assert!(size > 0);
|
||||||
let n: usize = module.n();
|
let n: usize = module.n();
|
||||||
assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size));
|
assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size));
|
||||||
@@ -60,13 +40,7 @@ impl ElemVecZnx for Elem<VecZnx> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_bytes_borrow(
|
fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx> {
|
||||||
module: &Module,
|
|
||||||
log_base2k: usize,
|
|
||||||
log_q: usize,
|
|
||||||
size: usize,
|
|
||||||
bytes: &mut [u8],
|
|
||||||
) -> Elem<VecZnx> {
|
|
||||||
assert!(size > 0);
|
assert!(size > 0);
|
||||||
let n: usize = module.n();
|
let n: usize = module.n();
|
||||||
assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size));
|
assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size));
|
||||||
|
|||||||
@@ -5,12 +5,38 @@ use crate::parameters::Parameters;
|
|||||||
use crate::plaintext::Plaintext;
|
use crate::plaintext::Plaintext;
|
||||||
use base2k::sampling::Sampling;
|
use base2k::sampling::Sampling;
|
||||||
use base2k::{
|
use base2k::{
|
||||||
Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps,
|
Infos, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat,
|
||||||
VecZnxOps, VmpPMat, VmpPMatOps,
|
VmpPMatOps,
|
||||||
};
|
};
|
||||||
|
|
||||||
use sampling::source::{Source, new_seed};
|
use sampling::source::{Source, new_seed};
|
||||||
|
|
||||||
|
impl Parameters {
|
||||||
|
pub fn encrypt_rlwe_sk_tmp_bytes(&self, log_q: usize) -> usize {
|
||||||
|
encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q)
|
||||||
|
}
|
||||||
|
pub fn encrypt_rlwe_sk(
|
||||||
|
&self,
|
||||||
|
ct: &mut Ciphertext<VecZnx>,
|
||||||
|
pt: Option<&Plaintext>,
|
||||||
|
sk: &SvpPPol,
|
||||||
|
source_xa: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
encrypt_rlwe_sk(
|
||||||
|
self.module(),
|
||||||
|
&mut ct.0,
|
||||||
|
pt.map(|pt| pt.at(0)),
|
||||||
|
sk,
|
||||||
|
source_xa,
|
||||||
|
source_xe,
|
||||||
|
self.xe(),
|
||||||
|
tmp_bytes,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct EncryptorSk {
|
pub struct EncryptorSk {
|
||||||
sk: SvpPPol,
|
sk: SvpPPol,
|
||||||
source_xa: Source,
|
source_xa: Source,
|
||||||
@@ -49,12 +75,7 @@ impl EncryptorSk {
|
|||||||
self.source_xe = Source::new(seed)
|
self.source_xe = Source::new(seed)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn encrypt_rlwe_sk(
|
pub fn encrypt_rlwe_sk(&mut self, params: &Parameters, ct: &mut Ciphertext<VecZnx>, pt: Option<&Plaintext>) {
|
||||||
&mut self,
|
|
||||||
params: &Parameters,
|
|
||||||
ct: &mut Ciphertext<VecZnx>,
|
|
||||||
pt: Option<&Plaintext>,
|
|
||||||
) {
|
|
||||||
assert!(
|
assert!(
|
||||||
self.initialized == true,
|
self.initialized == true,
|
||||||
"invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]"
|
"invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]"
|
||||||
@@ -86,42 +107,26 @@ impl EncryptorSk {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Parameters {
|
|
||||||
pub fn encrypt_rlwe_sk_tmp_bytes(&self, log_q: usize) -> usize {
|
|
||||||
encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn encrypt_rlwe_sk(
|
|
||||||
&self,
|
|
||||||
ct: &mut Ciphertext<VecZnx>,
|
|
||||||
pt: Option<&Plaintext>,
|
|
||||||
sk: &SvpPPol,
|
|
||||||
source_xa: &mut Source,
|
|
||||||
source_xe: &mut Source,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
) {
|
|
||||||
encrypt_rlwe_sk(
|
|
||||||
self.module(),
|
|
||||||
&mut ct.0,
|
|
||||||
pt.map(|pt| &pt.0),
|
|
||||||
sk,
|
|
||||||
source_xa,
|
|
||||||
source_xe,
|
|
||||||
self.xe(),
|
|
||||||
tmp_bytes,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize {
|
pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize {
|
||||||
module.bytes_of_vec_znx_dft((log_q + log_base2k - 1) / log_base2k)
|
module.bytes_of_vec_znx_dft((log_q + log_base2k - 1) / log_base2k) + module.vec_znx_big_normalize_tmp_bytes()
|
||||||
+ module.vec_znx_big_normalize_tmp_bytes()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn encrypt_rlwe_sk(
|
pub fn encrypt_rlwe_sk(
|
||||||
module: &Module,
|
module: &Module,
|
||||||
ct: &mut Elem<VecZnx>,
|
ct: &mut Elem<VecZnx>,
|
||||||
pt: Option<&Elem<VecZnx>>,
|
pt: Option<&VecZnx>,
|
||||||
|
sk: &SvpPPol,
|
||||||
|
source_xa: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
sigma: f64,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
encrypt_rlwe_sk_core::<0>(module, ct, pt, sk, source_xa, source_xe, sigma, tmp_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encrypt_rlwe_sk_core<const PT_POS: u8>(
|
||||||
|
module: &Module,
|
||||||
|
ct: &mut Elem<VecZnx>,
|
||||||
|
pt: Option<&VecZnx>,
|
||||||
sk: &SvpPPol,
|
sk: &SvpPPol,
|
||||||
source_xa: &mut Source,
|
source_xa: &mut Source,
|
||||||
source_xe: &mut Source,
|
source_xe: &mut Source,
|
||||||
@@ -146,36 +151,49 @@ pub fn encrypt_rlwe_sk(
|
|||||||
// c1 <- Z_{2^prec}[X]/(X^{N}+1)
|
// c1 <- Z_{2^prec}[X]/(X^{N}+1)
|
||||||
module.fill_uniform(log_base2k, c1, cols, source_xa);
|
module.fill_uniform(log_base2k, c1, cols, source_xa);
|
||||||
|
|
||||||
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) =
|
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||||
tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
|
||||||
|
|
||||||
// Scratch space for DFT values
|
// Scratch space for DFT values
|
||||||
let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, cols, tmp_bytes_vec_znx_dft);
|
let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, cols, tmp_bytes_vec_znx_dft);
|
||||||
|
|
||||||
// Applies buf_dft <- DFT(s) * DFT(c1)
|
// Applies buf_dft <- DFT(s) * DFT(c1)
|
||||||
module.svp_apply_dft(&mut buf_dft, sk, c1, cols);
|
module.svp_apply_dft(&mut buf_dft, sk, c1);
|
||||||
|
|
||||||
// Alias scratch space
|
// Alias scratch space
|
||||||
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
|
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
|
||||||
|
|
||||||
// buf_big = s x c1
|
// buf_big = s x c1
|
||||||
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, cols);
|
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
|
||||||
|
|
||||||
// c0 <- -s x c1 + m
|
match PT_POS {
|
||||||
let c0: &mut VecZnx = ct.at_mut(0);
|
// c0 <- -s x c1 + m
|
||||||
|
0 => {
|
||||||
if let Some(pt) = pt {
|
let c0: &mut VecZnx = ct.at_mut(0);
|
||||||
module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt.at(0));
|
if let Some(pt) = pt {
|
||||||
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
|
module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt);
|
||||||
} else {
|
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
|
||||||
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
|
} else {
|
||||||
module.vec_znx_negate_inplace(c0);
|
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
|
||||||
|
module.vec_znx_negate_inplace(c0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// c1 <- c1 + m
|
||||||
|
1 => {
|
||||||
|
if let Some(pt) = pt {
|
||||||
|
module.vec_znx_add_inplace(c1, pt);
|
||||||
|
c1.normalize(log_base2k, tmp_bytes_normalize);
|
||||||
|
}
|
||||||
|
let c0: &mut VecZnx = ct.at_mut(0);
|
||||||
|
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
|
||||||
|
module.vec_znx_negate_inplace(c0);
|
||||||
|
}
|
||||||
|
_ => panic!("PT_POS must be 1 or 2"),
|
||||||
}
|
}
|
||||||
|
|
||||||
// c0 <- -s x c1 + m + e
|
// c0 <- -s x c1 + m + e
|
||||||
module.add_normal(
|
module.add_normal(
|
||||||
log_base2k,
|
log_base2k,
|
||||||
c0,
|
ct.at_mut(0),
|
||||||
log_q,
|
log_q,
|
||||||
source_xe,
|
source_xe,
|
||||||
sigma,
|
sigma,
|
||||||
@@ -189,12 +207,7 @@ impl Parameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn encrypt_grlwe_sk_tmp_bytes(
|
pub fn encrypt_grlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
|
||||||
module: &Module,
|
|
||||||
log_base2k: usize,
|
|
||||||
rows: usize,
|
|
||||||
log_q: usize,
|
|
||||||
) -> usize {
|
|
||||||
let cols = (log_q + log_base2k - 1) / log_base2k;
|
let cols = (log_q + log_base2k - 1) / log_base2k;
|
||||||
Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2)
|
Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2)
|
||||||
+ Plaintext::bytes_of(module, log_base2k, log_q)
|
+ Plaintext::bytes_of(module, log_base2k, log_q)
|
||||||
@@ -212,10 +225,93 @@ pub fn encrypt_grlwe_sk(
|
|||||||
sigma: f64,
|
sigma: f64,
|
||||||
tmp_bytes: &mut [u8],
|
tmp_bytes: &mut [u8],
|
||||||
) {
|
) {
|
||||||
let rows: usize = ct.rows();
|
|
||||||
let log_q: usize = ct.log_q();
|
let log_q: usize = ct.log_q();
|
||||||
//let cols: usize = (log_q + ct.log_base2k() - 1) / ct.log_base2k();
|
|
||||||
let log_base2k: usize = ct.log_base2k();
|
let log_base2k: usize = ct.log_base2k();
|
||||||
|
let (left, right) = ct.0.value.split_at_mut(1);
|
||||||
|
encrypt_grlwe_sk_core::<0>(
|
||||||
|
module,
|
||||||
|
log_base2k,
|
||||||
|
[&mut left[0], &mut right[0]],
|
||||||
|
log_q,
|
||||||
|
m,
|
||||||
|
sk,
|
||||||
|
source_xa,
|
||||||
|
source_xe,
|
||||||
|
sigma,
|
||||||
|
tmp_bytes,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parameters {
|
||||||
|
pub fn encrypt_rgsw_sk_tmp_bytes(&self, rows: usize, log_q: usize) -> usize {
|
||||||
|
encrypt_rgsw_sk_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encrypt_rgsw_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
|
||||||
|
let cols = (log_q + log_base2k - 1) / log_base2k;
|
||||||
|
Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2)
|
||||||
|
+ Plaintext::bytes_of(module, log_base2k, log_q)
|
||||||
|
+ encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
|
||||||
|
+ module.vmp_prepare_tmp_bytes(rows, cols)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encrypt_rgsw_sk(
|
||||||
|
module: &Module,
|
||||||
|
ct: &mut Ciphertext<VmpPMat>,
|
||||||
|
m: &Scalar,
|
||||||
|
sk: &SvpPPol,
|
||||||
|
source_xa: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
sigma: f64,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
let log_q: usize = ct.log_q();
|
||||||
|
let log_base2k: usize = ct.log_base2k();
|
||||||
|
|
||||||
|
let (left, right) = ct.0.value.split_at_mut(2);
|
||||||
|
let (ll, lr) = left.split_at_mut(1);
|
||||||
|
let (rl, rr) = right.split_at_mut(1);
|
||||||
|
|
||||||
|
encrypt_grlwe_sk_core::<0>(
|
||||||
|
module,
|
||||||
|
log_base2k,
|
||||||
|
[&mut ll[0], &mut lr[0]],
|
||||||
|
log_q,
|
||||||
|
m,
|
||||||
|
sk,
|
||||||
|
source_xa,
|
||||||
|
source_xe,
|
||||||
|
sigma,
|
||||||
|
tmp_bytes,
|
||||||
|
);
|
||||||
|
encrypt_grlwe_sk_core::<1>(
|
||||||
|
module,
|
||||||
|
log_base2k,
|
||||||
|
[&mut rl[0], &mut rr[0]],
|
||||||
|
log_q,
|
||||||
|
m,
|
||||||
|
sk,
|
||||||
|
source_xa,
|
||||||
|
source_xe,
|
||||||
|
sigma,
|
||||||
|
tmp_bytes,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encrypt_grlwe_sk_core<const PT_POS: u8>(
|
||||||
|
module: &Module,
|
||||||
|
log_base2k: usize,
|
||||||
|
mut ct: [&mut VmpPMat; 2],
|
||||||
|
log_q: usize,
|
||||||
|
m: &Scalar,
|
||||||
|
sk: &SvpPPol,
|
||||||
|
source_xa: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
sigma: f64,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
let rows: usize = ct[0].rows();
|
||||||
|
|
||||||
let min_tmp_bytes_len = encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q);
|
let min_tmp_bytes_len = encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q);
|
||||||
|
|
||||||
@@ -234,20 +330,18 @@ pub fn encrypt_grlwe_sk(
|
|||||||
let (tmp_bytes_enc_sk, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_enc_sk);
|
let (tmp_bytes_enc_sk, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_enc_sk);
|
||||||
let (tmp_bytes_elem, tmp_bytes_vmp_prepare_row) = tmp_bytes.split_at_mut(bytes_of_elem);
|
let (tmp_bytes_elem, tmp_bytes_vmp_prepare_row) = tmp_bytes.split_at_mut(bytes_of_elem);
|
||||||
|
|
||||||
let mut tmp_elem: Elem<VecZnx> =
|
let mut tmp_elem: Elem<VecZnx> = Elem::<VecZnx>::from_bytes_borrow(module, log_base2k, log_q, 2, tmp_bytes_elem);
|
||||||
Elem::<VecZnx>::from_bytes_borrow(module, log_base2k, ct.log_q(), 2, tmp_bytes_elem);
|
let mut tmp_pt: Plaintext = Plaintext::from_bytes_borrow(module, log_base2k, log_q, tmp_bytes_pt);
|
||||||
let mut tmp_pt: Plaintext =
|
|
||||||
Plaintext::from_bytes_borrow(module, log_base2k, log_q, tmp_bytes_pt);
|
|
||||||
|
|
||||||
(0..rows).for_each(|row_i| {
|
(0..rows).for_each(|row_i| {
|
||||||
// Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i})
|
// Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i})
|
||||||
tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.raw());
|
tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.raw());
|
||||||
|
|
||||||
// Encrypts RLWE(m * 2^{-log_base2k*i})
|
// Encrypts RLWE(m * 2^{-log_base2k*i})
|
||||||
encrypt_rlwe_sk(
|
encrypt_rlwe_sk_core::<PT_POS>(
|
||||||
module,
|
module,
|
||||||
&mut tmp_elem,
|
&mut tmp_elem,
|
||||||
Some(&tmp_pt.0),
|
Some(tmp_pt.at(0)),
|
||||||
sk,
|
sk,
|
||||||
source_xa,
|
source_xa,
|
||||||
source_xe,
|
source_xe,
|
||||||
@@ -255,31 +349,21 @@ pub fn encrypt_grlwe_sk(
|
|||||||
tmp_bytes_enc_sk,
|
tmp_bytes_enc_sk,
|
||||||
);
|
);
|
||||||
|
|
||||||
//tmp_pt.at(0).print(tmp_pt.cols(), 16);
|
|
||||||
//println!();
|
|
||||||
|
|
||||||
// Zeroes the ith-row of tmp_pt
|
// Zeroes the ith-row of tmp_pt
|
||||||
tmp_pt.at_mut(0).at_mut(row_i).fill(0);
|
tmp_pt.at_mut(0).at_mut(row_i).fill(0);
|
||||||
|
|
||||||
//println!("row:{}/{}", row_i, rows);
|
|
||||||
//tmp_elem.at(0).print(tmp_elem.cols(), tmp_elem.n());
|
|
||||||
//tmp_elem.at(1).print(tmp_elem.cols(), tmp_elem.n());
|
|
||||||
//println!();
|
|
||||||
//println!(">>>");
|
|
||||||
|
|
||||||
// GRLWE[row_i][0||1] = [-as + m * 2^{-i*log_base2k} + e*2^{-log_q} || a]
|
// GRLWE[row_i][0||1] = [-as + m * 2^{-i*log_base2k} + e*2^{-log_q} || a]
|
||||||
module.vmp_prepare_row(
|
module.vmp_prepare_row(
|
||||||
&mut ct.at_mut(0),
|
ct[0],
|
||||||
tmp_elem.at(0).raw(),
|
tmp_elem.at(0).raw(),
|
||||||
row_i,
|
row_i,
|
||||||
tmp_bytes_vmp_prepare_row,
|
tmp_bytes_vmp_prepare_row,
|
||||||
);
|
);
|
||||||
module.vmp_prepare_row(
|
module.vmp_prepare_row(
|
||||||
&mut ct.at_mut(1),
|
&mut ct[1],
|
||||||
tmp_elem.at(1).raw(),
|
tmp_elem.at(1).raw(),
|
||||||
row_i,
|
row_i,
|
||||||
tmp_bytes_vmp_prepare_row,
|
tmp_bytes_vmp_prepare_row,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
//println!("DONE");
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
|
use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
|
||||||
use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
|
use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
|
||||||
pub fn gadget_product_tmp_bytes(
|
pub fn gadget_product_core_tmp_bytes(
|
||||||
module: &Module,
|
module: &Module,
|
||||||
log_base2k: usize,
|
log_base2k: usize,
|
||||||
res_log_q: usize,
|
res_log_q: usize,
|
||||||
@@ -17,14 +17,8 @@ pub fn gadget_product_tmp_bytes(
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Parameters {
|
impl Parameters {
|
||||||
pub fn gadget_product_tmp_bytes(
|
pub fn gadget_product_tmp_bytes(&self, res_log_q: usize, in_log_q: usize, gct_rows: usize, gct_log_q: usize) -> usize {
|
||||||
&self,
|
gadget_product_core_tmp_bytes(
|
||||||
res_log_q: usize,
|
|
||||||
in_log_q: usize,
|
|
||||||
gct_rows: usize,
|
|
||||||
gct_log_q: usize,
|
|
||||||
) -> usize {
|
|
||||||
gadget_product_tmp_bytes(
|
|
||||||
self.module(),
|
self.module(),
|
||||||
self.log_base2k(),
|
self.log_base2k(),
|
||||||
res_log_q,
|
res_log_q,
|
||||||
@@ -35,54 +29,93 @@ impl Parameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Evaluates the gadget product res <- a x b.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `module`: backend support for operations mod (X^N + 1).
|
|
||||||
/// * `res`: an [Elem] to store (-cs + m * a + e, c) with res_ncols cols.
|
|
||||||
/// * `a`: a [VecZnx] of a_ncols cols.
|
|
||||||
/// * `b`: a [Ciphertext<VmpPMat>] as a vector of (-Bs + m * 2^{-k} + E, B)
|
|
||||||
/// containing b_nrows [VecZnx], each of b_ncols cols.
|
|
||||||
///
|
|
||||||
/// # Computation
|
|
||||||
///
|
|
||||||
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
|
|
||||||
/// = (cs + m * a + e, c) with min(res_cols, b_cols) cols.
|
|
||||||
pub fn gadget_product_core(
|
pub fn gadget_product_core(
|
||||||
module: &Module,
|
module: &Module,
|
||||||
res_dft_0: &mut VecZnxDft,
|
res_dft_0: &mut VecZnxDft,
|
||||||
res_dft_1: &mut VecZnxDft,
|
res_dft_1: &mut VecZnxDft,
|
||||||
a: &VecZnx,
|
a: &VecZnx,
|
||||||
a_cols: usize,
|
|
||||||
b: &Ciphertext<VmpPMat>,
|
b: &Ciphertext<VmpPMat>,
|
||||||
b_cols: usize,
|
b_cols: usize,
|
||||||
tmp_bytes: &mut [u8],
|
tmp_bytes: &mut [u8],
|
||||||
) {
|
) {
|
||||||
assert!(b_cols <= b.cols());
|
assert!(b_cols <= b.cols());
|
||||||
module.vec_znx_dft(res_dft_1, a, min(a_cols, b_cols));
|
module.vec_znx_dft(res_dft_1, a);
|
||||||
module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes);
|
module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes);
|
||||||
module.vmp_apply_dft_to_dft_inplace(res_dft_1, b.at(1), tmp_bytes);
|
module.vmp_apply_dft_to_dft_inplace(res_dft_1, b.at(1), tmp_bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
pub fn gadget_product_big_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
|
||||||
// res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)])
|
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
|
||||||
module.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0, b_cols);
|
+ 2 * module.bytes_of_vec_znx_dft(min(c_cols, a_cols));
|
||||||
module.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1, b_cols);
|
}
|
||||||
|
|
||||||
// res_big <- res[0] + res_big[a*G0]
|
/// Evaluates the gadget product: c.at(i) = IDFT(<DFT(a.at(i)), b.at(i)>)
|
||||||
module.vec_znx_big_add_small_inplace(&mut res_big_0, res.at(0));
|
///
|
||||||
module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big_0, tmp_bytes_carry);
|
/// # Arguments
|
||||||
|
///
|
||||||
if OVERWRITE {
|
/// * `module`: backend support for operations mod (X^N + 1).
|
||||||
// res[1] = normalize(res_big[a*G1])
|
/// * `c`: a [Ciphertext<VecZnxBig>] with cols_c cols.
|
||||||
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_1, tmp_bytes_carry);
|
/// * `a`: a [Ciphertext<VecZnx>] with cols_a cols.
|
||||||
} else {
|
/// * `b`: a [Ciphertext<VmpPMat>] with at least min(cols_c, cols_a) rows.
|
||||||
// res[1] = normalize(res_big[a*G1] + res[1])
|
pub fn gadget_product_big(
|
||||||
module.vec_znx_big_add_small_inplace(&mut res_big_1, res.at(1));
|
module: &Module,
|
||||||
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_1, tmp_bytes_carry);
|
c: &mut Ciphertext<VecZnxBig>,
|
||||||
|
a: &Ciphertext<VecZnx>,
|
||||||
|
b: &Ciphertext<VmpPMat>,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
let cols: usize = min(c.cols(), a.cols());
|
||||||
|
|
||||||
|
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||||
|
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||||
|
|
||||||
|
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
|
||||||
|
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
|
||||||
|
|
||||||
|
// a1_dft = DFT(a[1])
|
||||||
|
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||||
|
|
||||||
|
// c[i] = IDFT(DFT(a[1]) * b[i])
|
||||||
|
(0..2).for_each(|i| {
|
||||||
|
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(i), tmp_bytes);
|
||||||
|
module.vec_znx_idft_tmp_a(c.at_mut(i), &mut res_dft);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Evaluates the gadget product: c.at(i) = NORMALIZE(IDFT(<DFT(a.at(i)), b.at(i)>)
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `module`: backend support for operations mod (X^N + 1).
|
||||||
|
/// * `c`: a [Ciphertext<VecZnx>] with cols_c cols.
|
||||||
|
/// * `a`: a [Ciphertext<VecZnx>] with cols_a cols.
|
||||||
|
/// * `b`: a [Ciphertext<VmpPMat>] with at least min(cols_c, cols_a) rows.
|
||||||
|
pub fn gadget_product(
|
||||||
|
module: &Module,
|
||||||
|
c: &mut Ciphertext<VecZnx>,
|
||||||
|
a: &Ciphertext<VecZnx>,
|
||||||
|
b: &Ciphertext<VmpPMat>,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
let cols: usize = min(c.cols(), a.cols());
|
||||||
|
|
||||||
|
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||||
|
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||||
|
|
||||||
|
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
|
||||||
|
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
|
||||||
|
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
|
||||||
|
|
||||||
|
// a1_dft = DFT(a[1])
|
||||||
|
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||||
|
|
||||||
|
// c[i] = IDFT(DFT(a[1]) * b[i])
|
||||||
|
(0..2).for_each(|i| {
|
||||||
|
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(i), tmp_bytes);
|
||||||
|
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||||
|
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(i), &mut res_big, tmp_bytes);
|
||||||
|
})
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
@@ -97,8 +130,8 @@ mod test {
|
|||||||
plaintext::Plaintext,
|
plaintext::Plaintext,
|
||||||
};
|
};
|
||||||
use base2k::{
|
use base2k::{
|
||||||
Infos, BACKEND, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
|
BACKEND, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat,
|
||||||
VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8,
|
alloc_aligned_u8,
|
||||||
};
|
};
|
||||||
use sampling::source::{Source, new_seed};
|
use sampling::source::{Source, new_seed};
|
||||||
|
|
||||||
@@ -125,7 +158,6 @@ mod test {
|
|||||||
// scratch space
|
// scratch space
|
||||||
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
|
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
|
||||||
params.decrypt_rlwe_tmp_byte(params.log_qp())
|
params.decrypt_rlwe_tmp_byte(params.log_qp())
|
||||||
| params.encrypt_rlwe_sk_tmp_bytes(params.log_qp())
|
|
||||||
| params.gadget_product_tmp_bytes(
|
| params.gadget_product_tmp_bytes(
|
||||||
params.log_qp(),
|
params.log_qp(),
|
||||||
params.log_qp(),
|
params.log_qp(),
|
||||||
@@ -172,10 +204,6 @@ mod test {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Intermediate buffers
|
// Intermediate buffers
|
||||||
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
|
|
||||||
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
|
|
||||||
let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big();
|
|
||||||
let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big();
|
|
||||||
|
|
||||||
// Input polynopmial, uniformly distributed
|
// Input polynopmial, uniformly distributed
|
||||||
let mut a: VecZnx = params.module().new_vec_znx(params.cols_q());
|
let mut a: VecZnx = params.module().new_vec_znx(params.cols_q());
|
||||||
@@ -184,8 +212,7 @@ mod test {
|
|||||||
.fill_uniform(log_base2k, &mut a, params.cols_q(), &mut source_xa);
|
.fill_uniform(log_base2k, &mut a, params.cols_q(), &mut source_xa);
|
||||||
|
|
||||||
// res = g^-1(a) * gct
|
// res = g^-1(a) * gct
|
||||||
let mut elem_res: Elem<VecZnx> =
|
let mut elem_res: Elem<VecZnx> = Elem::<VecZnx>::new(params.module(), log_base2k, params.log_qp(), 2);
|
||||||
Elem::<VecZnx>::new(params.module(), log_base2k, params.log_qp(), 2);
|
|
||||||
|
|
||||||
// Ideal output = a * s
|
// Ideal output = a * s
|
||||||
let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(a.cols());
|
let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(a.cols());
|
||||||
@@ -193,31 +220,31 @@ mod test {
|
|||||||
let mut a_times_s: VecZnx = params.module().new_vec_znx(a.cols());
|
let mut a_times_s: VecZnx = params.module().new_vec_znx(a.cols());
|
||||||
|
|
||||||
// a * sk0
|
// a * sk0
|
||||||
|
params.module().svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a);
|
||||||
|
params.module().vec_znx_idft_tmp_a(&mut a_big, &mut a_dft);
|
||||||
params
|
params
|
||||||
.module()
|
.module()
|
||||||
.svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a, a.cols());
|
.vec_znx_big_normalize(params.log_base2k(), &mut a_times_s, &a_big, &mut tmp_bytes);
|
||||||
params
|
|
||||||
.module()
|
|
||||||
.vec_znx_idft_tmp_a(&mut a_big, &mut a_dft, a.cols());
|
|
||||||
params.module().vec_znx_big_normalize(
|
|
||||||
params.log_base2k(),
|
|
||||||
&mut a_times_s,
|
|
||||||
&a_big,
|
|
||||||
&mut tmp_bytes,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Plaintext for decrypted output of gadget product
|
// Plaintext for decrypted output of gadget product
|
||||||
let mut pt: Plaintext =
|
let mut pt: Plaintext = Plaintext::new(params.module(), params.log_base2k(), params.log_qp());
|
||||||
Plaintext::new(params.module(), params.log_base2k(), params.log_qp());
|
|
||||||
|
|
||||||
// Iterates over all possible cols values for input/output polynomials and gadget ciphertext.
|
// Iterates over all possible cols values for input/output polynomials and gadget ciphertext.
|
||||||
|
|
||||||
(1..a.cols() + 1).for_each(|a_cols| {
|
(1..a.cols() + 1).for_each(|a_cols| {
|
||||||
|
let mut a_trunc: VecZnx = params.module().new_vec_znx(a_cols);
|
||||||
|
a_trunc.copy_from(&a);
|
||||||
|
|
||||||
(1..gadget_ct.cols() + 1).for_each(|b_cols| {
|
(1..gadget_ct.cols() + 1).for_each(|b_cols| {
|
||||||
|
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(b_cols);
|
||||||
|
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(b_cols);
|
||||||
|
let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big();
|
||||||
|
let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big();
|
||||||
|
|
||||||
pt.elem_mut().zero();
|
pt.elem_mut().zero();
|
||||||
elem_res.zero();
|
elem_res.zero();
|
||||||
|
|
||||||
//let b_cols: usize = min(a_cols+1, gadget_ct.cols());
|
// let b_cols: usize = min(a_cols+1, gadget_ct.cols());
|
||||||
|
|
||||||
println!("a_cols: {} b_cols: {}", a_cols, b_cols);
|
println!("a_cols: {} b_cols: {}", a_cols, b_cols);
|
||||||
|
|
||||||
@@ -227,8 +254,7 @@ mod test {
|
|||||||
params.module(),
|
params.module(),
|
||||||
&mut res_dft_0,
|
&mut res_dft_0,
|
||||||
&mut res_dft_1,
|
&mut res_dft_1,
|
||||||
&a,
|
&a_trunc,
|
||||||
a_cols,
|
|
||||||
&gadget_ct,
|
&gadget_ct,
|
||||||
b_cols,
|
b_cols,
|
||||||
&mut tmp_bytes,
|
&mut tmp_bytes,
|
||||||
@@ -237,27 +263,21 @@ mod test {
|
|||||||
// res_big_0 = IDFT(res_dft_0)
|
// res_big_0 = IDFT(res_dft_0)
|
||||||
params
|
params
|
||||||
.module()
|
.module()
|
||||||
.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0, b_cols);
|
.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0);
|
||||||
// res_big_1 = IDFT(res_dft_1);
|
// res_big_1 = IDFT(res_dft_1);
|
||||||
params
|
params
|
||||||
.module()
|
.module()
|
||||||
.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1, b_cols);
|
.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1);
|
||||||
|
|
||||||
// res_big_0 = normalize(res_big_0)
|
// res_big_0 = normalize(res_big_0)
|
||||||
params.module().vec_znx_big_normalize(
|
params
|
||||||
log_base2k,
|
.module()
|
||||||
elem_res.at_mut(0),
|
.vec_znx_big_normalize(log_base2k, elem_res.at_mut(0), &res_big_0, &mut tmp_bytes);
|
||||||
&res_big_0,
|
|
||||||
&mut tmp_bytes,
|
|
||||||
);
|
|
||||||
|
|
||||||
// res_big_1 = normalize(res_big_1)
|
// res_big_1 = normalize(res_big_1)
|
||||||
params.module().vec_znx_big_normalize(
|
params
|
||||||
log_base2k,
|
.module()
|
||||||
elem_res.at_mut(1),
|
.vec_znx_big_normalize(log_base2k, elem_res.at_mut(1), &res_big_1, &mut tmp_bytes);
|
||||||
&res_big_1,
|
|
||||||
&mut tmp_bytes,
|
|
||||||
);
|
|
||||||
|
|
||||||
// <(-c*sk1 + a*sk0 + e, a), (1, sk1)> = a*sk0 + e
|
// <(-c*sk1 + a*sk0 + e, a), (1, sk1)> = a*sk0 + e
|
||||||
decrypt_rlwe(
|
decrypt_rlwe(
|
||||||
@@ -271,10 +291,10 @@ mod test {
|
|||||||
// a * sk0 + e - a*sk0 = e
|
// a * sk0 + e - a*sk0 = e
|
||||||
params
|
params
|
||||||
.module()
|
.module()
|
||||||
.vec_znx_sub_inplace(pt.at_mut(0), &mut a_times_s);
|
.vec_znx_sub_ab_inplace(pt.at_mut(0), &mut a_times_s);
|
||||||
pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes);
|
pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes);
|
||||||
|
|
||||||
//pt.at(0).print(pt.elem().cols(), 16);
|
// pt.at(0).print(pt.elem().cols(), 16);
|
||||||
|
|
||||||
let noise_have: f64 = pt.at(0).std(log_base2k).log2();
|
let noise_have: f64 = pt.at(0).std(log_base2k).log2();
|
||||||
|
|
||||||
@@ -288,28 +308,23 @@ mod test {
|
|||||||
|
|
||||||
let a_logq: usize = a_cols * log_base2k;
|
let a_logq: usize = a_cols * log_base2k;
|
||||||
let b_logq: usize = b_cols * log_base2k;
|
let b_logq: usize = b_cols * log_base2k;
|
||||||
let var_msg: f64 = params.xs() as f64;
|
let var_msg: f64 = (params.xs() as f64) / params.n() as f64;
|
||||||
|
|
||||||
let noise_pred: f64 =
|
println!("{} {} {} {}", var_msg, var_a_err, a_logq, b_logq);
|
||||||
params.noise_grlwe_product(var_msg, var_a_err, a_logq, b_logq);
|
|
||||||
|
|
||||||
assert!(noise_have <= noise_pred + 1.0);
|
let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, a_logq, b_logq);
|
||||||
|
|
||||||
println!("noise_pred: {}", noise_have);
|
println!("noise_pred: {}", noise_pred);
|
||||||
println!("noise_have: {}", noise_pred);
|
println!("noise_have: {}", noise_have);
|
||||||
|
|
||||||
|
// assert!(noise_have <= noise_pred + 1.0);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Parameters {
|
impl Parameters {
|
||||||
pub fn noise_grlwe_product(
|
pub fn noise_grlwe_product(&self, var_msg: f64, var_a_err: f64, a_logq: usize, b_logq: usize) -> f64 {
|
||||||
&self,
|
|
||||||
var_msg: f64,
|
|
||||||
var_a_err: f64,
|
|
||||||
a_logq: usize,
|
|
||||||
b_logq: usize,
|
|
||||||
) -> f64 {
|
|
||||||
let n: f64 = self.n() as f64;
|
let n: f64 = self.n() as f64;
|
||||||
let var_xs: f64 = self.xs() as f64;
|
let var_xs: f64 = self.xs() as f64;
|
||||||
|
|
||||||
@@ -360,9 +375,8 @@ pub fn noise_grlwe_product(
|
|||||||
|
|
||||||
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
|
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
|
||||||
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
|
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
|
||||||
let mut noise: f64 =
|
let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
|
||||||
(a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
|
noise += var_msg * var_a_err * a_scale * a_scale * n;
|
||||||
noise += var_msg * var_a_err * a_scale * a_scale;
|
|
||||||
noise = noise.sqrt();
|
noise = noise.sqrt();
|
||||||
noise /= b_scale;
|
noise /= b_scale;
|
||||||
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
|
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
|
||||||
|
|||||||
@@ -7,11 +7,7 @@ use sampling::source::Source;
|
|||||||
pub struct KeyGenerator {}
|
pub struct KeyGenerator {}
|
||||||
|
|
||||||
impl KeyGenerator {
|
impl KeyGenerator {
|
||||||
pub fn gen_secret_key_thread_safe(
|
pub fn gen_secret_key_thread_safe(&self, params: &Parameters, source: &mut Source) -> SecretKey {
|
||||||
&self,
|
|
||||||
params: &Parameters,
|
|
||||||
source: &mut Source,
|
|
||||||
) -> SecretKey {
|
|
||||||
let mut sk: SecretKey = SecretKey::new(params.module());
|
let mut sk: SecretKey = SecretKey::new(params.module());
|
||||||
sk.fill_ternary_hw(params.xs(), source);
|
sk.fill_ternary_hw(params.xs(), source);
|
||||||
sk
|
sk
|
||||||
@@ -26,8 +22,7 @@ impl KeyGenerator {
|
|||||||
) -> PublicKey {
|
) -> PublicKey {
|
||||||
let mut xa_source: Source = source.branch();
|
let mut xa_source: Source = source.branch();
|
||||||
let mut xe_source: Source = source.branch();
|
let mut xe_source: Source = source.branch();
|
||||||
let mut pk: PublicKey =
|
let mut pk: PublicKey = PublicKey::new(params.module(), params.log_base2k(), params.log_qp());
|
||||||
PublicKey::new(params.module(), params.log_base2k(), params.log_qp());
|
|
||||||
pk.gen_thread_safe(
|
pk.gen_thread_safe(
|
||||||
params.module(),
|
params.module(),
|
||||||
sk_ppol,
|
sk_ppol,
|
||||||
@@ -40,12 +35,7 @@ impl KeyGenerator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn gen_switching_key_tmp_bytes(
|
pub fn gen_switching_key_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
|
||||||
module: &Module,
|
|
||||||
log_base2k: usize,
|
|
||||||
rows: usize,
|
|
||||||
log_q: usize,
|
|
||||||
) -> usize {
|
|
||||||
encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
|
encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
79
rlwe/src/key_switching.rs
Normal file
79
rlwe/src/key_switching.rs
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
use crate::ciphertext::Ciphertext;
|
||||||
|
use crate::elem::ElemCommon;
|
||||||
|
use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement};
|
||||||
|
use std::cmp::min;
|
||||||
|
|
||||||
|
pub fn key_switch_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
|
||||||
|
let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k;
|
||||||
|
let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k;
|
||||||
|
let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k;
|
||||||
|
return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols)
|
||||||
|
+ module.bytes_of_vec_znx_dft(std::cmp::min(res_cols, in_cols))
|
||||||
|
+ module.bytes_of_vec_znx_dft(gct_cols);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn key_switch_rlwe(
|
||||||
|
module: &Module,
|
||||||
|
c: &mut Ciphertext<VecZnx>,
|
||||||
|
a: &Ciphertext<VecZnx>,
|
||||||
|
b: &Ciphertext<VmpPMat>,
|
||||||
|
b_cols: usize,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
key_switch_rlwe_core(module, c, a, b, b_cols, tmp_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn key_switch_rlwe_inplace(
|
||||||
|
module: &Module,
|
||||||
|
a: &mut Ciphertext<VecZnx>,
|
||||||
|
b: &Ciphertext<VmpPMat>,
|
||||||
|
b_cols: usize,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
key_switch_rlwe_core(module, a, a, b, b_cols, tmp_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn key_switch_rlwe_core(
|
||||||
|
module: &Module,
|
||||||
|
c: *mut Ciphertext<VecZnx>,
|
||||||
|
a: *const Ciphertext<VecZnx>,
|
||||||
|
b: &Ciphertext<VmpPMat>,
|
||||||
|
b_cols: usize,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
// SAFETY WARNING: must ensure `c` and `a` are valid for read/write
|
||||||
|
let c: &mut Ciphertext<VecZnx> = unsafe { &mut *c };
|
||||||
|
let a: &Ciphertext<VecZnx> = unsafe { &*a };
|
||||||
|
|
||||||
|
let cols: usize = min(min(c.cols(), a.cols()), b.rows());
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert!(b_cols <= b.cols());
|
||||||
|
assert!(tmp_bytes.len() >= key_switch_tmp_bytes(module, c.cols(), a.cols(), b.rows(), b.cols()));
|
||||||
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
let (tmp_bytes_a1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||||
|
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
|
||||||
|
|
||||||
|
let mut a1_dft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_a1_dft);
|
||||||
|
let mut res_dft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft);
|
||||||
|
let mut res_big = res_dft.as_vec_znx_big();
|
||||||
|
|
||||||
|
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||||
|
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(0), tmp_bytes);
|
||||||
|
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||||
|
|
||||||
|
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
||||||
|
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut res_big, tmp_bytes);
|
||||||
|
|
||||||
|
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(1), tmp_bytes);
|
||||||
|
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||||
|
|
||||||
|
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut res_big, tmp_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn key_switch_grlwe(module: &Module, c: &mut Ciphertext<VecZnx>, a: &Ciphertext<VecZnx>, b: &Ciphertext<VmpPMat>) {}
|
||||||
|
|
||||||
|
pub fn key_switch_rgsw(module: &Module, c: &mut Ciphertext<VecZnx>, a: &Ciphertext<VecZnx>, b: &Ciphertext<VmpPMat>) {}
|
||||||
@@ -1,10 +1,13 @@
|
|||||||
|
pub mod automorphism;
|
||||||
pub mod ciphertext;
|
pub mod ciphertext;
|
||||||
pub mod decryptor;
|
pub mod decryptor;
|
||||||
pub mod elem;
|
pub mod elem;
|
||||||
pub mod encryptor;
|
pub mod encryptor;
|
||||||
pub mod gadget_product;
|
pub mod gadget_product;
|
||||||
pub mod key_generator;
|
pub mod key_generator;
|
||||||
|
pub mod key_switching;
|
||||||
pub mod keys;
|
pub mod keys;
|
||||||
pub mod parameters;
|
pub mod parameters;
|
||||||
pub mod plaintext;
|
pub mod plaintext;
|
||||||
pub mod rgsw_product;
|
pub mod rgsw_product;
|
||||||
|
pub mod trace;
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
use base2k::module::{BACKEND, Module};
|
use base2k::module::{BACKEND, Module};
|
||||||
|
|
||||||
|
pub const DEFAULT_SIGMA: f64 = 3.2;
|
||||||
|
|
||||||
pub struct ParametersLiteral {
|
pub struct ParametersLiteral {
|
||||||
pub backend: BACKEND,
|
pub backend: BACKEND,
|
||||||
pub log_n: usize,
|
pub log_n: usize,
|
||||||
|
|||||||
@@ -43,12 +43,7 @@ impl Plaintext {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_bytes_borrow(
|
pub fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self {
|
||||||
module: &Module,
|
|
||||||
log_base2k: usize,
|
|
||||||
log_q: usize,
|
|
||||||
bytes: &mut [u8],
|
|
||||||
) -> Self {
|
|
||||||
Self(Elem::<VecZnx>::from_bytes_borrow(
|
Self(Elem::<VecZnx>::from_bytes_borrow(
|
||||||
module, log_base2k, log_q, 1, bytes,
|
module, log_base2k, log_q, 1, bytes,
|
||||||
))
|
))
|
||||||
|
|||||||
@@ -1,54 +1,300 @@
|
|||||||
use crate::{
|
use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
|
||||||
ciphertext::Ciphertext,
|
use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement};
|
||||||
elem::{Elem, ElemCommon, ElemVecZnx},
|
|
||||||
};
|
|
||||||
use base2k::{
|
|
||||||
Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps,
|
|
||||||
};
|
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
|
||||||
|
impl Parameters {
|
||||||
|
pub fn rgsw_product_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
|
||||||
|
rgsw_product_tmp_bytes(
|
||||||
|
self.module(),
|
||||||
|
self.log_base2k(),
|
||||||
|
res_logq,
|
||||||
|
in_logq,
|
||||||
|
gct_logq,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn rgsw_product_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
|
||||||
|
let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k;
|
||||||
|
let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k;
|
||||||
|
let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k;
|
||||||
|
return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols)
|
||||||
|
+ module.bytes_of_vec_znx_dft(std::cmp::min(res_cols, in_cols))
|
||||||
|
+ 2 * module.bytes_of_vec_znx_dft(gct_cols);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn rgsw_product(
|
pub fn rgsw_product(
|
||||||
module: &Module,
|
module: &Module,
|
||||||
_res: &mut Elem<VecZnx>,
|
c: &mut Ciphertext<VecZnx>,
|
||||||
a: &Ciphertext<VecZnx>,
|
a: &Ciphertext<VecZnx>,
|
||||||
b: &Ciphertext<VmpPMat>,
|
b: &Ciphertext<VmpPMat>,
|
||||||
|
b_cols: usize,
|
||||||
tmp_bytes: &mut [u8],
|
tmp_bytes: &mut [u8],
|
||||||
) {
|
) {
|
||||||
let _log_base2k: usize = b.log_base2k();
|
#[cfg(debug_assertions)]
|
||||||
let rows: usize = min(b.rows(), a.cols());
|
{
|
||||||
let cols: usize = b.cols();
|
assert!(b_cols <= b.cols());
|
||||||
let in_cols = a.cols();
|
assert_eq!(c.size(), 2);
|
||||||
let out_cols: usize = a.cols();
|
assert_eq!(a.size(), 2);
|
||||||
|
assert_eq!(b.size(), 4);
|
||||||
|
assert!(tmp_bytes.len() >= rgsw_product_tmp_bytes(module, c.cols(), a.cols(), min(b.rows(), a.cols()), b_cols));
|
||||||
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
let bytes_of_vec_znx_dft = module.bytes_of_vec_znx_dft(cols);
|
let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(a.cols()));
|
||||||
let bytes_of_vmp_apply_dft_to_dft =
|
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
|
||||||
module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, rows, cols);
|
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
|
||||||
|
|
||||||
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(a.cols(), tmp_bytes_ai_dft);
|
||||||
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c0_dft);
|
||||||
let (tmp_bytes_tmp_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c1_dft);
|
||||||
let (tmp_bytes_r1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
|
||||||
let (tmp_bytes_r2_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
|
|
||||||
let (bytes_of_vmp_apply_dft_to_dft, tmp_bytes) =
|
|
||||||
tmp_bytes.split_at_mut(bytes_of_vmp_apply_dft_to_dft);
|
|
||||||
|
|
||||||
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c0_dft);
|
let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big();
|
||||||
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft);
|
let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big();
|
||||||
let mut _tmp_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_tmp_dft);
|
|
||||||
let mut r1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r1_dft);
|
|
||||||
let mut _r2_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r2_dft);
|
|
||||||
|
|
||||||
// c0_dft <- DFT(a[0])
|
module.vec_znx_dft(&mut ai_dft, a.at(0));
|
||||||
module.vec_znx_dft(&mut c0_dft, a.at(0), in_cols);
|
module.vmp_apply_dft_to_dft(&mut c0_dft, &ai_dft, b.at(0), tmp_bytes);
|
||||||
|
module.vmp_apply_dft_to_dft(&mut c1_dft, &ai_dft, b.at(1), tmp_bytes);
|
||||||
|
|
||||||
// r_dft <- sum[rows] c0_dft[cols] x RGSW[0][cols]
|
module.vec_znx_dft(&mut ai_dft, a.at(1));
|
||||||
module.vmp_apply_dft_to_dft(
|
module.vmp_apply_dft_to_dft_add(&mut c0_dft, &ai_dft, b.at(2), tmp_bytes);
|
||||||
&mut r1_dft,
|
module.vmp_apply_dft_to_dft_add(&mut c1_dft, &ai_dft, b.at(3), tmp_bytes);
|
||||||
&c1_dft,
|
|
||||||
&b.0.value[0],
|
|
||||||
bytes_of_vmp_apply_dft_to_dft,
|
|
||||||
);
|
|
||||||
|
|
||||||
// c1_dft <- DFT(a[1])
|
module.vec_znx_idft_tmp_a(&mut c0_big, &mut c0_dft);
|
||||||
module.vec_znx_dft(&mut c1_dft, a.at(1), in_cols);
|
module.vec_znx_idft_tmp_a(&mut c1_big, &mut c1_dft);
|
||||||
|
|
||||||
|
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut c0_big, tmp_bytes);
|
||||||
|
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut c1_big, tmp_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rgsw_product_inplace(
|
||||||
|
module: &Module,
|
||||||
|
a: &mut Ciphertext<VecZnx>,
|
||||||
|
b: &Ciphertext<VmpPMat>,
|
||||||
|
b_cols: usize,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert!(b_cols <= b.cols());
|
||||||
|
assert_eq!(a.size(), 2);
|
||||||
|
assert_eq!(b.size(), 4);
|
||||||
|
assert!(tmp_bytes.len() >= rgsw_product_tmp_bytes(module, a.cols(), a.cols(), min(b.rows(), a.cols()), b_cols));
|
||||||
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(a.cols()));
|
||||||
|
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
|
||||||
|
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
|
||||||
|
|
||||||
|
let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(a.cols(), tmp_bytes_ai_dft);
|
||||||
|
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c0_dft);
|
||||||
|
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c1_dft);
|
||||||
|
|
||||||
|
let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big();
|
||||||
|
let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big();
|
||||||
|
|
||||||
|
module.vec_znx_dft(&mut ai_dft, a.at(0));
|
||||||
|
module.vmp_apply_dft_to_dft(&mut c0_dft, &ai_dft, b.at(0), tmp_bytes);
|
||||||
|
module.vmp_apply_dft_to_dft(&mut c1_dft, &ai_dft, b.at(1), tmp_bytes);
|
||||||
|
|
||||||
|
module.vec_znx_dft(&mut ai_dft, a.at(1));
|
||||||
|
module.vmp_apply_dft_to_dft_add(&mut c0_dft, &ai_dft, b.at(2), tmp_bytes);
|
||||||
|
module.vmp_apply_dft_to_dft_add(&mut c1_dft, &ai_dft, b.at(3), tmp_bytes);
|
||||||
|
|
||||||
|
module.vec_znx_idft_tmp_a(&mut c0_big, &mut c0_dft);
|
||||||
|
module.vec_znx_idft_tmp_a(&mut c1_big, &mut c1_dft);
|
||||||
|
|
||||||
|
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut c0_big, tmp_bytes);
|
||||||
|
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut c1_big, tmp_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use crate::{
|
||||||
|
ciphertext::{Ciphertext, new_rgsw_ciphertext},
|
||||||
|
decryptor::decrypt_rlwe,
|
||||||
|
elem::ElemCommon,
|
||||||
|
encryptor::{encrypt_rgsw_sk, encrypt_rlwe_sk},
|
||||||
|
keys::SecretKey,
|
||||||
|
parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral},
|
||||||
|
plaintext::Plaintext,
|
||||||
|
rgsw_product::rgsw_product_inplace,
|
||||||
|
};
|
||||||
|
use base2k::{BACKEND, Encoding, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, VmpPMat, alloc_aligned};
|
||||||
|
use sampling::source::{Source, new_seed};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rgsw_product() {
|
||||||
|
let log_base2k: usize = 10;
|
||||||
|
let log_q: usize = 50;
|
||||||
|
let log_p: usize = 15;
|
||||||
|
|
||||||
|
// Basic parameters with enough limbs to test edge cases
|
||||||
|
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||||
|
backend: BACKEND::FFT64,
|
||||||
|
log_n: 12,
|
||||||
|
log_q: log_q,
|
||||||
|
log_p: log_p,
|
||||||
|
log_base2k: log_base2k,
|
||||||
|
log_scale: 20,
|
||||||
|
xe: 3.2,
|
||||||
|
xs: 1 << 11,
|
||||||
|
};
|
||||||
|
|
||||||
|
let params: Parameters = Parameters::new(¶ms_lit);
|
||||||
|
|
||||||
|
let module: &Module = params.module();
|
||||||
|
let log_q: usize = params.log_q();
|
||||||
|
let log_qp: usize = params.log_qp();
|
||||||
|
let gct_rows: usize = params.cols_q();
|
||||||
|
let gct_cols: usize = params.cols_qp();
|
||||||
|
|
||||||
|
// scratch space
|
||||||
|
let mut tmp_bytes: Vec<u8> = alloc_aligned(
|
||||||
|
params.decrypt_rlwe_tmp_byte(log_q)
|
||||||
|
| params.encrypt_rlwe_sk_tmp_bytes(log_q)
|
||||||
|
| params.rgsw_product_tmp_bytes(log_q, log_q, log_qp)
|
||||||
|
| params.encrypt_rgsw_sk_tmp_bytes(gct_rows, log_qp),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Samplers for public and private randomness
|
||||||
|
let mut source_xe: Source = Source::new(new_seed());
|
||||||
|
let mut source_xa: Source = Source::new(new_seed());
|
||||||
|
let mut source_xs: Source = Source::new(new_seed());
|
||||||
|
|
||||||
|
let mut sk: SecretKey = SecretKey::new(module);
|
||||||
|
sk.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||||
|
let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol();
|
||||||
|
module.svp_prepare(&mut sk_svp_ppol, &sk.0);
|
||||||
|
|
||||||
|
let mut ct_rgsw: Ciphertext<VmpPMat> = new_rgsw_ciphertext(module, log_base2k, gct_rows, log_qp);
|
||||||
|
|
||||||
|
let k: i64 = 3;
|
||||||
|
|
||||||
|
// X^k
|
||||||
|
let m: Scalar = module.new_scalar();
|
||||||
|
let data: &mut [i64] = m.raw_mut();
|
||||||
|
data[k as usize] = 1;
|
||||||
|
|
||||||
|
encrypt_rgsw_sk(
|
||||||
|
module,
|
||||||
|
&mut ct_rgsw,
|
||||||
|
&m,
|
||||||
|
&sk_svp_ppol,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
DEFAULT_SIGMA,
|
||||||
|
&mut tmp_bytes,
|
||||||
|
);
|
||||||
|
|
||||||
|
let log_k: usize = 2 * log_base2k;
|
||||||
|
|
||||||
|
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
|
||||||
|
let mut pt: Plaintext = params.new_plaintext(log_q);
|
||||||
|
let mut pt_rotate: Plaintext = params.new_plaintext(log_q);
|
||||||
|
|
||||||
|
pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32);
|
||||||
|
|
||||||
|
module.vec_znx_rotate(k, pt_rotate.at_mut(0), pt.at_mut(0));
|
||||||
|
|
||||||
|
encrypt_rlwe_sk(
|
||||||
|
module,
|
||||||
|
&mut ct.elem_mut(),
|
||||||
|
Some(pt.at(0)),
|
||||||
|
&sk_svp_ppol,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
params.xe(),
|
||||||
|
&mut tmp_bytes,
|
||||||
|
);
|
||||||
|
|
||||||
|
rgsw_product_inplace(module, &mut ct, &ct_rgsw, gct_cols, &mut tmp_bytes);
|
||||||
|
|
||||||
|
decrypt_rlwe(
|
||||||
|
module,
|
||||||
|
pt.elem_mut(),
|
||||||
|
ct.elem(),
|
||||||
|
&sk_svp_ppol,
|
||||||
|
&mut tmp_bytes,
|
||||||
|
);
|
||||||
|
|
||||||
|
module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_rotate.at(0));
|
||||||
|
|
||||||
|
// pt.at(0).print(pt.cols(), 16);
|
||||||
|
|
||||||
|
let noise_have: f64 = pt.at(0).std(log_base2k).log2();
|
||||||
|
|
||||||
|
let var_msg: f64 = 1f64 / params.n() as f64; // X^{k}
|
||||||
|
let var_a0_err: f64 = params.xe() * params.xe();
|
||||||
|
let var_a1_err: f64 = 1f64 / 12f64;
|
||||||
|
|
||||||
|
let noise_pred: f64 = params.noise_rgsw_product(var_msg, var_a0_err, var_a1_err, ct.log_q(), ct_rgsw.log_q());
|
||||||
|
|
||||||
|
println!("noise_pred: {}", noise_pred);
|
||||||
|
println!("noise_have: {}", noise_have);
|
||||||
|
|
||||||
|
assert!(noise_have <= noise_pred + 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parameters {
|
||||||
|
pub fn noise_rgsw_product(&self, var_msg: f64, var_a0_err: f64, var_a1_err: f64, a_logq: usize, b_logq: usize) -> f64 {
|
||||||
|
let n: f64 = self.n() as f64;
|
||||||
|
let var_xs: f64 = self.xs() as f64;
|
||||||
|
|
||||||
|
let var_gct_err_lhs: f64;
|
||||||
|
let var_gct_err_rhs: f64;
|
||||||
|
if b_logq < self.log_qp() {
|
||||||
|
let var_round: f64 = 1f64 / 12f64;
|
||||||
|
var_gct_err_lhs = var_round;
|
||||||
|
var_gct_err_rhs = var_round;
|
||||||
|
} else {
|
||||||
|
var_gct_err_lhs = self.xe() * self.xe();
|
||||||
|
var_gct_err_rhs = 0f64;
|
||||||
|
}
|
||||||
|
|
||||||
|
noise_rgsw_product(
|
||||||
|
n,
|
||||||
|
self.log_base2k(),
|
||||||
|
var_xs,
|
||||||
|
var_msg,
|
||||||
|
var_a0_err,
|
||||||
|
var_a1_err,
|
||||||
|
var_gct_err_lhs,
|
||||||
|
var_gct_err_rhs,
|
||||||
|
a_logq,
|
||||||
|
b_logq,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn noise_rgsw_product(
|
||||||
|
n: f64,
|
||||||
|
log_base2k: usize,
|
||||||
|
var_xs: f64,
|
||||||
|
var_msg: f64,
|
||||||
|
var_a0_err: f64,
|
||||||
|
var_a1_err: f64,
|
||||||
|
var_gct_err_lhs: f64,
|
||||||
|
var_gct_err_rhs: f64,
|
||||||
|
a_logq: usize,
|
||||||
|
b_logq: usize,
|
||||||
|
) -> f64 {
|
||||||
|
let a_logq: usize = min(a_logq, b_logq);
|
||||||
|
let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k;
|
||||||
|
|
||||||
|
let b_scale = 2.0f64.powi(b_logq as i32);
|
||||||
|
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
|
||||||
|
|
||||||
|
let base: f64 = (1 << (log_base2k)) as f64;
|
||||||
|
let var_base: f64 = base * base / 12f64;
|
||||||
|
|
||||||
|
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
|
||||||
|
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
|
||||||
|
let mut noise: f64 = 2.0 * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
|
||||||
|
noise += var_msg * var_a0_err * a_scale * a_scale * n;
|
||||||
|
noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs;
|
||||||
|
noise = noise.sqrt();
|
||||||
|
noise /= b_scale;
|
||||||
|
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
|
||||||
}
|
}
|
||||||
|
|||||||
113
rlwe/src/test.rs
Normal file
113
rlwe/src/test.rs
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
use base2k::{alloc_aligned, SvpPPol, SvpPPolOps, VecZnx, BACKEND};
|
||||||
|
use sampling::source::{Source, new_seed};
|
||||||
|
use crate::{ciphertext::Ciphertext, decryptor::decrypt_rlwe, elem::ElemCommon, encryptor::encrypt_rlwe_sk, keys::SecretKey, parameters::{Parameters, ParametersLiteral, DEFAULT_SIGMA}, plaintext::Plaintext};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
pub struct Context{
|
||||||
|
pub params: Parameters,
|
||||||
|
pub sk0: SecretKey,
|
||||||
|
pub sk0_ppol:SvpPPol,
|
||||||
|
pub sk1: SecretKey,
|
||||||
|
pub sk1_ppol: SvpPPol,
|
||||||
|
pub tmp_bytes: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Context{
|
||||||
|
pub fn new(log_n: usize, log_base2k: usize, log_q: usize, log_p: usize) -> Self{
|
||||||
|
|
||||||
|
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||||
|
backend: BACKEND::FFT64,
|
||||||
|
log_n: log_n,
|
||||||
|
log_q: log_q,
|
||||||
|
log_p: log_p,
|
||||||
|
log_base2k: log_base2k,
|
||||||
|
log_scale: 20,
|
||||||
|
xe: DEFAULT_SIGMA,
|
||||||
|
xs: 1 << (log_n-1),
|
||||||
|
};
|
||||||
|
|
||||||
|
let params: Parameters =Parameters::new(¶ms_lit);
|
||||||
|
let module = params.module();
|
||||||
|
|
||||||
|
let log_q: usize = params.log_q();
|
||||||
|
|
||||||
|
let mut source_xs: Source = Source::new(new_seed());
|
||||||
|
|
||||||
|
let mut sk0: SecretKey = SecretKey::new(module);
|
||||||
|
sk0.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||||
|
let mut sk0_ppol: base2k::SvpPPol = module.new_svp_ppol();
|
||||||
|
module.svp_prepare(&mut sk0_ppol, &sk0.0);
|
||||||
|
|
||||||
|
let mut sk1: SecretKey = SecretKey::new(module);
|
||||||
|
sk1.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||||
|
let mut sk1_ppol: base2k::SvpPPol = module.new_svp_ppol();
|
||||||
|
module.svp_prepare(&mut sk1_ppol, &sk1.0);
|
||||||
|
|
||||||
|
let tmp_bytes: Vec<u8> = alloc_aligned(params.decrypt_rlwe_tmp_byte(log_q)| params.encrypt_rlwe_sk_tmp_bytes(log_q));
|
||||||
|
|
||||||
|
Context{
|
||||||
|
params: params,
|
||||||
|
sk0: sk0,
|
||||||
|
sk0_ppol: sk0_ppol,
|
||||||
|
sk1: sk1,
|
||||||
|
sk1_ppol: sk1_ppol,
|
||||||
|
tmp_bytes: tmp_bytes,
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encrypt_rlwe_sk0(&mut self, pt: &Plaintext, ct: &mut Ciphertext<VecZnx>){
|
||||||
|
|
||||||
|
let mut source_xe: Source = Source::new(new_seed());
|
||||||
|
let mut source_xa: Source = Source::new(new_seed());
|
||||||
|
|
||||||
|
encrypt_rlwe_sk(
|
||||||
|
self.params.module(),
|
||||||
|
ct.elem_mut(),
|
||||||
|
Some(pt.elem()),
|
||||||
|
&self.sk0_ppol,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
self.params.xe(),
|
||||||
|
&mut self.tmp_bytes,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encrypt_rlwe_sk1(&mut self, ct: &mut Ciphertext<VecZnx>, pt: &Plaintext){
|
||||||
|
|
||||||
|
let mut source_xe: Source = Source::new(new_seed());
|
||||||
|
let mut source_xa: Source = Source::new(new_seed());
|
||||||
|
|
||||||
|
encrypt_rlwe_sk(
|
||||||
|
self.params.module(),
|
||||||
|
ct.elem_mut(),
|
||||||
|
Some(pt.elem()),
|
||||||
|
&self.sk1_ppol,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
self.params.xe(),
|
||||||
|
&mut self.tmp_bytes,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decrypt_sk0(&mut self, pt: &mut Plaintext, ct: &Ciphertext<VecZnx>){
|
||||||
|
decrypt_rlwe(
|
||||||
|
self.params.module(),
|
||||||
|
pt.elem_mut(),
|
||||||
|
ct.elem(),
|
||||||
|
&self.sk0_ppol,
|
||||||
|
&mut self.tmp_bytes,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decrypt_sk1(&mut self, pt: &mut Plaintext, ct: &Ciphertext<VecZnx>){
|
||||||
|
decrypt_rlwe(
|
||||||
|
self.params.module(),
|
||||||
|
pt.elem_mut(),
|
||||||
|
ct.elem(),
|
||||||
|
&self.sk1_ppol,
|
||||||
|
&mut self.tmp_bytes,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
236
rlwe/src/trace.rs
Normal file
236
rlwe/src/trace.rs
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
use crate::{automorphism::AutomorphismKey, ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
|
||||||
|
use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMatOps, assert_alignement};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
pub fn trace_galois_elements(module: &Module) -> Vec<i64> {
|
||||||
|
let mut gal_els: Vec<i64> = Vec::new();
|
||||||
|
(0..module.log_n()).for_each(|i| {
|
||||||
|
if i == 0 {
|
||||||
|
gal_els.push(-1);
|
||||||
|
} else {
|
||||||
|
gal_els.push(module.galois_element(1 << (i - 1)));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
gal_els
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parameters {
|
||||||
|
pub fn trace_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
|
||||||
|
self.automorphism_tmp_bytes(res_logq, in_logq, gct_logq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn trace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
|
||||||
|
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
|
||||||
|
+ 2 * module.bytes_of_vec_znx_dft(std::cmp::min(c_cols, a_cols));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn trace_inplace(
|
||||||
|
module: &Module,
|
||||||
|
a: &mut Ciphertext<VecZnx>,
|
||||||
|
start: usize,
|
||||||
|
end: usize,
|
||||||
|
b: &HashMap<i64, AutomorphismKey>,
|
||||||
|
b_cols: usize,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
let cols: usize = a.cols();
|
||||||
|
|
||||||
|
let b_rows: usize;
|
||||||
|
|
||||||
|
if let Some((_, key)) = b.iter().next() {
|
||||||
|
b_rows = key.value.rows();
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
println!("{} {}", b_cols, key.value.cols());
|
||||||
|
assert!(b_cols <= key.value.cols())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!("b: HashMap<i64, AutomorphismKey>, is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert!(start <= end);
|
||||||
|
assert!(end <= module.n());
|
||||||
|
assert!(tmp_bytes.len() >= trace_tmp_bytes(module, cols, cols, b_rows, b_cols));
|
||||||
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
let cols: usize = std::cmp::min(b_cols, a.cols());
|
||||||
|
|
||||||
|
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||||
|
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
|
||||||
|
|
||||||
|
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
|
||||||
|
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft);
|
||||||
|
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
|
||||||
|
|
||||||
|
let log_base2k: usize = a.log_base2k();
|
||||||
|
|
||||||
|
(start..end).for_each(|i| {
|
||||||
|
a.at_mut(0).rsh(log_base2k, 1, tmp_bytes);
|
||||||
|
a.at_mut(1).rsh(log_base2k, 1, tmp_bytes);
|
||||||
|
|
||||||
|
let p: i64;
|
||||||
|
if i == 0 {
|
||||||
|
p = -1;
|
||||||
|
} else {
|
||||||
|
p = module.galois_element(1 << (i - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(key) = b.get(&p) {
|
||||||
|
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||||
|
|
||||||
|
// a[0] = NORMALIZE(a[0] + AUTO(a[0] + IDFT(<DFT(a[1]), key[0]>)))
|
||||||
|
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, key.value.at(0), tmp_bytes);
|
||||||
|
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||||
|
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
||||||
|
module.vec_znx_big_automorphism_inplace(p, &mut res_big);
|
||||||
|
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
||||||
|
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut res_big, tmp_bytes);
|
||||||
|
|
||||||
|
// a[1] = NORMALIZE(a[1] + AUTO(IDFT(<DFT(a[1]), key[1]>)))
|
||||||
|
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, key.value.at(1), tmp_bytes);
|
||||||
|
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||||
|
module.vec_znx_big_automorphism_inplace(p, &mut res_big);
|
||||||
|
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(1));
|
||||||
|
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut res_big, tmp_bytes);
|
||||||
|
} else {
|
||||||
|
panic!("b[{}] is empty", p)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::{trace_galois_elements, trace_inplace};
|
||||||
|
use crate::{
|
||||||
|
automorphism::AutomorphismKey,
|
||||||
|
ciphertext::Ciphertext,
|
||||||
|
decryptor::decrypt_rlwe,
|
||||||
|
elem::ElemCommon,
|
||||||
|
encryptor::encrypt_rlwe_sk,
|
||||||
|
keys::SecretKey,
|
||||||
|
parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral},
|
||||||
|
plaintext::Plaintext,
|
||||||
|
};
|
||||||
|
use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, alloc_aligned};
|
||||||
|
use sampling::source::{Source, new_seed};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_trace_inplace() {
|
||||||
|
let log_base2k: usize = 10;
|
||||||
|
let log_q: usize = 50;
|
||||||
|
let log_p: usize = 15;
|
||||||
|
|
||||||
|
// Basic parameters with enough limbs to test edge cases
|
||||||
|
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||||
|
backend: BACKEND::FFT64,
|
||||||
|
log_n: 12,
|
||||||
|
log_q: log_q,
|
||||||
|
log_p: log_p,
|
||||||
|
log_base2k: log_base2k,
|
||||||
|
log_scale: 20,
|
||||||
|
xe: 3.2,
|
||||||
|
xs: 1 << 11,
|
||||||
|
};
|
||||||
|
|
||||||
|
let params: Parameters = Parameters::new(¶ms_lit);
|
||||||
|
|
||||||
|
let module: &Module = params.module();
|
||||||
|
let log_q: usize = params.log_q();
|
||||||
|
let log_qp: usize = params.log_qp();
|
||||||
|
let gct_rows: usize = params.cols_q();
|
||||||
|
let gct_cols: usize = params.cols_qp();
|
||||||
|
|
||||||
|
// scratch space
|
||||||
|
let mut tmp_bytes: Vec<u8> = alloc_aligned(
|
||||||
|
params.decrypt_rlwe_tmp_byte(log_q)
|
||||||
|
| params.encrypt_rlwe_sk_tmp_bytes(log_q)
|
||||||
|
| params.automorphism_key_new_tmp_bytes(gct_rows, log_qp)
|
||||||
|
| params.automorphism_tmp_bytes(log_q, log_q, log_qp),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Samplers for public and private randomness
|
||||||
|
let mut source_xe: Source = Source::new(new_seed());
|
||||||
|
let mut source_xa: Source = Source::new(new_seed());
|
||||||
|
let mut source_xs: Source = Source::new(new_seed());
|
||||||
|
|
||||||
|
let mut sk: SecretKey = SecretKey::new(module);
|
||||||
|
sk.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||||
|
let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol();
|
||||||
|
module.svp_prepare(&mut sk_svp_ppol, &sk.0);
|
||||||
|
|
||||||
|
let gal_els: Vec<i64> = trace_galois_elements(module);
|
||||||
|
|
||||||
|
let auto_keys: HashMap<i64, AutomorphismKey> = AutomorphismKey::new_many(
|
||||||
|
module,
|
||||||
|
&gal_els,
|
||||||
|
&sk,
|
||||||
|
log_base2k,
|
||||||
|
gct_rows,
|
||||||
|
log_qp,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
DEFAULT_SIGMA,
|
||||||
|
&mut tmp_bytes,
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut data: Vec<i64> = vec![0i64; params.n()];
|
||||||
|
|
||||||
|
data.iter_mut()
|
||||||
|
.enumerate()
|
||||||
|
.for_each(|(i, x)| *x = 1 + i as i64);
|
||||||
|
|
||||||
|
let log_k: usize = 2 * log_base2k;
|
||||||
|
|
||||||
|
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
|
||||||
|
let mut pt: Plaintext = params.new_plaintext(log_q);
|
||||||
|
|
||||||
|
pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32);
|
||||||
|
pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes);
|
||||||
|
|
||||||
|
pt.at(0).decode_vec_i64(log_base2k, log_k, &mut data);
|
||||||
|
|
||||||
|
pt.at(0).print(pt.cols(), 16);
|
||||||
|
|
||||||
|
encrypt_rlwe_sk(
|
||||||
|
module,
|
||||||
|
&mut ct.elem_mut(),
|
||||||
|
Some(pt.at(0)),
|
||||||
|
&sk_svp_ppol,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
params.xe(),
|
||||||
|
&mut tmp_bytes,
|
||||||
|
);
|
||||||
|
|
||||||
|
trace_inplace(module, &mut ct, 0, 4, &auto_keys, gct_cols, &mut tmp_bytes);
|
||||||
|
trace_inplace(
|
||||||
|
module,
|
||||||
|
&mut ct,
|
||||||
|
4,
|
||||||
|
module.log_n(),
|
||||||
|
&auto_keys,
|
||||||
|
gct_cols,
|
||||||
|
&mut tmp_bytes,
|
||||||
|
);
|
||||||
|
|
||||||
|
// pt = dec(auto(ct)) - auto(pt)
|
||||||
|
decrypt_rlwe(
|
||||||
|
module,
|
||||||
|
pt.elem_mut(),
|
||||||
|
ct.elem(),
|
||||||
|
&sk_svp_ppol,
|
||||||
|
&mut tmp_bytes,
|
||||||
|
);
|
||||||
|
|
||||||
|
pt.at(0).print(pt.cols(), 16);
|
||||||
|
|
||||||
|
pt.at(0).decode_vec_i64(log_base2k, log_k, &mut data);
|
||||||
|
|
||||||
|
println!("trace: {:?}", &data[..16]);
|
||||||
|
}
|
||||||
|
}
|
||||||
79
rustfmt.toml
Normal file
79
rustfmt.toml
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
max_width = 130
|
||||||
|
hard_tabs = false
|
||||||
|
tab_spaces = 4
|
||||||
|
newline_style = "Auto"
|
||||||
|
indent_style = "Block"
|
||||||
|
use_small_heuristics = "Default"
|
||||||
|
fn_call_width = 60
|
||||||
|
attr_fn_like_width = 100
|
||||||
|
struct_lit_width = 18
|
||||||
|
struct_variant_width = 35
|
||||||
|
array_width = 60
|
||||||
|
chain_width = 60
|
||||||
|
single_line_if_else_max_width = 50
|
||||||
|
single_line_let_else_max_width = 50
|
||||||
|
wrap_comments = false
|
||||||
|
format_code_in_doc_comments = true
|
||||||
|
doc_comment_code_block_width = 100
|
||||||
|
comment_width = 80
|
||||||
|
normalize_comments = true
|
||||||
|
normalize_doc_attributes = true
|
||||||
|
format_strings = true
|
||||||
|
format_macro_matchers = false
|
||||||
|
format_macro_bodies = true
|
||||||
|
skip_macro_invocations = []
|
||||||
|
hex_literal_case = "Preserve"
|
||||||
|
empty_item_single_line = true
|
||||||
|
struct_lit_single_line = true
|
||||||
|
fn_single_line = false
|
||||||
|
where_single_line = false
|
||||||
|
imports_indent = "Block"
|
||||||
|
imports_layout = "Mixed"
|
||||||
|
imports_granularity = "Preserve"
|
||||||
|
group_imports = "Preserve"
|
||||||
|
reorder_imports = true
|
||||||
|
reorder_modules = true
|
||||||
|
reorder_impl_items = false
|
||||||
|
type_punctuation_density = "Wide"
|
||||||
|
space_before_colon = false
|
||||||
|
space_after_colon = true
|
||||||
|
spaces_around_ranges = false
|
||||||
|
binop_separator = "Front"
|
||||||
|
remove_nested_parens = true
|
||||||
|
combine_control_expr = true
|
||||||
|
short_array_element_width_threshold = 10
|
||||||
|
overflow_delimited_expr = false
|
||||||
|
struct_field_align_threshold = 0
|
||||||
|
enum_discrim_align_threshold = 0
|
||||||
|
match_arm_blocks = true
|
||||||
|
match_arm_leading_pipes = "Never"
|
||||||
|
force_multiline_blocks = false
|
||||||
|
fn_params_layout = "Tall"
|
||||||
|
brace_style = "SameLineWhere"
|
||||||
|
control_brace_style = "AlwaysSameLine"
|
||||||
|
trailing_semicolon = true
|
||||||
|
trailing_comma = "Vertical"
|
||||||
|
match_block_trailing_comma = false
|
||||||
|
blank_lines_upper_bound = 1
|
||||||
|
blank_lines_lower_bound = 0
|
||||||
|
edition = "2024"
|
||||||
|
style_edition = "2024"
|
||||||
|
inline_attribute_width = 0
|
||||||
|
format_generated_files = true
|
||||||
|
generated_marker_line_search_limit = 5
|
||||||
|
merge_derives = true
|
||||||
|
use_try_shorthand = false
|
||||||
|
use_field_init_shorthand = false
|
||||||
|
force_explicit_abi = true
|
||||||
|
condense_wildcard_suffixes = false
|
||||||
|
color = "Auto"
|
||||||
|
required_version = "1.8.0"
|
||||||
|
unstable_features = true
|
||||||
|
disable_all_formatting = false
|
||||||
|
skip_children = false
|
||||||
|
show_parse_errors = true
|
||||||
|
error_on_line_overflow = false
|
||||||
|
error_on_unformatted = false
|
||||||
|
ignore = []
|
||||||
|
emit_mode = "Files"
|
||||||
|
make_backup = false
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
use rand_chacha::rand_core::SeedableRng;
|
|
||||||
use rand_chacha::ChaCha8Rng;
|
use rand_chacha::ChaCha8Rng;
|
||||||
|
use rand_chacha::rand_core::SeedableRng;
|
||||||
use rand_core::{OsRng, RngCore};
|
use rand_core::{OsRng, RngCore};
|
||||||
|
|
||||||
const MAXF64: f64 = 9007199254740992.0;
|
const MAXF64: f64 = 9007199254740992.0;
|
||||||
|
|||||||
Reference in New Issue
Block a user