Merge branch 'dev_trace'

This commit is contained in:
Jean-Philippe Bossuat
2025-04-24 19:14:56 +02:00
38 changed files with 1947 additions and 947 deletions

2
.gitmodules vendored
View File

@@ -1,3 +1,3 @@
[submodule "base2k/spqlios-arithmetic"] [submodule "base2k/spqlios-arithmetic"]
path = base2k/spqlios-arithmetic path = base2k/spqlios-arithmetic
url = https://github.com/Pro7ech/spqlios-arithmetic url = https://github.com/phantomzone-org/spqlios-arithmetic

View File

@@ -1,14 +1,9 @@
use base2k::ffi::reim::*; use base2k::ffi::reim::*;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use std::ffi::c_void; use std::ffi::c_void;
fn fft(c: &mut Criterion) { fn fft(c: &mut Criterion) {
fn forward<'a>( fn forward<'a>(m: u32, log_bound: u32, reim_fft_precomp: *mut reim_fft_precomp, a: &'a [i64]) -> Box<dyn FnMut() + 'a> {
m: u32,
log_bound: u32,
reim_fft_precomp: *mut reim_fft_precomp,
a: &'a [i64],
) -> Box<dyn FnMut() + 'a> {
unsafe { unsafe {
let buf_a: *mut f64 = reim_fft_precomp_get_buffer(reim_fft_precomp, 0); let buf_a: *mut f64 = reim_fft_precomp_get_buffer(reim_fft_precomp, 0);
reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr()); reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr());
@@ -16,12 +11,7 @@ fn fft(c: &mut Criterion) {
} }
} }
fn backward<'a>( fn backward<'a>(m: u32, log_bound: u32, reim_ifft_precomp: *mut reim_ifft_precomp, a: &'a [i64]) -> Box<dyn FnMut() + 'a> {
m: u32,
log_bound: u32,
reim_ifft_precomp: *mut reim_ifft_precomp,
a: &'a [i64],
) -> Box<dyn FnMut() + 'a> {
Box::new(move || unsafe { Box::new(move || unsafe {
let buf_a: *mut f64 = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0); let buf_a: *mut f64 = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0);
reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr()); reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr());
@@ -29,8 +19,7 @@ fn fft(c: &mut Criterion) {
}) })
} }
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("fft");
c.benchmark_group("fft");
for log_n in 10..17 { for log_n in 10..17 {
let n: usize = 1 << log_n; let n: usize = 1 << log_n;

View File

@@ -1,6 +1,6 @@
use base2k::{ use base2k::{
alloc_aligned, Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, BACKEND, Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, BACKEND, VecZnxDftOps, VecZnxOps, alloc_aligned,
}; };
use itertools::izip; use itertools::izip;
use sampling::source::Source; use sampling::source::Source;
@@ -38,13 +38,13 @@ fn main() {
let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.cols()); let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.cols());
// Applies buf_dft <- s * a // Applies buf_dft <- s * a
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a, a.cols()); module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
// Alias scratch space // Alias scratch space
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
// buf_big <- IDFT(buf_dft) (not normalized) // buf_big <- IDFT(buf_dft) (not normalized)
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, a.cols()); module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
let mut m: VecZnx = module.new_vec_znx(msg_cols); let mut m: VecZnx = module.new_vec_znx(msg_cols);
@@ -74,8 +74,8 @@ fn main() {
// Decrypt // Decrypt
// buf_big <- a * s // buf_big <- a * s
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a, a.cols()); module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, b.cols()); module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
// buf_big <- a * s + b // buf_big <- a * s + b
module.vec_znx_big_add_small_inplace(&mut buf_big, &b); module.vec_znx_big_add_small_inplace(&mut buf_big, &b);

View File

@@ -1,6 +1,6 @@
use base2k::{ use base2k::{
alloc_aligned, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, BACKEND, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VecZnxVec, VmpPMat,
VecZnxDftOps, VecZnxOps, VecZnxVec, VmpPMat, VmpPMatOps, BACKEND, VmpPMatOps, alloc_aligned,
}; };
fn main() { fn main() {
@@ -16,8 +16,7 @@ fn main() {
let cols: usize = cols + 1; let cols: usize = cols + 1;
// Maximum size of the byte scratch needed // Maximum size of the byte scratch needed
let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows, cols) let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows, cols) | module.vmp_apply_dft_tmp_bytes(cols, cols, rows, cols);
| module.vmp_apply_dft_tmp_bytes(cols, cols, rows, cols);
let mut buf: Vec<u8> = alloc_aligned(tmp_bytes); let mut buf: Vec<u8> = alloc_aligned(tmp_bytes);
@@ -49,7 +48,7 @@ fn main() {
module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf);
let mut c_big: VecZnxBig = c_dft.as_vec_znx_big(); let mut c_big: VecZnxBig = c_dft.as_vec_znx_big();
module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft, cols); module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft);
let mut res: VecZnx = module.new_vec_znx(cols); let mut res: VecZnx = module.new_vec_znx(cols);
module.vec_znx_big_normalize(log_base2k, &mut res, &c_big, &mut buf); module.vec_znx_big_normalize(log_base2k, &mut res, &c_big, &mut buf);

View File

@@ -40,14 +40,7 @@ pub trait Encoding {
/// * `i`: index of the coefficient on which to encode the data. /// * `i`: index of the coefficient on which to encode the data.
/// * `data`: data to encode on the receiver. /// * `data`: data to encode on the receiver.
/// * `log_max`: base two logarithm of the infinity norm of the input data. /// * `log_max`: base two logarithm of the infinity norm of the input data.
fn encode_coeff_i64( fn encode_coeff_i64(&mut self, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize);
&mut self,
log_base2k: usize,
log_k: usize,
i: usize,
data: i64,
log_max: usize,
);
/// decode a single of i64 from the receiver at the given index. /// decode a single of i64 from the receiver at the given index.
/// ///
@@ -73,14 +66,7 @@ impl Encoding for VecZnx {
decode_vec_float(self, log_base2k, data) decode_vec_float(self, log_base2k, data)
} }
fn encode_coeff_i64( fn encode_coeff_i64(&mut self, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) {
&mut self,
log_base2k: usize,
log_k: usize,
i: usize,
value: i64,
log_max: usize,
) {
encode_coeff_i64(self, log_base2k, log_k, i, value, log_max) encode_coeff_i64(self, log_base2k, log_k, i, value, log_max)
} }
@@ -119,8 +105,7 @@ fn encode_vec_i64(a: &mut VecZnx, log_base2k: usize, log_k: usize, data: &[i64],
.enumerate() .enumerate()
.for_each(|(i, i_rev)| { .for_each(|(i, i_rev)| {
let shift: usize = i * log_base2k; let shift: usize = i * log_base2k;
izip!(a.at_mut(i_rev)[..size].iter_mut(), data[..size].iter()) izip!(a.at_mut(i_rev)[..size].iter_mut(), data[..size].iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
.for_each(|(y, x)| *y = (x >> shift) & mask);
}) })
} }
@@ -189,14 +174,7 @@ fn decode_vec_float(a: &VecZnx, log_base2k: usize, data: &mut [Float]) {
}); });
} }
fn encode_coeff_i64( fn encode_coeff_i64(a: &mut VecZnx, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) {
a: &mut VecZnx,
log_base2k: usize,
log_k: usize,
i: usize,
value: i64,
log_max: usize,
) {
debug_assert!(i < a.n()); debug_assert!(i < a.n());
let cols: usize = (log_k + log_base2k - 1) / log_base2k; let cols: usize = (log_k + log_base2k - 1) / log_base2k;
debug_assert!( debug_assert!(

View File

@@ -62,10 +62,7 @@ unsafe extern "C" {
pub unsafe fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP; pub unsafe fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP;
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_fft_precomp_get_buffer( pub unsafe fn reim_fft_precomp_get_buffer(tables: *const REIM_FFT_PRECOMP, buffer_index: u32) -> *mut f64;
tables: *const REIM_FFT_PRECOMP,
buffer_index: u32,
) -> *mut f64;
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn new_reim_fft_buffer(m: u32) -> *mut f64; pub unsafe fn new_reim_fft_buffer(m: u32) -> *mut f64;
@@ -80,10 +77,7 @@ unsafe extern "C" {
pub unsafe fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP; pub unsafe fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP;
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_ifft_precomp_get_buffer( pub unsafe fn reim_ifft_precomp_get_buffer(tables: *const REIM_IFFT_PRECOMP, buffer_index: u32) -> *mut f64;
tables: *const REIM_IFFT_PRECOMP,
buffer_index: u32,
) -> *mut f64;
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64); pub unsafe fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64);
@@ -92,120 +86,58 @@ unsafe extern "C" {
pub unsafe fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP; pub unsafe fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP;
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_fftvec_mul( pub unsafe fn reim_fftvec_mul(tables: *const REIM_FFTVEC_MUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64);
tables: *const REIM_FFTVEC_MUL_PRECOMP,
r: *mut f64,
a: *const f64,
b: *const f64,
);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP; pub unsafe fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP;
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_fftvec_addmul( pub unsafe fn reim_fftvec_addmul(tables: *const REIM_FFTVEC_ADDMUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64);
tables: *const REIM_FFTVEC_ADDMUL_PRECOMP,
r: *mut f64,
a: *const f64,
b: *const f64,
);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn new_reim_from_znx32_precomp( pub unsafe fn new_reim_from_znx32_precomp(m: u32, log2bound: u32) -> *mut REIM_FROM_ZNX32_PRECOMP;
m: u32,
log2bound: u32,
) -> *mut REIM_FROM_ZNX32_PRECOMP;
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_from_znx32( pub unsafe fn reim_from_znx32(tables: *const REIM_FROM_ZNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32);
tables: *const REIM_FROM_ZNX32_PRECOMP,
r: *mut ::std::os::raw::c_void,
a: *const i32,
);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_from_znx64( pub unsafe fn reim_from_znx64(tables: *const REIM_FROM_ZNX64_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i64);
tables: *const REIM_FROM_ZNX64_PRECOMP,
r: *mut ::std::os::raw::c_void,
a: *const i64,
);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP; pub unsafe fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP;
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_from_znx64_simple( pub unsafe fn reim_from_znx64_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, a: *const i64);
m: u32,
log2bound: u32,
r: *mut ::std::os::raw::c_void,
a: *const i64,
);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP; pub unsafe fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP;
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_from_tnx32( pub unsafe fn reim_from_tnx32(tables: *const REIM_FROM_TNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32);
tables: *const REIM_FROM_TNX32_PRECOMP,
r: *mut ::std::os::raw::c_void,
a: *const i32,
);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn new_reim_to_tnx32_precomp( pub unsafe fn new_reim_to_tnx32_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX32_PRECOMP;
m: u32,
divisor: f64,
log2overhead: u32,
) -> *mut REIM_TO_TNX32_PRECOMP;
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_to_tnx32( pub unsafe fn reim_to_tnx32(tables: *const REIM_TO_TNX32_PRECOMP, r: *mut i32, a: *const ::std::os::raw::c_void);
tables: *const REIM_TO_TNX32_PRECOMP,
r: *mut i32,
a: *const ::std::os::raw::c_void,
);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn new_reim_to_tnx_precomp( pub unsafe fn new_reim_to_tnx_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX_PRECOMP;
m: u32,
divisor: f64,
log2overhead: u32,
) -> *mut REIM_TO_TNX_PRECOMP;
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64); pub unsafe fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_to_tnx_simple( pub unsafe fn reim_to_tnx_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut f64, a: *const f64);
m: u32,
divisor: f64,
log2overhead: u32,
r: *mut f64,
a: *const f64,
);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn new_reim_to_znx64_precomp( pub unsafe fn new_reim_to_znx64_precomp(m: u32, divisor: f64, log2bound: u32) -> *mut REIM_TO_ZNX64_PRECOMP;
m: u32,
divisor: f64,
log2bound: u32,
) -> *mut REIM_TO_ZNX64_PRECOMP;
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_to_znx64( pub unsafe fn reim_to_znx64(precomp: *const REIM_TO_ZNX64_PRECOMP, r: *mut i64, a: *const ::std::os::raw::c_void);
precomp: *const REIM_TO_ZNX64_PRECOMP,
r: *mut i64,
a: *const ::std::os::raw::c_void,
);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_to_znx64_simple( pub unsafe fn reim_to_znx64_simple(m: u32, divisor: f64, log2bound: u32, r: *mut i64, a: *const ::std::os::raw::c_void);
m: u32,
divisor: f64,
log2bound: u32,
r: *mut i64,
a: *const ::std::os::raw::c_void,
);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void); pub unsafe fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void);
@@ -230,22 +162,11 @@ unsafe extern "C" {
); );
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_from_znx32_simple( pub unsafe fn reim_from_znx32_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
m: u32,
log2bound: u32,
r: *mut ::std::os::raw::c_void,
x: *const i32,
);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32); pub unsafe fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn reim_to_tnx32_simple( pub unsafe fn reim_to_tnx32_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut i32, x: *const ::std::os::raw::c_void);
m: u32,
divisor: f64,
log2overhead: u32,
r: *mut i32,
x: *const ::std::os::raw::c_void,
);
} }

View File

@@ -44,14 +44,7 @@ unsafe extern "C" {
); );
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn vec_znx_dft( pub unsafe fn vec_znx_dft(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64, a: *const i64, a_size: u64, a_sl: u64);
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn vec_znx_idft( pub unsafe fn vec_znx_idft(

View File

@@ -37,13 +37,22 @@ unsafe extern "C" {
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn vmp_apply_dft_tmp_bytes( pub unsafe fn vmp_apply_dft_add(
module: *const MODULE, module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64, res_size: u64,
a: *const i64,
a_size: u64, a_size: u64,
a_sl: u64,
pmat: *const VMP_PMAT,
nrows: u64, nrows: u64,
ncols: u64, ncols: u64,
) -> u64; tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64;
} }
unsafe extern "C" { unsafe extern "C" {
@@ -60,6 +69,20 @@ unsafe extern "C" {
); );
} }
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_to_dft_add(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a_dft: *const VEC_ZNX_DFT,
a_size: u64,
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes( pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
module: *const MODULE, module: *const MODULE,

View File

@@ -64,24 +64,11 @@ unsafe extern "C" {
pub unsafe fn rnx_mul_xp_minus_one_inplace(nn: u64, p: i64, res: *mut f64); pub unsafe fn rnx_mul_xp_minus_one_inplace(nn: u64, p: i64, res: *mut f64);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn znx_normalize( pub unsafe fn znx_normalize(nn: u64, base_k: u64, out: *mut i64, carry_out: *mut i64, in_: *const i64, carry_in: *const i64);
nn: u64,
base_k: u64,
out: *mut i64,
carry_out: *mut i64,
in_: *const i64,
carry_in: *const i64,
);
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn znx_small_single_product( pub unsafe fn znx_small_single_product(module: *const MODULE, res: *mut i64, a: *const i64, b: *const i64, tmp: *mut u8);
module: *const MODULE,
res: *mut i64,
a: *const i64,
b: *const i64,
tmp: *mut u8,
);
} }
unsafe extern "C" { unsafe extern "C" {

View File

@@ -1,11 +1,5 @@
pub mod encoding; pub mod encoding;
#[allow( #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
non_camel_case_types,
non_snake_case,
non_upper_case_globals,
dead_code,
improper_ctypes
)]
// Other modules and exports // Other modules and exports
pub mod ffi; pub mod ffi;
pub mod infos; pub mod infos;
@@ -42,7 +36,10 @@ pub fn is_aligned<T>(ptr: *const T) -> bool {
} }
pub fn assert_alignement<T>(ptr: *const T) { pub fn assert_alignement<T>(ptr: *const T) {
assert!(is_aligned(ptr), "invalid alignement: ensure passed bytes have been allocated with [alloc_aligned_u8] or [alloc_aligned]") assert!(
is_aligned(ptr),
"invalid alignement: ensure passed bytes have been allocated with [alloc_aligned_u8] or [alloc_aligned]"
)
} }
pub fn cast<T, V>(data: &[T]) -> &[V] { pub fn cast<T, V>(data: &[T]) -> &[V] {
@@ -57,7 +54,7 @@ pub fn cast_mut<T, V>(data: &[T]) -> &mut [V] {
unsafe { std::slice::from_raw_parts_mut(ptr, len) } unsafe { std::slice::from_raw_parts_mut(ptr, len) }
} }
use std::alloc::{alloc, Layout}; use std::alloc::{Layout, alloc};
use std::ptr; use std::ptr;
/// Allocates a block of bytes with a custom alignement. /// Allocates a block of bytes with a custom alignement.

View File

@@ -1,5 +1,5 @@
use crate::ffi::module::{delete_module_info, module_info_t, new_module_info, MODULE};
use crate::GALOISGENERATOR; use crate::GALOISGENERATOR;
use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info};
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
#[repr(u8)] #[repr(u8)]
@@ -51,27 +51,21 @@ impl Module {
(self.n() << 1) as _ (self.n() << 1) as _
} }
// GALOISGENERATOR^|gen| * sign(gen) // Returns GALOISGENERATOR^|gen| * sign(gen)
pub fn galois_element(&self, gen: i64) -> i64 { pub fn galois_element(&self, gen: i64) -> i64 {
if gen == 0 { if gen == 0 {
return 1; return 1;
} }
((mod_exp_u64(GALOISGENERATOR, gen.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * gen.signum()
let mut gal_el: u64 = 1;
let mut gen_1_pow: u64 = GALOISGENERATOR;
let mut e: usize = gen.abs() as usize;
while e > 0 {
if e & 1 == 1 {
gal_el = gal_el.wrapping_mul(gen_1_pow);
} }
gen_1_pow = gen_1_pow.wrapping_mul(gen_1_pow); // Returns gen^-1
e >>= 1; pub fn galois_element_inv(&self, gen: i64) -> i64 {
if gen == 0 {
panic!("cannot invert 0")
} }
((mod_exp_u64(gen.abs() as u64, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1)) as i64)
gal_el &= self.cyclotomic_order() - 1; * gen.signum()
(gal_el as i64) * gen.signum()
} }
pub fn free(self) { pub fn free(self) {
@@ -79,3 +73,17 @@ impl Module {
drop(self); drop(self);
} }
} }
fn mod_exp_u64(x: u64, e: usize) -> u64 {
let mut y: u64 = 1;
let mut x_pow: u64 = x;
let mut exp = e;
while exp > 0 {
if exp & 1 == 1 {
y = y.wrapping_mul(x_pow);
}
x_pow = x_pow.wrapping_mul(x_pow);
exp >>= 1;
}
y
}

View File

@@ -18,15 +18,7 @@ pub trait Sampling {
); );
/// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\]. /// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\].
fn add_normal( fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64);
&self,
log_base2k: usize,
a: &mut VecZnx,
log_k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
);
} }
impl Sampling for Module { impl Sampling for Module {
@@ -63,7 +55,7 @@ impl Sampling for Module {
while dist_f64.abs() > bound { while dist_f64.abs() > bound {
dist_f64 = dist.sample(source) dist_f64 = dist.sample(source)
} }
*a += (dist_f64.round() as i64) << log_base2k_rem *a += (dist_f64.round() as i64) << log_base2k_rem;
}); });
} else { } else {
a.at_mut(a.cols() - 1).iter_mut().for_each(|a| { a.at_mut(a.cols() - 1).iter_mut().for_each(|a| {
@@ -76,15 +68,7 @@ impl Sampling for Module {
} }
} }
fn add_normal( fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64) {
&self,
log_base2k: usize,
a: &mut VecZnx,
log_k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) {
self.add_dist_f64( self.add_dist_f64(
log_base2k, log_base2k,
a, a,

View File

@@ -1,7 +1,7 @@
use crate::{Encoding, Infos, VecZnx}; use crate::{Encoding, Infos, VecZnx};
use rug::Float;
use rug::float::Round; use rug::float::Round;
use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound}; use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound};
use rug::Float;
impl VecZnx { impl VecZnx {
pub fn std(&self, log_base2k: usize) -> f64 { pub fn std(&self, log_base2k: usize) -> f64 {

View File

@@ -1,8 +1,8 @@
use crate::ffi::svp; use crate::ffi::svp::{self, svp_ppol_t};
use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::{assert_alignement, Module, VecZnx, VecZnxDft}; use crate::{BACKEND, Module, VecZnx, VecZnxDft, assert_alignement};
use crate::{alloc_aligned, cast_mut, Infos}; use crate::{Infos, alloc_aligned, cast_mut};
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand_core::RngCore; use rand_core::RngCore;
use rand_distr::{Distribution, WeightedIndex}; use rand_distr::{Distribution, WeightedIndex};
@@ -35,15 +35,15 @@ impl Scalar {
self.n self.n
} }
pub fn buffer_size(n: usize) -> usize { pub fn bytes_of(n: usize) -> usize {
n n * std::mem::size_of::<i64>()
} }
pub fn from_buffer(&mut self, n: usize, bytes: &mut [u8]) -> Self { pub fn from_bytes(n: usize, bytes: &mut [u8]) -> Self {
let size: usize = Self::buffer_size(n); let size: usize = Self::bytes_of(n);
debug_assert!( debug_assert!(
bytes.len() == size, bytes.len() == size,
"invalid buffer: bytes.len()={} < self.buffer_size(n={})={}", "invalid buffer: bytes.len()={} < self.bytes_of(n={})={}",
bytes.len(), bytes.len(),
n, n,
size size
@@ -63,11 +63,37 @@ impl Scalar {
} }
} }
pub fn from_bytes_borrow(n: usize, bytes: &mut [u8]) -> Self {
let size: usize = Self::bytes_of(n);
debug_assert!(
bytes.len() == size,
"invalid buffer: bytes.len()={} < self.bytes_of(n={})={}",
bytes.len(),
n,
size
);
#[cfg(debug_assertions)]
{
assert_alignement(bytes.as_ptr())
}
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
Self {
n: n,
data: Vec::new(),
ptr: ptr,
}
}
pub fn as_ptr(&self) -> *const i64 { pub fn as_ptr(&self) -> *const i64 {
self.ptr self.ptr
} }
pub fn raw(&self) -> &[i64] { pub fn raw(&self) -> &[i64] {
unsafe { std::slice::from_raw_parts(self.ptr, self.n) }
}
pub fn raw_mut(&self) -> &mut [i64] {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n) } unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n) }
} }
@@ -87,26 +113,89 @@ impl Scalar {
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
self.data.shuffle(source); self.data.shuffle(source);
} }
pub fn as_vec_znx(&self) -> VecZnx {
VecZnx {
n: self.n,
cols: 1,
data: Vec::new(),
ptr: self.ptr,
}
}
} }
pub struct SvpPPol(pub *mut svp::svp_ppol_t, pub usize); pub trait ScalarOps {
fn bytes_of_scalar(&self) -> usize;
fn new_scalar(&self) -> Scalar;
fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar;
fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar;
}
impl ScalarOps for Module {
fn bytes_of_scalar(&self) -> usize {
Scalar::bytes_of(self.n())
}
fn new_scalar(&self) -> Scalar {
Scalar::new(self.n())
}
fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar {
Scalar::from_bytes(self.n(), bytes)
}
fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar {
Scalar::from_bytes_borrow(self.n(), tmp_bytes)
}
}
pub struct SvpPPol {
pub n: usize,
pub data: Vec<u8>,
pub ptr: *mut u8,
pub backend: BACKEND,
}
/// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft]. /// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft].
/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb. /// An [SvpPPol] an be seen as a [VecZnxDft] of one limb.
/// The backend array of an [SvpPPol] is allocated in C and must be freed manually.
impl SvpPPol { impl SvpPPol {
/// Returns the ring degree of the [SvpPPol]. pub fn new(module: &Module) -> Self {
pub fn n(&self) -> usize { module.new_svp_ppol()
self.1
} }
pub fn from_bytes(size: usize, bytes: &mut [u8]) -> SvpPPol { /// Returns the ring degree of the [SvpPPol].
pub fn n(&self) -> usize {
self.n
}
pub fn bytes_of(module: &Module) -> usize {
module.bytes_of_svp_ppol()
}
pub fn from_bytes(module: &Module, bytes: &mut [u8]) -> SvpPPol {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(bytes.as_ptr()) assert_alignement(bytes.as_ptr());
assert_eq!(bytes.len(), module.bytes_of_svp_ppol());
}
unsafe {
Self {
n: module.n(),
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: bytes.as_mut_ptr(),
backend: module.backend(),
}
}
}
pub fn from_bytes_borrow(module: &Module, tmp_bytes: &mut [u8]) -> SvpPPol {
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
assert_eq!(tmp_bytes.len(), module.bytes_of_svp_ppol());
}
Self {
n: module.n(),
data: Vec::new(),
ptr: tmp_bytes.as_mut_ptr(),
backend: module.backend(),
} }
debug_assert!(bytes.len() << 3 >= size);
SvpPPol(bytes.as_mut_ptr() as *mut svp::svp_ppol_t, size)
} }
/// Returns the number of cols of the [SvpPPol], which is always 1. /// Returns the number of cols of the [SvpPPol], which is always 1.
@@ -120,45 +209,64 @@ pub trait SvpPPolOps {
fn new_svp_ppol(&self) -> SvpPPol; fn new_svp_ppol(&self) -> SvpPPol;
/// Returns the minimum number of bytes necessary to allocate /// Returns the minimum number of bytes necessary to allocate
/// a new [SvpPPol] through [SvpPPol::from_bytes]. /// a new [SvpPPol] through [SvpPPol::from_bytes] ro.
fn bytes_of_svp_ppol(&self) -> usize; fn bytes_of_svp_ppol(&self) -> usize;
/// Allocates a new [SvpPPol] from an array of bytes.
/// The array of bytes is owned by the [SvpPPol].
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol;
/// Allocates a new [SvpPPol] from an array of bytes.
/// The array of bytes is borrowed by the [SvpPPol].
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol;
/// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft]. /// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft].
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar); fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar);
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
/// the [VecZnxDft] is multiplied with [SvpPPol]. /// the [VecZnxDft] is multiplied with [SvpPPol].
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_cols: usize); fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx);
} }
impl SvpPPolOps for Module { impl SvpPPolOps for Module {
fn new_svp_ppol(&self) -> SvpPPol { fn new_svp_ppol(&self) -> SvpPPol {
unsafe { SvpPPol(svp::new_svp_ppol(self.ptr), self.n()) } let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_svp_ppol());
let ptr: *mut u8 = data.as_mut_ptr();
SvpPPol {
data: data,
ptr: ptr,
n: self.n(),
backend: self.backend(),
}
} }
fn bytes_of_svp_ppol(&self) -> usize { fn bytes_of_svp_ppol(&self) -> usize {
unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize } unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize }
} }
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) { fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol {
unsafe { svp::svp_prepare(self.ptr, svp_ppol.0, a.as_ptr()) } SvpPPol::from_bytes(self, bytes)
} }
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_cols: usize) { fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol {
debug_assert!( SvpPPol::from_bytes_borrow(self, tmp_bytes)
c.cols() >= b_cols, }
"invalid c_vector: c_vector.cols()={} < b.cols()={}",
c.cols(), fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) {
b_cols unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) }
); }
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) {
unsafe { unsafe {
svp::svp_apply_dft( svp::svp_apply_dft(
self.ptr, self.ptr,
c.ptr as *mut vec_znx_dft_t, c.ptr as *mut vec_znx_dft_t,
b_cols as u64, c.cols() as u64,
a.0, a.ptr as *const svp_ppol_t,
b.as_ptr(), b.as_ptr(),
b_cols as u64, b.cols() as u64,
b.n() as u64, b.n() as u64,
) )
} }

View File

@@ -1,8 +1,8 @@
use crate::cast_mut; use crate::cast_mut;
use crate::ffi::vec_znx; use crate::ffi::vec_znx;
use crate::ffi::znx; use crate::ffi::znx;
use crate::{alloc_aligned, assert_alignement};
use crate::{Infos, Module}; use crate::{Infos, Module};
use crate::{alloc_aligned, assert_alignement};
use itertools::izip; use itertools::izip;
use std::cmp::min; use std::cmp::min;
@@ -12,16 +12,16 @@ use std::cmp::min;
#[derive(Clone)] #[derive(Clone)]
pub struct VecZnx { pub struct VecZnx {
/// Polynomial degree. /// Polynomial degree.
n: usize, pub n: usize,
/// Number of columns. /// Number of columns.
cols: usize, pub cols: usize,
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
data: Vec<i64>, pub data: Vec<i64>,
/// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it). /// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it).
ptr: *mut i64, pub ptr: *mut i64,
} }
pub trait VecZnxVec { pub trait VecZnxVec {
@@ -347,8 +347,11 @@ pub trait VecZnxOps {
/// c <- a - b. /// c <- a - b.
fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
/// b <- a - b.
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx);
/// b <- b - a. /// b <- b - a.
fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx); fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx);
/// b <- -a. /// b <- -a.
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx); fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx);
@@ -363,10 +366,10 @@ pub trait VecZnxOps {
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx); fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx);
/// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) /// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx, a_cols: usize); fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx);
/// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) /// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_cols: usize); fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx);
/// Splits b into subrings and copies them them into a. /// Splits b into subrings and copies them them into a.
/// ///
@@ -452,8 +455,8 @@ impl VecZnxOps for Module {
} }
} }
// b <- a + b // b <- a - b
fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx) { fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
unsafe { unsafe {
vec_znx::vec_znx_sub( vec_znx::vec_znx_sub(
self.ptr, self.ptr,
@@ -470,6 +473,24 @@ impl VecZnxOps for Module {
} }
} }
// b <- b - a
fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
b.as_mut_ptr(),
b.cols() as u64,
b.n() as u64,
b.as_ptr(),
b.cols() as u64,
b.n() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
)
}
}
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) { fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) {
unsafe { unsafe {
vec_znx::vec_znx_negate( vec_znx::vec_znx_negate(
@@ -540,10 +561,9 @@ impl VecZnxOps for Module {
/// # Panics /// # Panics
/// ///
/// The method will panic if the argument `a` is greater than `a.cols()`. /// The method will panic if the argument `a` is greater than `a.cols()`.
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx, a_cols: usize) { fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) {
debug_assert_eq!(a.n(), self.n()); debug_assert_eq!(a.n(), self.n());
debug_assert_eq!(b.n(), self.n()); debug_assert_eq!(b.n(), self.n());
debug_assert!(a.cols() >= a_cols);
unsafe { unsafe {
vec_znx::vec_znx_automorphism( vec_znx::vec_znx_automorphism(
self.ptr, self.ptr,
@@ -552,7 +572,7 @@ impl VecZnxOps for Module {
b.cols() as u64, b.cols() as u64,
b.n() as u64, b.n() as u64,
a.as_ptr(), a.as_ptr(),
a_cols as u64, a.cols() as u64,
a.n() as u64, a.n() as u64,
); );
} }
@@ -569,9 +589,8 @@ impl VecZnxOps for Module {
/// # Panics /// # Panics
/// ///
/// The method will panic if the argument `cols` is greater than `self.cols()`. /// The method will panic if the argument `cols` is greater than `self.cols()`.
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_cols: usize) { fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) {
debug_assert_eq!(a.n(), self.n()); debug_assert_eq!(a.n(), self.n());
debug_assert!(a.cols() >= a_cols);
unsafe { unsafe {
vec_znx::vec_znx_automorphism( vec_znx::vec_znx_automorphism(
self.ptr, self.ptr,
@@ -580,7 +599,7 @@ impl VecZnxOps for Module {
a.cols() as u64, a.cols() as u64,
a.n() as u64, a.n() as u64,
a.as_ptr(), a.as_ptr(),
a_cols as u64, a.cols() as u64,
a.n() as u64, a.n() as u64,
); );
} }

View File

@@ -1,5 +1,5 @@
use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; use crate::ffi::vec_znx_big::{self, vec_znx_big_t};
use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, BACKEND}; use crate::{BACKEND, Infos, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement};
pub struct VecZnxBig { pub struct VecZnxBig {
pub data: Vec<u8>, pub data: Vec<u8>,
@@ -16,6 +16,7 @@ impl VecZnxBig {
pub fn from_bytes(module: &Module, cols: usize, bytes: &mut [u8]) -> Self { pub fn from_bytes(module: &Module, cols: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols));
assert_alignement(bytes.as_ptr()) assert_alignement(bytes.as_ptr())
}; };
unsafe { unsafe {
@@ -54,14 +55,6 @@ impl VecZnxBig {
} }
} }
pub fn n(&self) -> usize {
self.n
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn backend(&self) -> BACKEND { pub fn backend(&self) -> BACKEND {
self.backend self.backend
} }
@@ -77,12 +70,36 @@ impl VecZnxBig {
} }
} }
impl Infos for VecZnxBig {
/// Returns the base 2 logarithm of the [VecZnx] degree.
fn log_n(&self) -> usize {
(usize::BITS - (self.n - 1).leading_zeros()) as _
}
/// Returns the [VecZnx] degree.
fn n(&self) -> usize {
self.n
}
/// Returns the number of cols of the [VecZnx].
fn cols(&self) -> usize {
self.cols
}
/// Returns the number of rows of the [VecZnx].
fn rows(&self) -> usize {
1
}
}
pub trait VecZnxBigOps { pub trait VecZnxBigOps {
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
fn new_vec_znx_big(&self, cols: usize) -> VecZnxBig; fn new_vec_znx_big(&self, cols: usize) -> VecZnxBig;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// Returns a new [VecZnxBig] with the provided bytes array as backing array.
/// ///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments /// # Arguments
/// ///
/// * `cols`: the number of cols of the [VecZnxBig]. /// * `cols`: the number of cols of the [VecZnxBig].
@@ -92,6 +109,19 @@ pub trait VecZnxBigOps {
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. /// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig; fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
///
/// Behavior: the backing array is only borrowed.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxBig].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig;
/// Returns the minimum number of bytes necessary to allocate /// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxBig] through [VecZnxBig::from_bytes]. /// a new [VecZnxBig] through [VecZnxBig::from_bytes].
fn bytes_of_vec_znx_big(&self, cols: usize) -> usize; fn bytes_of_vec_znx_big(&self, cols: usize) -> usize;
@@ -111,13 +141,7 @@ pub trait VecZnxBigOps {
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
/// b <- normalize(a) /// b <- normalize(a)
fn vec_znx_big_normalize( fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]);
&self,
log_base2k: usize,
b: &mut VecZnx,
a: &VecZnxBig,
tmp_bytes: &mut [u8],
);
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize; fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize;
@@ -151,19 +175,13 @@ impl VecZnxBigOps for Module {
} }
fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig { fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig {
debug_assert!(
bytes.len() >= <Module as VecZnxBigOps>::bytes_of_vec_znx_big(self, cols),
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
bytes.len(),
<Module as VecZnxBigOps>::bytes_of_vec_znx_big(self, cols)
);
#[cfg(debug_assertions)]
{
assert_alignement(bytes.as_ptr())
}
VecZnxBig::from_bytes(self, cols, bytes) VecZnxBig::from_bytes(self, cols, bytes)
} }
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig {
VecZnxBig::from_bytes_borrow(self, cols, tmp_bytes)
}
fn bytes_of_vec_znx_big(&self, cols: usize) -> usize { fn bytes_of_vec_znx_big(&self, cols: usize) -> usize {
unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, cols as u64) as usize } unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, cols as u64) as usize }
} }
@@ -232,13 +250,7 @@ impl VecZnxBigOps for Module {
unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize } unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize }
} }
fn vec_znx_big_normalize( fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) {
&self,
log_base2k: usize,
b: &mut VecZnx,
a: &VecZnxBig,
tmp_bytes: &mut [u8],
) {
debug_assert!( debug_assert!(
tmp_bytes.len() >= <Module as VecZnxBigOps>::vec_znx_big_normalize_tmp_bytes(self), tmp_bytes.len() >= <Module as VecZnxBigOps>::vec_znx_big_normalize_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",

View File

@@ -1,8 +1,8 @@
use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft;
use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t};
use crate::{alloc_aligned, VecZnx, DEFAULTALIGN}; use crate::{BACKEND, Infos, Module, VecZnxBig, assert_alignement};
use crate::{assert_alignement, Infos, Module, VecZnxBig, BACKEND}; use crate::{DEFAULTALIGN, VecZnx, alloc_aligned};
pub struct VecZnxDft { pub struct VecZnxDft {
pub data: Vec<u8>, pub data: Vec<u8>,
@@ -61,14 +61,6 @@ impl VecZnxDft {
} }
} }
pub fn n(&self) -> usize {
self.n
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn backend(&self) -> BACKEND { pub fn backend(&self) -> BACKEND {
self.backend self.backend
} }
@@ -102,12 +94,36 @@ impl VecZnxDft {
} }
} }
impl Infos for VecZnxDft {
/// Returns the base 2 logarithm of the [VecZnx] degree.
fn log_n(&self) -> usize {
(usize::BITS - (self.n - 1).leading_zeros()) as _
}
/// Returns the [VecZnx] degree.
fn n(&self) -> usize {
self.n
}
/// Returns the number of cols of the [VecZnx].
fn cols(&self) -> usize {
self.cols
}
/// Returns the number of rows of the [VecZnx].
fn rows(&self) -> usize {
1
}
}
pub trait VecZnxDftOps { pub trait VecZnxDftOps {
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
fn new_vec_znx_dft(&self, cols: usize) -> VecZnxDft; fn new_vec_znx_dft(&self, cols: usize) -> VecZnxDft;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// Returns a new [VecZnxDft] with the provided bytes array as backing array.
/// ///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments /// # Arguments
/// ///
/// * `cols`: the number of cols of the [VecZnxDft]. /// * `cols`: the number of cols of the [VecZnxDft].
@@ -117,6 +133,19 @@ pub trait VecZnxDftOps {
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft; fn new_vec_znx_dft_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: the backing array is only borrowed.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// Returns a new [VecZnxDft] with the provided bytes array as backing array.
/// ///
/// # Arguments /// # Arguments
@@ -133,28 +162,15 @@ pub trait VecZnxDftOps {
fn vec_znx_idft_tmp_bytes(&self) -> usize; fn vec_znx_idft_tmp_bytes(&self) -> usize;
/// b <- IDFT(a), uses a as scratch space. /// b <- IDFT(a), uses a as scratch space.
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize); fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft);
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, a_cols: usize, tmp_bytes: &mut [u8]); fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]);
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize); fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx);
fn vec_znx_dft_automorphism( fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft);
&self,
k: i64,
b: &mut VecZnxDft,
b_cols: usize,
a: &VecZnxDft,
a_cols: usize,
);
fn vec_znx_dft_automorphism_inplace( fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]);
&self,
k: i64,
a: &mut VecZnxDft,
a_cols: usize,
tmp_bytes: &mut [u8],
);
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize; fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize;
} }
@@ -173,37 +189,25 @@ impl VecZnxDftOps for Module {
} }
fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
debug_assert!(
tmp_bytes.len() >= Self::bytes_of_vec_znx_dft(self, cols),
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
tmp_bytes.len(),
Self::bytes_of_vec_znx_dft(self, cols)
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr())
}
VecZnxDft::from_bytes(self, cols, tmp_bytes) VecZnxDft::from_bytes(self, cols, tmp_bytes)
} }
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
VecZnxDft::from_bytes_borrow(self, cols, tmp_bytes)
}
fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize { fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize {
unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize } unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize }
} }
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize) { fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) {
debug_assert!(
b.cols() >= a_cols,
"invalid c_vector: b_vector.cols()={} < a_cols={}",
b.cols(),
a_cols
);
unsafe { unsafe {
vec_znx_dft::vec_znx_idft_tmp_a( vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr, self.ptr,
b.ptr as *mut vec_znx_big_t, b.ptr as *mut vec_znx_big_t,
b.cols() as u64, b.cols() as u64,
a.ptr as *mut vec_znx_dft_t, a.ptr as *mut vec_znx_dft_t,
a_cols as u64, a.cols() as u64,
) )
} }
} }
@@ -216,41 +220,23 @@ impl VecZnxDftOps for Module {
/// ///
/// # Panics /// # Panics
/// If b.cols < a_cols /// If b.cols < a_cols
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize) { fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx) {
debug_assert!(
b.cols() >= a_cols,
"invalid a_cols: b.cols()={} < a_cols={}",
b.cols(),
a_cols
);
unsafe { unsafe {
vec_znx_dft::vec_znx_dft( vec_znx_dft::vec_znx_dft(
self.ptr, self.ptr,
b.ptr as *mut vec_znx_dft_t, b.ptr as *mut vec_znx_dft_t,
b.cols() as u64, b.cols() as u64,
a.as_ptr(), a.as_ptr(),
a_cols as u64, a.cols() as u64,
a.n() as u64, a.n() as u64,
) )
} }
} }
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, a_cols: usize, tmp_bytes: &mut [u8]) { fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(
b.cols() >= a_cols,
"invalid c_vector: b.cols()={} < a_cols={}",
b.cols(),
a_cols
);
assert!(
a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}",
a.cols(),
a_cols
);
assert!( assert!(
tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self), tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
@@ -263,65 +249,31 @@ impl VecZnxDftOps for Module {
vec_znx_dft::vec_znx_idft( vec_znx_dft::vec_znx_idft(
self.ptr, self.ptr,
b.ptr as *mut vec_znx_big_t, b.ptr as *mut vec_znx_big_t,
a.cols() as u64, b.cols() as u64,
a.ptr as *const vec_znx_dft_t, a.ptr as *const vec_znx_dft_t,
a_cols as u64, a.cols() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
) )
} }
} }
fn vec_znx_dft_automorphism( fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft) {
&self,
k: i64,
b: &mut VecZnxDft,
b_cols: usize,
a: &VecZnxDft,
a_cols: usize,
) {
#[cfg(debug_assertions)]
{
assert!(
b.cols() >= a_cols,
"invalid c_vector: b.cols()={} < a_cols={}",
b.cols(),
a_cols
);
assert!(
a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}",
a.cols(),
a_cols
);
}
unsafe { unsafe {
vec_znx_dft::vec_znx_dft_automorphism( vec_znx_dft::vec_znx_dft_automorphism(
self.ptr, self.ptr,
k, k,
b.ptr as *mut vec_znx_dft_t, b.ptr as *mut vec_znx_dft_t,
b_cols as u64, b.cols() as u64,
a.ptr as *const vec_znx_dft_t, a.ptr as *const vec_znx_dft_t,
a_cols as u64, a.cols() as u64,
[0u8; 0].as_mut_ptr(), [0u8; 0].as_mut_ptr(),
); );
} }
} }
fn vec_znx_dft_automorphism_inplace( fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]) {
&self,
k: i64,
a: &mut VecZnxDft,
a_cols: usize,
tmp_bytes: &mut [u8],
) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(
a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}",
a.cols(),
a_cols
);
assert!( assert!(
tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self), tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}", "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}",
@@ -335,9 +287,9 @@ impl VecZnxDftOps for Module {
self.ptr, self.ptr,
k, k,
a.ptr as *mut vec_znx_dft_t, a.ptr as *mut vec_znx_dft_t,
a_cols as u64, a.cols() as u64,
a.ptr as *const vec_znx_dft_t, a.ptr as *const vec_znx_dft_t,
a_cols as u64, a.cols() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
); );
} }
@@ -355,11 +307,9 @@ impl VecZnxDftOps for Module {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{ use crate::{BACKEND, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned};
alloc_aligned, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, BACKEND,
};
use itertools::izip; use itertools::izip;
use sampling::source::{new_seed, Source}; use sampling::source::{Source, new_seed};
#[test] #[test]
fn test_automorphism_dft() { fn test_automorphism_dft() {
@@ -379,16 +329,16 @@ mod tests {
let p: i64 = -5; let p: i64 = -5;
// a_dft <- DFT(a) // a_dft <- DFT(a)
module.vec_znx_dft(&mut a_dft, &a, cols); module.vec_znx_dft(&mut a_dft, &a);
// a_dft <- AUTO(a_dft) // a_dft <- AUTO(a_dft)
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, cols, &mut tmp_bytes); module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
// a <- AUTO(a) // a <- AUTO(a)
module.vec_znx_automorphism_inplace(p, &mut a, cols); module.vec_znx_automorphism_inplace(p, &mut a);
// b_dft <- DFT(AUTO(a)) // b_dft <- DFT(AUTO(a))
module.vec_znx_dft(&mut b_dft, &a, cols); module.vec_znx_dft(&mut b_dft, &a);
let a_f64: &[f64] = a_dft.raw(&module); let a_f64: &[f64] = a_dft.raw(&module);
let b_f64: &[f64] = b_dft.raw(&module); let b_f64: &[f64] = b_dft.raw(&module);

View File

@@ -1,9 +1,7 @@
use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp::{self, vmp_pmat_t}; use crate::ffi::vmp::{self, vmp_pmat_t};
use crate::{ use crate::{BACKEND, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement};
alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, BACKEND,
};
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array. /// stored as a 3D matrix in the DFT domain in a single contiguous array.
@@ -26,6 +24,7 @@ pub struct VmpPMat {
/// The ring degree of each [VecZnxDft]. /// The ring degree of each [VecZnxDft].
n: usize, n: usize,
#[warn(dead_code)]
backend: BACKEND, backend: BACKEND,
} }
@@ -99,8 +98,7 @@ impl VmpPMat {
if self.n < 8 { if self.n < 8 {
res.copy_from_slice( res.copy_from_slice(
&self.raw::<T>()[(row + col * self.rows()) * self.n() &self.raw::<T>()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)],
..(row + col * self.rows()) * (self.n() + 1)],
); );
} else { } else {
(0..self.n >> 3).for_each(|blk| { (0..self.n >> 3).for_each(|blk| {
@@ -119,10 +117,7 @@ impl VmpPMat {
if col == (ncols - 1) && (ncols & 1 == 1) { if col == (ncols - 1) && (ncols & 1 == 1) {
&self.raw::<T>()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..] &self.raw::<T>()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..]
} else { } else {
&self.raw::<T>()[blk * nrows * ncols * 8 &self.raw::<T>()[blk * nrows * ncols * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..]
+ (col / 2) * (2 * nrows) * 8
+ row * 2 * 8
+ (col % 2) * 8..]
} }
} }
} }
@@ -219,13 +214,7 @@ pub trait VmpPMatOps {
/// * `a_cols`: number of cols of the input [VecZnx]. /// * `a_cols`: number of cols of the input [VecZnx].
/// * `rows`: number of rows of the input [VmpPMat]. /// * `rows`: number of rows of the input [VmpPMat].
/// * `cols`: number of cols of the input [VmpPMat]. /// * `cols`: number of cols of the input [VmpPMat].
fn vmp_apply_dft_tmp_bytes( fn vmp_apply_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize;
&self,
c_cols: usize,
a_cols: usize,
rows: usize,
cols: usize,
) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. /// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
/// ///
@@ -253,6 +242,32 @@ pub trait VmpPMatOps {
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver.
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
///
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` cols.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [VecZnxDft].
///
/// # Arguments
///
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
/// * `a`: the left operand [VecZnx] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft]. /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft].
/// ///
/// # Arguments /// # Arguments
@@ -261,13 +276,7 @@ pub trait VmpPMatOps {
/// * `a_cols`: number of cols of the input [VecZnxDft]. /// * `a_cols`: number of cols of the input [VecZnxDft].
/// * `rows`: number of rows of the input [VmpPMat]. /// * `rows`: number of rows of the input [VmpPMat].
/// * `cols`: number of cols of the input [VmpPMat]. /// * `cols`: number of cols of the input [VmpPMat].
fn vmp_apply_dft_to_dft_tmp_bytes( fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize;
&self,
c_cols: usize,
a_cols: usize,
rows: usize,
cols: usize,
) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. /// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
@@ -296,6 +305,33 @@ pub trait VmpPMatOps {
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it.
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
///
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` cols.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [VecZnxDft].
///
/// # Arguments
///
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place. /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place.
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
/// ///
@@ -461,13 +497,7 @@ impl VmpPMatOps for Module {
} }
} }
fn vmp_apply_dft_tmp_bytes( fn vmp_apply_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize {
&self,
res_cols: usize,
a_cols: usize,
gct_rows: usize,
gct_cols: usize,
) -> usize {
unsafe { unsafe {
vmp::vmp_apply_dft_tmp_bytes( vmp::vmp_apply_dft_tmp_bytes(
self.ptr, self.ptr,
@@ -480,9 +510,7 @@ impl VmpPMatOps for Module {
} }
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!( debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
);
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
@@ -503,13 +531,29 @@ impl VmpPMatOps for Module {
} }
} }
fn vmp_apply_dft_to_dft_tmp_bytes( fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
&self, debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
res_cols: usize, #[cfg(debug_assertions)]
a_cols: usize, {
gct_rows: usize, assert_alignement(tmp_bytes.as_ptr());
gct_cols: usize, }
) -> usize { unsafe {
vmp::vmp_apply_dft_add(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.cols() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
b.cols() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize {
unsafe { unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes( vmp::vmp_apply_dft_to_dft_tmp_bytes(
self.ptr, self.ptr,
@@ -521,17 +565,8 @@ impl VmpPMatOps for Module {
} }
} }
fn vmp_apply_dft_to_dft( fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) {
&self, debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
c: &mut VecZnxDft,
a: &VecZnxDft,
b: &VmpPMat,
tmp_bytes: &mut [u8],
) {
debug_assert!(
tmp_bytes.len()
>= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
);
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
@@ -551,11 +586,29 @@ impl VmpPMatOps for Module {
} }
} }
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_apply_dft_to_dft_add(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.cols() as u64,
a.ptr as *const vec_znx_dft_t,
a.cols() as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
b.cols() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) { fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!( debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols()));
tmp_bytes.len()
>= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols())
);
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
@@ -579,8 +632,7 @@ impl VmpPMatOps for Module {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{ use crate::{
alloc_aligned, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, alloc_aligned,
VecZnxOps, VmpPMat, VmpPMatOps,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -598,13 +650,12 @@ mod tests {
let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols); let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols); let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
let mut tmp_bytes: Vec<u8> = let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols));
alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols));
for row_i in 0..vpmat_rows { for row_i in 0..vpmat_rows {
let mut source: Source = Source::new([0u8; 32]); let mut source: Source = Source::new([0u8; 32]);
module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source); module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source);
module.vec_znx_dft(&mut a_dft, &a, vpmat_cols); module.vec_znx_dft(&mut a_dft, &a);
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes); module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
// Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft) // Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft)
@@ -617,7 +668,7 @@ mod tests {
// Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big) // Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big)
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i);
module.vec_znx_idft(&mut a_big, &a_dft, vpmat_cols, &mut tmp_bytes); module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes);
assert_eq!(a_big.raw::<i64>(&module), b_big.raw::<i64>(&module)); assert_eq!(a_big.raw::<i64>(&module), b_big.raw::<i64>(&module));
} }

View File

@@ -1,13 +1,12 @@
use base2k::{ use base2k::{
Infos, BACKEND, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, BACKEND, Infos, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8,
VmpPMat, alloc_aligned_u8,
}; };
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use rlwe::{ use rlwe::{
ciphertext::{Ciphertext, new_gadget_ciphertext}, ciphertext::{Ciphertext, new_gadget_ciphertext},
elem::ElemCommon, elem::ElemCommon,
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes}, encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
gadget_product::{gadget_product_core, gadget_product_tmp_bytes}, gadget_product::{gadget_product_core, gadget_product_core_tmp_bytes},
keys::SecretKey, keys::SecretKey,
parameters::{Parameters, ParametersLiteral}, parameters::{Parameters, ParametersLiteral},
}; };
@@ -19,20 +18,16 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
res_dft_0: &'a mut VecZnxDft, res_dft_0: &'a mut VecZnxDft,
res_dft_1: &'a mut VecZnxDft, res_dft_1: &'a mut VecZnxDft,
a: &'a VecZnx, a: &'a VecZnx,
a_cols: usize,
b: &'a Ciphertext<VmpPMat>, b: &'a Ciphertext<VmpPMat>,
b_cols: usize, b_cols: usize,
tmp_bytes: &'a mut [u8], tmp_bytes: &'a mut [u8],
) -> Box<dyn FnMut() + 'a> { ) -> Box<dyn FnMut() + 'a> {
Box::new(move || { Box::new(move || {
gadget_product_core( gadget_product_core(module, res_dft_0, res_dft_1, a, b, b_cols, tmp_bytes);
module, res_dft_0, res_dft_1, a, a_cols, b, b_cols, tmp_bytes,
);
}) })
} }
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("gadget_product_inplace");
c.benchmark_group("gadget_product_inplace");
for log_n in 10..11 { for log_n in 10..11 {
let params_lit: ParametersLiteral = ParametersLiteral { let params_lit: ParametersLiteral = ParametersLiteral {
@@ -50,7 +45,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8( let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
| gadget_product_tmp_bytes( | gadget_product_core_tmp_bytes(
params.module(), params.module(),
params.log_base2k(), params.log_base2k(),
params.log_q(), params.log_q(),
@@ -119,7 +114,6 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
.module() .module()
.fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa); .fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa);
let a_cols: usize = a.cols();
let b_cols: usize = gadget_ct.cols(); let b_cols: usize = gadget_ct.cols();
let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), { let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), {
@@ -128,7 +122,6 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
&mut res_dft_0, &mut res_dft_0,
&mut res_dft_1, &mut res_dft_1,
&mut a, &mut a,
a_cols,
&gadget_ct, &gadget_ct,
b_cols, b_cols,
&mut tmp_bytes, &mut tmp_bytes,

View File

@@ -22,10 +22,8 @@ fn main() {
let params: Parameters = Parameters::new(&params_lit); let params: Parameters = Parameters::new(&params_lit);
let mut tmp_bytes: Vec<u8> = alloc_aligned( let mut tmp_bytes: Vec<u8> =
params.decrypt_rlwe_tmp_byte(params.log_q()) alloc_aligned(params.decrypt_rlwe_tmp_byte(params.log_q()) | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()));
| params.encrypt_rlwe_sk_tmp_bytes(params.log_q()),
);
let mut source: Source = Source::new([0; 32]); let mut source: Source = Source::new([0; 32]);
let mut sk: SecretKey = SecretKey::new(params.module()); let mut sk: SecretKey = SecretKey::new(params.module());

View File

@@ -1,151 +0,0 @@
use base2k::{
Encoding, Infos, Module, Sampling, SvpPPol, SvpPPolOps, VecZnx, VecZnxDftOps, VecZnxOps,
VmpPMat, VmpPMatOps, is_aligned,
};
use itertools::izip;
use rlwe::ciphertext::{Ciphertext, new_gadget_ciphertext};
use rlwe::elem::ElemCommon;
use rlwe::encryptor::encrypt_rlwe_sk;
use rlwe::keys::SecretKey;
use rlwe::plaintext::Plaintext;
use sampling::source::{Source, new_seed};
fn main() {
let n: usize = 32;
let module: Module = Module::new(n, base2k::BACKEND::FFT64);
let log_base2k: usize = 16;
let log_k: usize = 32;
let cols: usize = 4;
let mut a: VecZnx = module.new_vec_znx(cols);
let mut data: Vec<i64> = vec![0i64; n];
data[0] = 0;
data[1] = 0;
a.encode_vec_i64(log_base2k, log_k, &data, 16);
let mut a_dft: base2k::VecZnxDft = module.new_vec_znx_dft(cols);
module.vec_znx_dft(&mut a_dft, &a, cols);
(0..cols).for_each(|i| {
println!("{:?}", a_dft.at::<f64>(&module, i));
})
}
pub struct GadgetCiphertextProtocol {}
impl GadgetCiphertextProtocol {
pub fn new() -> GadgetCiphertextProtocol {
Self {}
}
pub fn allocate(
module: &Module,
log_base2k: usize,
rows: usize,
log_q: usize,
) -> GadgetCiphertextShare {
GadgetCiphertextShare::new(module, log_base2k, rows, log_q)
}
pub fn gen_share(
module: &Module,
sk: &SecretKey,
pt: &Plaintext,
seed: &[u8; 32],
share: &mut GadgetCiphertextShare,
tmp_bytes: &mut [u8],
) {
share.seed.copy_from_slice(seed);
let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(*seed);
let mut sk_ppol: SvpPPol = module.new_svp_ppol();
sk.prepare(module, &mut sk_ppol);
share.value.iter_mut().for_each(|ai| {
//let elem = Elem<VecZnx>{};
//encrypt_rlwe_sk_thread_safe(module, ai, Some(pt.elem()), &sk_ppol, &mut source_xa, &mut source_xe, 3.2, tmp_bytes);
})
}
}
pub struct GadgetCiphertextShare {
pub seed: [u8; 32],
pub log_q: usize,
pub log_base2k: usize,
pub value: Vec<VecZnx>,
}
impl GadgetCiphertextShare {
pub fn new(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Self {
let value: Vec<VecZnx> = Vec::new();
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
(0..rows).for_each(|_| {
let vec_znx: VecZnx = module.new_vec_znx(cols);
});
Self {
seed: [u8::default(); 32],
log_q: log_q,
log_base2k: log_base2k,
value: value,
}
}
pub fn rows(&self) -> usize {
self.value.len()
}
pub fn cols(&self) -> usize {
self.value[0].cols()
}
pub fn aggregate_inplace(&mut self, module: &Module, a: &GadgetCiphertextShare) {
izip!(self.value.iter_mut(), a.value.iter()).for_each(|(bi, ai)| {
module.vec_znx_add_inplace(bi, ai);
})
}
pub fn get(&self, module: &Module, b: &mut Ciphertext<VmpPMat>, tmp_bytes: &mut [u8]) {
assert!(is_aligned(tmp_bytes.as_ptr()));
let rows: usize = b.rows();
let cols: usize = b.cols();
assert!(tmp_bytes.len() >= gadget_ciphertext_share_get_tmp_bytes(module, rows, cols));
assert_eq!(self.value.len(), rows);
assert_eq!(self.value[0].cols(), cols);
let (tmp_bytes_vmp_prepare_row, tmp_bytes_vec_znx) =
tmp_bytes.split_at_mut(module.vmp_prepare_tmp_bytes(rows, cols));
let mut c: VecZnx = VecZnx::from_bytes_borrow(module.n(), cols, tmp_bytes_vec_znx);
let mut source: Source = Source::new(self.seed);
(0..self.value.len()).for_each(|row_i| {
module.vmp_prepare_row(
b.at_mut(0),
self.value[row_i].raw(),
row_i,
tmp_bytes_vmp_prepare_row,
);
module.fill_uniform(self.log_base2k, &mut c, cols, &mut source);
module.vmp_prepare_row(b.at_mut(1), c.raw(), row_i, tmp_bytes_vmp_prepare_row)
})
}
pub fn get_new(&self, module: &Module, tmp_bytes: &mut [u8]) -> Ciphertext<VmpPMat> {
let mut b: Ciphertext<VmpPMat> =
new_gadget_ciphertext(module, self.log_base2k, self.rows(), self.log_q);
self.get(module, &mut b, tmp_bytes);
b
}
}
pub fn gadget_ciphertext_share_get_tmp_bytes(module: &Module, rows: usize, cols: usize) -> usize {
module.vmp_prepare_tmp_bytes(rows, cols) + module.bytes_of_vec_znx(cols)
}
pub struct CircularCiphertextProtocol {}
pub struct CircularGadgetCiphertextProtocol {}

349
rlwe/src/automorphism.rs Normal file
View File

@@ -0,0 +1,349 @@
use crate::{
ciphertext::{Ciphertext, new_gadget_ciphertext},
elem::ElemCommon,
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
key_switching::{key_switch_rlwe, key_switch_rlwe_inplace, key_switch_tmp_bytes},
keys::SecretKey,
parameters::Parameters,
};
use base2k::{
Module, Scalar, ScalarOps, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat,
VmpPMatOps, assert_alignement,
};
use sampling::source::Source;
use std::{cmp::min, collections::HashMap};
/// Stores DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E, A]) where AUTO(X, p): X^{i} -> X^{i*p}
pub struct AutomorphismKey {
pub value: Ciphertext<VmpPMat>,
pub p: i64,
}
pub fn automorphis_key_new_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
module.bytes_of_scalar() + module.bytes_of_svp_ppol() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
}
impl Parameters {
pub fn automorphism_key_new_tmp_bytes(&self, rows: usize, log_q: usize) -> usize {
automorphis_key_new_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
}
pub fn automorphism_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
automorphism_tmp_bytes(
self.module(),
self.log_base2k(),
res_logq,
in_logq,
gct_logq,
)
}
}
impl AutomorphismKey {
pub fn new(
module: &Module,
p: i64,
sk: &SecretKey,
log_base2k: usize,
rows: usize,
log_q: usize,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) -> Self {
Self::new_many_core(
module,
&vec![p],
sk,
log_base2k,
rows,
log_q,
source_xa,
source_xe,
sigma,
tmp_bytes,
)
.into_iter()
.next()
.unwrap()
}
pub fn new_many(
module: &Module,
p: &Vec<i64>,
sk: &SecretKey,
log_base2k: usize,
rows: usize,
log_q: usize,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) -> HashMap<i64, AutomorphismKey> {
Self::new_many_core(
module, p, sk, log_base2k, rows, log_q, source_xa, source_xe, sigma, tmp_bytes,
)
.into_iter()
.zip(p.iter().cloned())
.map(|(key, pi)| (pi, key))
.collect()
}
fn new_many_core(
module: &Module,
p: &Vec<i64>,
sk: &SecretKey,
log_base2k: usize,
rows: usize,
log_q: usize,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) -> Vec<Self> {
let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar());
let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol());
let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes);
let mut sk_out: SvpPPol = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes);
let mut keys: Vec<AutomorphismKey> = Vec::new();
p.iter().for_each(|pi| {
let mut value: Ciphertext<VmpPMat> = new_gadget_ciphertext(module, log_base2k, rows, log_q);
let p_inv: i64 = module.galois_element_inv(*pi);
module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx());
module.svp_prepare(&mut sk_out, &sk_auto);
encrypt_grlwe_sk(
module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes,
);
keys.push(Self {
value: value,
p: *pi,
})
});
keys
}
}
pub fn automorphism_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
key_switch_tmp_bytes(module, log_base2k, res_logq, in_logq, gct_logq)
}
pub fn automorphism(
module: &Module,
c: &mut Ciphertext<VecZnx>,
a: &Ciphertext<VecZnx>,
b: &AutomorphismKey,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
key_switch_rlwe(module, c, a, &b.value, b_cols, tmp_bytes);
// c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
module.vec_znx_automorphism_inplace(b.p, c.at_mut(0));
// c[1] = AUTO(b, p)
module.vec_znx_automorphism_inplace(b.p, c.at_mut(1));
}
pub fn automorphism_inplace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
+ 2 * module.bytes_of_vec_znx_dft(std::cmp::min(c_cols, a_cols));
}
pub fn automorphism_inplace(
module: &Module,
a: &mut Ciphertext<VecZnx>,
b: &AutomorphismKey,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
key_switch_rlwe_inplace(module, a, &b.value, b_cols, tmp_bytes);
// a[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
module.vec_znx_automorphism_inplace(b.p, a.at_mut(0));
// a[1] = AUTO(b, p)
module.vec_znx_automorphism_inplace(b.p, a.at_mut(1));
}
pub fn automorphism_big(
module: &Module,
c: &mut Ciphertext<VecZnxBig>,
a: &Ciphertext<VecZnx>,
b: &AutomorphismKey,
tmp_bytes: &mut [u8],
) {
let cols = std::cmp::min(c.cols(), a.cols());
#[cfg(debug_assertions)]
{
assert!(tmp_bytes.len() >= automorphism_tmp_bytes(module, c.cols(), a.cols(), b.value.rows(), b.value.cols()));
assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
// a1_dft = DFT(a[1])
module.vec_znx_dft(&mut a1_dft, a.at(1));
// res_dft = IDFT(<DFT(a), DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E])>) = [-b*AUTO(s, -p) + a * s + e]
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(0), tmp_bytes);
module.vec_znx_idft_tmp_a(c.at_mut(0), &mut res_dft);
// res_dft = [-b*AUTO(s, -p) + a * s + e] + [-a * s + m + e] = [-b*AUTO(s, -p) + m + e]
module.vec_znx_big_add_small_inplace(c.at_mut(0), a.at(0));
// c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
module.vec_znx_big_automorphism_inplace(b.p, c.at_mut(0));
// res_dft = IDFT(<DFT(a), DFT([A])>) = [b]
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(1), tmp_bytes);
module.vec_znx_idft_tmp_a(c.at_mut(1), &mut res_dft);
// c[1] = AUTO(b, p)
module.vec_znx_big_automorphism_inplace(b.p, c.at_mut(1));
}
#[cfg(test)]
mod test {
use super::{AutomorphismKey, automorphism};
use crate::{
ciphertext::Ciphertext,
decryptor::decrypt_rlwe,
elem::ElemCommon,
encryptor::encrypt_rlwe_sk,
keys::SecretKey,
parameters::{Parameters, ParametersLiteral},
plaintext::Plaintext,
};
use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, alloc_aligned};
use sampling::source::{Source, new_seed};
#[test]
fn test_automorphism() {
let log_base2k: usize = 10;
let log_q: usize = 50;
let log_p: usize = 15;
// Basic parameters with enough limbs to test edge cases
let params_lit: ParametersLiteral = ParametersLiteral {
backend: BACKEND::FFT64,
log_n: 12,
log_q: log_q,
log_p: log_p,
log_base2k: log_base2k,
log_scale: 20,
xe: 3.2,
xs: 1 << 11,
};
let params: Parameters = Parameters::new(&params_lit);
let module: &Module = params.module();
let log_q: usize = params.log_q();
let log_qp: usize = params.log_qp();
let gct_rows: usize = params.cols_q();
let gct_cols: usize = params.cols_qp();
// scratch space
let mut tmp_bytes: Vec<u8> = alloc_aligned(
params.decrypt_rlwe_tmp_byte(log_q)
| params.encrypt_rlwe_sk_tmp_bytes(log_q)
| params.automorphism_key_new_tmp_bytes(gct_rows, log_qp)
| params.automorphism_tmp_bytes(log_q, log_q, log_qp),
);
// Samplers for public and private randomness
let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(new_seed());
let mut source_xs: Source = Source::new(new_seed());
let mut sk: SecretKey = SecretKey::new(module);
sk.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol();
module.svp_prepare(&mut sk_svp_ppol, &sk.0);
let p: i64 = -5;
let auto_key: AutomorphismKey = AutomorphismKey::new(
module,
p,
&sk,
log_base2k,
gct_rows,
log_qp,
&mut source_xa,
&mut source_xe,
params.xe(),
&mut tmp_bytes,
);
let mut data: Vec<i64> = vec![0i64; params.n()];
data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
let log_k: usize = 2 * log_base2k;
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
let mut pt: Plaintext = params.new_plaintext(log_q);
let mut pt_auto: Plaintext = params.new_plaintext(log_q);
pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32);
module.vec_znx_automorphism(p, pt_auto.at_mut(0), pt.at(0));
encrypt_rlwe_sk(
module,
&mut ct.elem_mut(),
Some(pt.at(0)),
&sk_svp_ppol,
&mut source_xa,
&mut source_xe,
params.xe(),
&mut tmp_bytes,
);
let mut ct_auto: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
// ct <- AUTO(ct)
automorphism(
module,
&mut ct_auto,
&ct,
&auto_key,
gct_cols,
&mut tmp_bytes,
);
// pt = dec(auto(ct)) - auto(pt)
decrypt_rlwe(
module,
pt.elem_mut(),
ct_auto.elem(),
&sk_svp_ppol,
&mut tmp_bytes,
);
module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_auto.at(0));
// pt.at(0).print(pt.cols(), 16);
let noise_have: f64 = pt.at(0).std(log_base2k).log2();
let var_msg: f64 = (params.xs() as f64) / params.n() as f64;
let var_a_err: f64 = 1f64 / 12f64;
let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, ct_auto.log_q(), auto_key.value.log_q());
println!("noise_pred: {}", noise_pred);
println!("noise_have: {}", noise_have);
assert!(noise_have <= noise_pred + 1.0);
}
}

View File

@@ -74,24 +74,14 @@ pub fn new_rlwe_ciphertext(module: &Module, log_base2k: usize, log_q: usize) ->
Ciphertext::<VecZnx>::new(module, log_base2k, log_q, rows) Ciphertext::<VecZnx>::new(module, log_base2k, log_q, rows)
} }
pub fn new_gadget_ciphertext( pub fn new_gadget_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext<VmpPMat> {
module: &Module,
log_base2k: usize,
rows: usize,
log_q: usize,
) -> Ciphertext<VmpPMat> {
let cols: usize = (log_q + log_base2k - 1) / log_base2k; let cols: usize = (log_q + log_base2k - 1) / log_base2k;
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 2, rows, cols); let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 2, rows, cols);
elem.log_q = log_q; elem.log_q = log_q;
Ciphertext(elem) Ciphertext(elem)
} }
pub fn new_rgsw_ciphertext( pub fn new_rgsw_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext<VmpPMat> {
module: &Module,
log_base2k: usize,
rows: usize,
log_q: usize,
) -> Ciphertext<VmpPMat> {
let cols: usize = (log_q + log_base2k - 1) / log_base2k; let cols: usize = (log_q + log_base2k - 1) / log_base2k;
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 4, rows, cols); let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 4, rows, cols);
elem.log_q = log_q; elem.log_q = log_q;

View File

@@ -9,6 +9,7 @@ use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZn
use std::cmp::min; use std::cmp::min;
pub struct Decryptor { pub struct Decryptor {
#[warn(dead_code)]
sk: SvpPPol, sk: SvpPPol,
} }
@@ -32,24 +33,12 @@ impl Parameters {
) )
} }
pub fn decrypt_rlwe( pub fn decrypt_rlwe(&self, res: &mut Plaintext, ct: &Ciphertext<VecZnx>, sk: &SvpPPol, tmp_bytes: &mut [u8]) {
&self,
res: &mut Plaintext,
ct: &Ciphertext<VecZnx>,
sk: &SvpPPol,
tmp_bytes: &mut [u8],
) {
decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes) decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes)
} }
} }
pub fn decrypt_rlwe( pub fn decrypt_rlwe(module: &Module, res: &mut Elem<VecZnx>, a: &Elem<VecZnx>, sk: &SvpPPol, tmp_bytes: &mut [u8]) {
module: &Module,
res: &mut Elem<VecZnx>,
a: &Elem<VecZnx>,
sk: &SvpPPol,
tmp_bytes: &mut [u8],
) {
let cols: usize = a.cols(); let cols: usize = a.cols();
assert!( assert!(
@@ -59,17 +48,15 @@ pub fn decrypt_rlwe(
decrypt_rlwe_tmp_byte(module, cols) decrypt_rlwe_tmp_byte(module, cols)
); );
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let mut res_dft: VecZnxDft = let mut res_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, cols, tmp_bytes_vec_znx_dft);
VecZnxDft::from_bytes_borrow(module, a.cols(), tmp_bytes_vec_znx_dft);
let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big(); let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big();
// res_dft <- DFT(ct[1]) * DFT(sk) // res_dft <- DFT(ct[1]) * DFT(sk)
module.svp_apply_dft(&mut res_dft, sk, a.at(1), cols); module.svp_apply_dft(&mut res_dft, sk, a.at(1));
// res_big <- ct[1] x sk // res_big <- ct[1] x sk
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols); module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
// res_big <- ct[1] x sk + ct[0] // res_big <- ct[1] x sk + ct[0]
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0)); module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
// res <- normalize(ct[1] x sk + ct[0]) // res <- normalize(ct[1] x sk + ct[0])

View File

@@ -1,7 +1,5 @@
use base2k::{Infos, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps}; use base2k::{Infos, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps};
use crate::parameters::Parameters;
pub struct Elem<T> { pub struct Elem<T> {
pub value: Vec<T>, pub value: Vec<T>,
pub log_base2k: usize, pub log_base2k: usize,
@@ -10,20 +8,8 @@ pub struct Elem<T> {
} }
pub trait ElemVecZnx { pub trait ElemVecZnx {
fn from_bytes( fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx>;
module: &Module, fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx>;
log_base2k: usize,
log_q: usize,
size: usize,
bytes: &mut [u8],
) -> Elem<VecZnx>;
fn from_bytes_borrow(
module: &Module,
log_base2k: usize,
log_q: usize,
size: usize,
bytes: &mut [u8],
) -> Elem<VecZnx>;
fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize; fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize;
fn zero(&mut self); fn zero(&mut self);
} }
@@ -34,13 +20,7 @@ impl ElemVecZnx for Elem<VecZnx> {
module.n() * cols * size * 8 module.n() * cols * size * 8
} }
fn from_bytes( fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx> {
module: &Module,
log_base2k: usize,
log_q: usize,
size: usize,
bytes: &mut [u8],
) -> Elem<VecZnx> {
assert!(size > 0); assert!(size > 0);
let n: usize = module.n(); let n: usize = module.n();
assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size)); assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size));
@@ -60,13 +40,7 @@ impl ElemVecZnx for Elem<VecZnx> {
} }
} }
fn from_bytes_borrow( fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx> {
module: &Module,
log_base2k: usize,
log_q: usize,
size: usize,
bytes: &mut [u8],
) -> Elem<VecZnx> {
assert!(size > 0); assert!(size > 0);
let n: usize = module.n(); let n: usize = module.n();
assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size)); assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size));

View File

@@ -5,12 +5,38 @@ use crate::parameters::Parameters;
use crate::plaintext::Plaintext; use crate::plaintext::Plaintext;
use base2k::sampling::Sampling; use base2k::sampling::Sampling;
use base2k::{ use base2k::{
Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, Infos, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat,
VecZnxOps, VmpPMat, VmpPMatOps, VmpPMatOps,
}; };
use sampling::source::{Source, new_seed}; use sampling::source::{Source, new_seed};
impl Parameters {
pub fn encrypt_rlwe_sk_tmp_bytes(&self, log_q: usize) -> usize {
encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q)
}
pub fn encrypt_rlwe_sk(
&self,
ct: &mut Ciphertext<VecZnx>,
pt: Option<&Plaintext>,
sk: &SvpPPol,
source_xa: &mut Source,
source_xe: &mut Source,
tmp_bytes: &mut [u8],
) {
encrypt_rlwe_sk(
self.module(),
&mut ct.0,
pt.map(|pt| pt.at(0)),
sk,
source_xa,
source_xe,
self.xe(),
tmp_bytes,
)
}
}
pub struct EncryptorSk { pub struct EncryptorSk {
sk: SvpPPol, sk: SvpPPol,
source_xa: Source, source_xa: Source,
@@ -49,12 +75,7 @@ impl EncryptorSk {
self.source_xe = Source::new(seed) self.source_xe = Source::new(seed)
} }
pub fn encrypt_rlwe_sk( pub fn encrypt_rlwe_sk(&mut self, params: &Parameters, ct: &mut Ciphertext<VecZnx>, pt: Option<&Plaintext>) {
&mut self,
params: &Parameters,
ct: &mut Ciphertext<VecZnx>,
pt: Option<&Plaintext>,
) {
assert!( assert!(
self.initialized == true, self.initialized == true,
"invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]" "invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]"
@@ -86,42 +107,26 @@ impl EncryptorSk {
} }
} }
impl Parameters {
pub fn encrypt_rlwe_sk_tmp_bytes(&self, log_q: usize) -> usize {
encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q)
}
pub fn encrypt_rlwe_sk(
&self,
ct: &mut Ciphertext<VecZnx>,
pt: Option<&Plaintext>,
sk: &SvpPPol,
source_xa: &mut Source,
source_xe: &mut Source,
tmp_bytes: &mut [u8],
) {
encrypt_rlwe_sk(
self.module(),
&mut ct.0,
pt.map(|pt| &pt.0),
sk,
source_xa,
source_xe,
self.xe(),
tmp_bytes,
)
}
}
pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize { pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize {
module.bytes_of_vec_znx_dft((log_q + log_base2k - 1) / log_base2k) module.bytes_of_vec_znx_dft((log_q + log_base2k - 1) / log_base2k) + module.vec_znx_big_normalize_tmp_bytes()
+ module.vec_znx_big_normalize_tmp_bytes()
} }
pub fn encrypt_rlwe_sk( pub fn encrypt_rlwe_sk(
module: &Module, module: &Module,
ct: &mut Elem<VecZnx>, ct: &mut Elem<VecZnx>,
pt: Option<&Elem<VecZnx>>, pt: Option<&VecZnx>,
sk: &SvpPPol,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) {
encrypt_rlwe_sk_core::<0>(module, ct, pt, sk, source_xa, source_xe, sigma, tmp_bytes)
}
fn encrypt_rlwe_sk_core<const PT_POS: u8>(
module: &Module,
ct: &mut Elem<VecZnx>,
pt: Option<&VecZnx>,
sk: &SvpPPol, sk: &SvpPPol,
source_xa: &mut Source, source_xa: &mut Source,
source_xe: &mut Source, source_xe: &mut Source,
@@ -146,36 +151,49 @@ pub fn encrypt_rlwe_sk(
// c1 <- Z_{2^prec}[X]/(X^{N}+1) // c1 <- Z_{2^prec}[X]/(X^{N}+1)
module.fill_uniform(log_base2k, c1, cols, source_xa); module.fill_uniform(log_base2k, c1, cols, source_xa);
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
// Scratch space for DFT values // Scratch space for DFT values
let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, cols, tmp_bytes_vec_znx_dft); let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, cols, tmp_bytes_vec_znx_dft);
// Applies buf_dft <- DFT(s) * DFT(c1) // Applies buf_dft <- DFT(s) * DFT(c1)
module.svp_apply_dft(&mut buf_dft, sk, c1, cols); module.svp_apply_dft(&mut buf_dft, sk, c1);
// Alias scratch space // Alias scratch space
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
// buf_big = s x c1 // buf_big = s x c1
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, cols); module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
match PT_POS {
// c0 <- -s x c1 + m // c0 <- -s x c1 + m
0 => {
let c0: &mut VecZnx = ct.at_mut(0); let c0: &mut VecZnx = ct.at_mut(0);
if let Some(pt) = pt { if let Some(pt) = pt {
module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt.at(0)); module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt);
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
} else { } else {
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
module.vec_znx_negate_inplace(c0); module.vec_znx_negate_inplace(c0);
} }
}
// c1 <- c1 + m
1 => {
if let Some(pt) = pt {
module.vec_znx_add_inplace(c1, pt);
c1.normalize(log_base2k, tmp_bytes_normalize);
}
let c0: &mut VecZnx = ct.at_mut(0);
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
module.vec_znx_negate_inplace(c0);
}
_ => panic!("PT_POS must be 1 or 2"),
}
// c0 <- -s x c1 + m + e // c0 <- -s x c1 + m + e
module.add_normal( module.add_normal(
log_base2k, log_base2k,
c0, ct.at_mut(0),
log_q, log_q,
source_xe, source_xe,
sigma, sigma,
@@ -189,12 +207,7 @@ impl Parameters {
} }
} }
pub fn encrypt_grlwe_sk_tmp_bytes( pub fn encrypt_grlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
module: &Module,
log_base2k: usize,
rows: usize,
log_q: usize,
) -> usize {
let cols = (log_q + log_base2k - 1) / log_base2k; let cols = (log_q + log_base2k - 1) / log_base2k;
Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2) Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2)
+ Plaintext::bytes_of(module, log_base2k, log_q) + Plaintext::bytes_of(module, log_base2k, log_q)
@@ -212,10 +225,93 @@ pub fn encrypt_grlwe_sk(
sigma: f64, sigma: f64,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) { ) {
let rows: usize = ct.rows();
let log_q: usize = ct.log_q(); let log_q: usize = ct.log_q();
//let cols: usize = (log_q + ct.log_base2k() - 1) / ct.log_base2k();
let log_base2k: usize = ct.log_base2k(); let log_base2k: usize = ct.log_base2k();
let (left, right) = ct.0.value.split_at_mut(1);
encrypt_grlwe_sk_core::<0>(
module,
log_base2k,
[&mut left[0], &mut right[0]],
log_q,
m,
sk,
source_xa,
source_xe,
sigma,
tmp_bytes,
)
}
impl Parameters {
pub fn encrypt_rgsw_sk_tmp_bytes(&self, rows: usize, log_q: usize) -> usize {
encrypt_rgsw_sk_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
}
}
pub fn encrypt_rgsw_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
let cols = (log_q + log_base2k - 1) / log_base2k;
Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2)
+ Plaintext::bytes_of(module, log_base2k, log_q)
+ encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
+ module.vmp_prepare_tmp_bytes(rows, cols)
}
pub fn encrypt_rgsw_sk(
module: &Module,
ct: &mut Ciphertext<VmpPMat>,
m: &Scalar,
sk: &SvpPPol,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) {
let log_q: usize = ct.log_q();
let log_base2k: usize = ct.log_base2k();
let (left, right) = ct.0.value.split_at_mut(2);
let (ll, lr) = left.split_at_mut(1);
let (rl, rr) = right.split_at_mut(1);
encrypt_grlwe_sk_core::<0>(
module,
log_base2k,
[&mut ll[0], &mut lr[0]],
log_q,
m,
sk,
source_xa,
source_xe,
sigma,
tmp_bytes,
);
encrypt_grlwe_sk_core::<1>(
module,
log_base2k,
[&mut rl[0], &mut rr[0]],
log_q,
m,
sk,
source_xa,
source_xe,
sigma,
tmp_bytes,
);
}
fn encrypt_grlwe_sk_core<const PT_POS: u8>(
module: &Module,
log_base2k: usize,
mut ct: [&mut VmpPMat; 2],
log_q: usize,
m: &Scalar,
sk: &SvpPPol,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) {
let rows: usize = ct[0].rows();
let min_tmp_bytes_len = encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q); let min_tmp_bytes_len = encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q);
@@ -234,20 +330,18 @@ pub fn encrypt_grlwe_sk(
let (tmp_bytes_enc_sk, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_enc_sk); let (tmp_bytes_enc_sk, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_enc_sk);
let (tmp_bytes_elem, tmp_bytes_vmp_prepare_row) = tmp_bytes.split_at_mut(bytes_of_elem); let (tmp_bytes_elem, tmp_bytes_vmp_prepare_row) = tmp_bytes.split_at_mut(bytes_of_elem);
let mut tmp_elem: Elem<VecZnx> = let mut tmp_elem: Elem<VecZnx> = Elem::<VecZnx>::from_bytes_borrow(module, log_base2k, log_q, 2, tmp_bytes_elem);
Elem::<VecZnx>::from_bytes_borrow(module, log_base2k, ct.log_q(), 2, tmp_bytes_elem); let mut tmp_pt: Plaintext = Plaintext::from_bytes_borrow(module, log_base2k, log_q, tmp_bytes_pt);
let mut tmp_pt: Plaintext =
Plaintext::from_bytes_borrow(module, log_base2k, log_q, tmp_bytes_pt);
(0..rows).for_each(|row_i| { (0..rows).for_each(|row_i| {
// Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i}) // Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i})
tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.raw()); tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.raw());
// Encrypts RLWE(m * 2^{-log_base2k*i}) // Encrypts RLWE(m * 2^{-log_base2k*i})
encrypt_rlwe_sk( encrypt_rlwe_sk_core::<PT_POS>(
module, module,
&mut tmp_elem, &mut tmp_elem,
Some(&tmp_pt.0), Some(tmp_pt.at(0)),
sk, sk,
source_xa, source_xa,
source_xe, source_xe,
@@ -255,31 +349,21 @@ pub fn encrypt_grlwe_sk(
tmp_bytes_enc_sk, tmp_bytes_enc_sk,
); );
//tmp_pt.at(0).print(tmp_pt.cols(), 16);
//println!();
// Zeroes the ith-row of tmp_pt // Zeroes the ith-row of tmp_pt
tmp_pt.at_mut(0).at_mut(row_i).fill(0); tmp_pt.at_mut(0).at_mut(row_i).fill(0);
//println!("row:{}/{}", row_i, rows);
//tmp_elem.at(0).print(tmp_elem.cols(), tmp_elem.n());
//tmp_elem.at(1).print(tmp_elem.cols(), tmp_elem.n());
//println!();
//println!(">>>");
// GRLWE[row_i][0||1] = [-as + m * 2^{-i*log_base2k} + e*2^{-log_q} || a] // GRLWE[row_i][0||1] = [-as + m * 2^{-i*log_base2k} + e*2^{-log_q} || a]
module.vmp_prepare_row( module.vmp_prepare_row(
&mut ct.at_mut(0), ct[0],
tmp_elem.at(0).raw(), tmp_elem.at(0).raw(),
row_i, row_i,
tmp_bytes_vmp_prepare_row, tmp_bytes_vmp_prepare_row,
); );
module.vmp_prepare_row( module.vmp_prepare_row(
&mut ct.at_mut(1), &mut ct[1],
tmp_elem.at(1).raw(), tmp_elem.at(1).raw(),
row_i, row_i,
tmp_bytes_vmp_prepare_row, tmp_bytes_vmp_prepare_row,
); );
}); });
//println!("DONE");
} }

View File

@@ -1,8 +1,8 @@
use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps}; use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
use std::cmp::min; use std::cmp::min;
pub fn gadget_product_tmp_bytes( pub fn gadget_product_core_tmp_bytes(
module: &Module, module: &Module,
log_base2k: usize, log_base2k: usize,
res_log_q: usize, res_log_q: usize,
@@ -17,14 +17,8 @@ pub fn gadget_product_tmp_bytes(
} }
impl Parameters { impl Parameters {
pub fn gadget_product_tmp_bytes( pub fn gadget_product_tmp_bytes(&self, res_log_q: usize, in_log_q: usize, gct_rows: usize, gct_log_q: usize) -> usize {
&self, gadget_product_core_tmp_bytes(
res_log_q: usize,
in_log_q: usize,
gct_rows: usize,
gct_log_q: usize,
) -> usize {
gadget_product_tmp_bytes(
self.module(), self.module(),
self.log_base2k(), self.log_base2k(),
res_log_q, res_log_q,
@@ -35,54 +29,93 @@ impl Parameters {
} }
} }
/// Evaluates the gadget product res <- a x b.
///
/// # Arguments
///
/// * `module`: backend support for operations mod (X^N + 1).
/// * `res`: an [Elem] to store (-cs + m * a + e, c) with res_ncols cols.
/// * `a`: a [VecZnx] of a_ncols cols.
/// * `b`: a [Ciphertext<VmpPMat>] as a vector of (-Bs + m * 2^{-k} + E, B)
/// containing b_nrows [VecZnx], each of b_ncols cols.
///
/// # Computation
///
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
/// = (cs + m * a + e, c) with min(res_cols, b_cols) cols.
pub fn gadget_product_core( pub fn gadget_product_core(
module: &Module, module: &Module,
res_dft_0: &mut VecZnxDft, res_dft_0: &mut VecZnxDft,
res_dft_1: &mut VecZnxDft, res_dft_1: &mut VecZnxDft,
a: &VecZnx, a: &VecZnx,
a_cols: usize,
b: &Ciphertext<VmpPMat>, b: &Ciphertext<VmpPMat>,
b_cols: usize, b_cols: usize,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) { ) {
assert!(b_cols <= b.cols()); assert!(b_cols <= b.cols());
module.vec_znx_dft(res_dft_1, a, min(a_cols, b_cols)); module.vec_znx_dft(res_dft_1, a);
module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes); module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes);
module.vmp_apply_dft_to_dft_inplace(res_dft_1, b.at(1), tmp_bytes); module.vmp_apply_dft_to_dft_inplace(res_dft_1, b.at(1), tmp_bytes);
} }
/* pub fn gadget_product_big_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
// res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)]) return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
module.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0, b_cols); + 2 * module.bytes_of_vec_znx_dft(min(c_cols, a_cols));
module.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1, b_cols); }
// res_big <- res[0] + res_big[a*G0] /// Evaluates the gadget product: c.at(i) = IDFT(<DFT(a.at(i)), b.at(i)>)
module.vec_znx_big_add_small_inplace(&mut res_big_0, res.at(0)); ///
module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big_0, tmp_bytes_carry); /// # Arguments
///
if OVERWRITE { /// * `module`: backend support for operations mod (X^N + 1).
// res[1] = normalize(res_big[a*G1]) /// * `c`: a [Ciphertext<VecZnxBig>] with cols_c cols.
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_1, tmp_bytes_carry); /// * `a`: a [Ciphertext<VecZnx>] with cols_a cols.
} else { /// * `b`: a [Ciphertext<VmpPMat>] with at least min(cols_c, cols_a) rows.
// res[1] = normalize(res_big[a*G1] + res[1]) pub fn gadget_product_big(
module.vec_znx_big_add_small_inplace(&mut res_big_1, res.at(1)); module: &Module,
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_1, tmp_bytes_carry); c: &mut Ciphertext<VecZnxBig>,
a: &Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>,
tmp_bytes: &mut [u8],
) {
let cols: usize = min(c.cols(), a.cols());
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
// a1_dft = DFT(a[1])
module.vec_znx_dft(&mut a1_dft, a.at(1));
// c[i] = IDFT(DFT(a[1]) * b[i])
(0..2).for_each(|i| {
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(i), tmp_bytes);
module.vec_znx_idft_tmp_a(c.at_mut(i), &mut res_dft);
})
}
/// Evaluates the gadget product: c.at(i) = NORMALIZE(IDFT(<DFT(a.at(i)), b.at(i)>)
///
/// # Arguments
///
/// * `module`: backend support for operations mod (X^N + 1).
/// * `c`: a [Ciphertext<VecZnx>] with cols_c cols.
/// * `a`: a [Ciphertext<VecZnx>] with cols_a cols.
/// * `b`: a [Ciphertext<VmpPMat>] with at least min(cols_c, cols_a) rows.
pub fn gadget_product(
module: &Module,
c: &mut Ciphertext<VecZnx>,
a: &Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>,
tmp_bytes: &mut [u8],
) {
let cols: usize = min(c.cols(), a.cols());
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
// a1_dft = DFT(a[1])
module.vec_znx_dft(&mut a1_dft, a.at(1));
// c[i] = IDFT(DFT(a[1]) * b[i])
(0..2).for_each(|i| {
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(i), tmp_bytes);
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(i), &mut res_big, tmp_bytes);
})
} }
*/
#[cfg(test)] #[cfg(test)]
mod test { mod test {
@@ -97,8 +130,8 @@ mod test {
plaintext::Plaintext, plaintext::Plaintext,
}; };
use base2k::{ use base2k::{
Infos, BACKEND, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, BACKEND, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat,
VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8, alloc_aligned_u8,
}; };
use sampling::source::{Source, new_seed}; use sampling::source::{Source, new_seed};
@@ -125,7 +158,6 @@ mod test {
// scratch space // scratch space
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8( let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
params.decrypt_rlwe_tmp_byte(params.log_qp()) params.decrypt_rlwe_tmp_byte(params.log_qp())
| params.encrypt_rlwe_sk_tmp_bytes(params.log_qp())
| params.gadget_product_tmp_bytes( | params.gadget_product_tmp_bytes(
params.log_qp(), params.log_qp(),
params.log_qp(), params.log_qp(),
@@ -172,10 +204,6 @@ mod test {
); );
// Intermediate buffers // Intermediate buffers
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big();
let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big();
// Input polynopmial, uniformly distributed // Input polynopmial, uniformly distributed
let mut a: VecZnx = params.module().new_vec_znx(params.cols_q()); let mut a: VecZnx = params.module().new_vec_znx(params.cols_q());
@@ -184,8 +212,7 @@ mod test {
.fill_uniform(log_base2k, &mut a, params.cols_q(), &mut source_xa); .fill_uniform(log_base2k, &mut a, params.cols_q(), &mut source_xa);
// res = g^-1(a) * gct // res = g^-1(a) * gct
let mut elem_res: Elem<VecZnx> = let mut elem_res: Elem<VecZnx> = Elem::<VecZnx>::new(params.module(), log_base2k, params.log_qp(), 2);
Elem::<VecZnx>::new(params.module(), log_base2k, params.log_qp(), 2);
// Ideal output = a * s // Ideal output = a * s
let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(a.cols()); let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(a.cols());
@@ -193,27 +220,27 @@ mod test {
let mut a_times_s: VecZnx = params.module().new_vec_znx(a.cols()); let mut a_times_s: VecZnx = params.module().new_vec_znx(a.cols());
// a * sk0 // a * sk0
params.module().svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a);
params.module().vec_znx_idft_tmp_a(&mut a_big, &mut a_dft);
params params
.module() .module()
.svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a, a.cols()); .vec_znx_big_normalize(params.log_base2k(), &mut a_times_s, &a_big, &mut tmp_bytes);
params
.module()
.vec_znx_idft_tmp_a(&mut a_big, &mut a_dft, a.cols());
params.module().vec_znx_big_normalize(
params.log_base2k(),
&mut a_times_s,
&a_big,
&mut tmp_bytes,
);
// Plaintext for decrypted output of gadget product // Plaintext for decrypted output of gadget product
let mut pt: Plaintext = let mut pt: Plaintext = Plaintext::new(params.module(), params.log_base2k(), params.log_qp());
Plaintext::new(params.module(), params.log_base2k(), params.log_qp());
// Iterates over all possible cols values for input/output polynomials and gadget ciphertext. // Iterates over all possible cols values for input/output polynomials and gadget ciphertext.
(1..a.cols() + 1).for_each(|a_cols| { (1..a.cols() + 1).for_each(|a_cols| {
let mut a_trunc: VecZnx = params.module().new_vec_znx(a_cols);
a_trunc.copy_from(&a);
(1..gadget_ct.cols() + 1).for_each(|b_cols| { (1..gadget_ct.cols() + 1).for_each(|b_cols| {
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(b_cols);
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(b_cols);
let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big();
let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big();
pt.elem_mut().zero(); pt.elem_mut().zero();
elem_res.zero(); elem_res.zero();
@@ -227,8 +254,7 @@ mod test {
params.module(), params.module(),
&mut res_dft_0, &mut res_dft_0,
&mut res_dft_1, &mut res_dft_1,
&a, &a_trunc,
a_cols,
&gadget_ct, &gadget_ct,
b_cols, b_cols,
&mut tmp_bytes, &mut tmp_bytes,
@@ -237,27 +263,21 @@ mod test {
// res_big_0 = IDFT(res_dft_0) // res_big_0 = IDFT(res_dft_0)
params params
.module() .module()
.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0, b_cols); .vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0);
// res_big_1 = IDFT(res_dft_1); // res_big_1 = IDFT(res_dft_1);
params params
.module() .module()
.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1, b_cols); .vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1);
// res_big_0 = normalize(res_big_0) // res_big_0 = normalize(res_big_0)
params.module().vec_znx_big_normalize( params
log_base2k, .module()
elem_res.at_mut(0), .vec_znx_big_normalize(log_base2k, elem_res.at_mut(0), &res_big_0, &mut tmp_bytes);
&res_big_0,
&mut tmp_bytes,
);
// res_big_1 = normalize(res_big_1) // res_big_1 = normalize(res_big_1)
params.module().vec_znx_big_normalize( params
log_base2k, .module()
elem_res.at_mut(1), .vec_znx_big_normalize(log_base2k, elem_res.at_mut(1), &res_big_1, &mut tmp_bytes);
&res_big_1,
&mut tmp_bytes,
);
// <(-c*sk1 + a*sk0 + e, a), (1, sk1)> = a*sk0 + e // <(-c*sk1 + a*sk0 + e, a), (1, sk1)> = a*sk0 + e
decrypt_rlwe( decrypt_rlwe(
@@ -271,7 +291,7 @@ mod test {
// a * sk0 + e - a*sk0 = e // a * sk0 + e - a*sk0 = e
params params
.module() .module()
.vec_znx_sub_inplace(pt.at_mut(0), &mut a_times_s); .vec_znx_sub_ab_inplace(pt.at_mut(0), &mut a_times_s);
pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes); pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes);
// pt.at(0).print(pt.elem().cols(), 16); // pt.at(0).print(pt.elem().cols(), 16);
@@ -288,28 +308,23 @@ mod test {
let a_logq: usize = a_cols * log_base2k; let a_logq: usize = a_cols * log_base2k;
let b_logq: usize = b_cols * log_base2k; let b_logq: usize = b_cols * log_base2k;
let var_msg: f64 = params.xs() as f64; let var_msg: f64 = (params.xs() as f64) / params.n() as f64;
let noise_pred: f64 = println!("{} {} {} {}", var_msg, var_a_err, a_logq, b_logq);
params.noise_grlwe_product(var_msg, var_a_err, a_logq, b_logq);
assert!(noise_have <= noise_pred + 1.0); let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, a_logq, b_logq);
println!("noise_pred: {}", noise_have); println!("noise_pred: {}", noise_pred);
println!("noise_have: {}", noise_pred); println!("noise_have: {}", noise_have);
// assert!(noise_have <= noise_pred + 1.0);
}); });
}); });
} }
} }
impl Parameters { impl Parameters {
pub fn noise_grlwe_product( pub fn noise_grlwe_product(&self, var_msg: f64, var_a_err: f64, a_logq: usize, b_logq: usize) -> f64 {
&self,
var_msg: f64,
var_a_err: f64,
a_logq: usize,
b_logq: usize,
) -> f64 {
let n: f64 = self.n() as f64; let n: f64 = self.n() as f64;
let var_xs: f64 = self.xs() as f64; let var_xs: f64 = self.xs() as f64;
@@ -360,9 +375,8 @@ pub fn noise_grlwe_product(
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
let mut noise: f64 = let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
(a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); noise += var_msg * var_a_err * a_scale * a_scale * n;
noise += var_msg * var_a_err * a_scale * a_scale;
noise = noise.sqrt(); noise = noise.sqrt();
noise /= b_scale; noise /= b_scale;
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]

View File

@@ -7,11 +7,7 @@ use sampling::source::Source;
pub struct KeyGenerator {} pub struct KeyGenerator {}
impl KeyGenerator { impl KeyGenerator {
pub fn gen_secret_key_thread_safe( pub fn gen_secret_key_thread_safe(&self, params: &Parameters, source: &mut Source) -> SecretKey {
&self,
params: &Parameters,
source: &mut Source,
) -> SecretKey {
let mut sk: SecretKey = SecretKey::new(params.module()); let mut sk: SecretKey = SecretKey::new(params.module());
sk.fill_ternary_hw(params.xs(), source); sk.fill_ternary_hw(params.xs(), source);
sk sk
@@ -26,8 +22,7 @@ impl KeyGenerator {
) -> PublicKey { ) -> PublicKey {
let mut xa_source: Source = source.branch(); let mut xa_source: Source = source.branch();
let mut xe_source: Source = source.branch(); let mut xe_source: Source = source.branch();
let mut pk: PublicKey = let mut pk: PublicKey = PublicKey::new(params.module(), params.log_base2k(), params.log_qp());
PublicKey::new(params.module(), params.log_base2k(), params.log_qp());
pk.gen_thread_safe( pk.gen_thread_safe(
params.module(), params.module(),
sk_ppol, sk_ppol,
@@ -40,12 +35,7 @@ impl KeyGenerator {
} }
} }
pub fn gen_switching_key_tmp_bytes( pub fn gen_switching_key_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
module: &Module,
log_base2k: usize,
rows: usize,
log_q: usize,
) -> usize {
encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q) encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
} }

79
rlwe/src/key_switching.rs Normal file
View File

@@ -0,0 +1,79 @@
use crate::ciphertext::Ciphertext;
use crate::elem::ElemCommon;
use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement};
use std::cmp::min;
pub fn key_switch_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k;
let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k;
let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k;
return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols)
+ module.bytes_of_vec_znx_dft(std::cmp::min(res_cols, in_cols))
+ module.bytes_of_vec_znx_dft(gct_cols);
}
pub fn key_switch_rlwe(
module: &Module,
c: &mut Ciphertext<VecZnx>,
a: &Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
key_switch_rlwe_core(module, c, a, b, b_cols, tmp_bytes);
}
pub fn key_switch_rlwe_inplace(
module: &Module,
a: &mut Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
key_switch_rlwe_core(module, a, a, b, b_cols, tmp_bytes);
}
fn key_switch_rlwe_core(
module: &Module,
c: *mut Ciphertext<VecZnx>,
a: *const Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
// SAFETY WARNING: must ensure `c` and `a` are valid for read/write
let c: &mut Ciphertext<VecZnx> = unsafe { &mut *c };
let a: &Ciphertext<VecZnx> = unsafe { &*a };
let cols: usize = min(min(c.cols(), a.cols()), b.rows());
#[cfg(debug_assertions)]
{
assert!(b_cols <= b.cols());
assert!(tmp_bytes.len() >= key_switch_tmp_bytes(module, c.cols(), a.cols(), b.rows(), b.cols()));
assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes_a1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
let mut a1_dft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_a1_dft);
let mut res_dft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft);
let mut res_big = res_dft.as_vec_znx_big();
module.vec_znx_dft(&mut a1_dft, a.at(1));
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(0), tmp_bytes);
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut res_big, tmp_bytes);
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(1), tmp_bytes);
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut res_big, tmp_bytes);
}
pub fn key_switch_grlwe(module: &Module, c: &mut Ciphertext<VecZnx>, a: &Ciphertext<VecZnx>, b: &Ciphertext<VmpPMat>) {}
pub fn key_switch_rgsw(module: &Module, c: &mut Ciphertext<VecZnx>, a: &Ciphertext<VecZnx>, b: &Ciphertext<VmpPMat>) {}

View File

@@ -1,10 +1,13 @@
pub mod automorphism;
pub mod ciphertext; pub mod ciphertext;
pub mod decryptor; pub mod decryptor;
pub mod elem; pub mod elem;
pub mod encryptor; pub mod encryptor;
pub mod gadget_product; pub mod gadget_product;
pub mod key_generator; pub mod key_generator;
pub mod key_switching;
pub mod keys; pub mod keys;
pub mod parameters; pub mod parameters;
pub mod plaintext; pub mod plaintext;
pub mod rgsw_product; pub mod rgsw_product;
pub mod trace;

View File

@@ -1,5 +1,7 @@
use base2k::module::{BACKEND, Module}; use base2k::module::{BACKEND, Module};
pub const DEFAULT_SIGMA: f64 = 3.2;
pub struct ParametersLiteral { pub struct ParametersLiteral {
pub backend: BACKEND, pub backend: BACKEND,
pub log_n: usize, pub log_n: usize,

View File

@@ -43,12 +43,7 @@ impl Plaintext {
)) ))
} }
pub fn from_bytes_borrow( pub fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self {
module: &Module,
log_base2k: usize,
log_q: usize,
bytes: &mut [u8],
) -> Self {
Self(Elem::<VecZnx>::from_bytes_borrow( Self(Elem::<VecZnx>::from_bytes_borrow(
module, log_base2k, log_q, 1, bytes, module, log_base2k, log_q, 1, bytes,
)) ))

View File

@@ -1,54 +1,300 @@
use crate::{ use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
ciphertext::Ciphertext, use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement};
elem::{Elem, ElemCommon, ElemVecZnx},
};
use base2k::{
Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps,
};
use std::cmp::min; use std::cmp::min;
impl Parameters {
pub fn rgsw_product_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
rgsw_product_tmp_bytes(
self.module(),
self.log_base2k(),
res_logq,
in_logq,
gct_logq,
)
}
}
pub fn rgsw_product_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k;
let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k;
let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k;
return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols)
+ module.bytes_of_vec_znx_dft(std::cmp::min(res_cols, in_cols))
+ 2 * module.bytes_of_vec_znx_dft(gct_cols);
}
pub fn rgsw_product( pub fn rgsw_product(
module: &Module, module: &Module,
_res: &mut Elem<VecZnx>, c: &mut Ciphertext<VecZnx>,
a: &Ciphertext<VecZnx>, a: &Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>, b: &Ciphertext<VmpPMat>,
b_cols: usize,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) { ) {
let _log_base2k: usize = b.log_base2k(); #[cfg(debug_assertions)]
let rows: usize = min(b.rows(), a.cols()); {
let cols: usize = b.cols(); assert!(b_cols <= b.cols());
let in_cols = a.cols(); assert_eq!(c.size(), 2);
let out_cols: usize = a.cols(); assert_eq!(a.size(), 2);
assert_eq!(b.size(), 4);
assert!(tmp_bytes.len() >= rgsw_product_tmp_bytes(module, c.cols(), a.cols(), min(b.rows(), a.cols()), b_cols));
assert_alignement(tmp_bytes.as_ptr());
}
let bytes_of_vec_znx_dft = module.bytes_of_vec_znx_dft(cols); let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(a.cols()));
let bytes_of_vmp_apply_dft_to_dft = let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, rows, cols); let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(a.cols(), tmp_bytes_ai_dft);
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c0_dft);
let (tmp_bytes_tmp_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c1_dft);
let (tmp_bytes_r1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
let (tmp_bytes_r2_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft);
let (bytes_of_vmp_apply_dft_to_dft, tmp_bytes) =
tmp_bytes.split_at_mut(bytes_of_vmp_apply_dft_to_dft);
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c0_dft); let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big();
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft); let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big();
let mut _tmp_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_tmp_dft);
let mut r1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r1_dft);
let mut _r2_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r2_dft);
// c0_dft <- DFT(a[0]) module.vec_znx_dft(&mut ai_dft, a.at(0));
module.vec_znx_dft(&mut c0_dft, a.at(0), in_cols); module.vmp_apply_dft_to_dft(&mut c0_dft, &ai_dft, b.at(0), tmp_bytes);
module.vmp_apply_dft_to_dft(&mut c1_dft, &ai_dft, b.at(1), tmp_bytes);
// r_dft <- sum[rows] c0_dft[cols] x RGSW[0][cols] module.vec_znx_dft(&mut ai_dft, a.at(1));
module.vmp_apply_dft_to_dft( module.vmp_apply_dft_to_dft_add(&mut c0_dft, &ai_dft, b.at(2), tmp_bytes);
&mut r1_dft, module.vmp_apply_dft_to_dft_add(&mut c1_dft, &ai_dft, b.at(3), tmp_bytes);
&c1_dft,
&b.0.value[0], module.vec_znx_idft_tmp_a(&mut c0_big, &mut c0_dft);
bytes_of_vmp_apply_dft_to_dft, module.vec_znx_idft_tmp_a(&mut c1_big, &mut c1_dft);
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut c0_big, tmp_bytes);
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut c1_big, tmp_bytes);
}
pub fn rgsw_product_inplace(
module: &Module,
a: &mut Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
#[cfg(debug_assertions)]
{
assert!(b_cols <= b.cols());
assert_eq!(a.size(), 2);
assert_eq!(b.size(), 4);
assert!(tmp_bytes.len() >= rgsw_product_tmp_bytes(module, a.cols(), a.cols(), min(b.rows(), a.cols()), b_cols));
assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(a.cols()));
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(a.cols(), tmp_bytes_ai_dft);
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c0_dft);
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c1_dft);
let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big();
let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big();
module.vec_znx_dft(&mut ai_dft, a.at(0));
module.vmp_apply_dft_to_dft(&mut c0_dft, &ai_dft, b.at(0), tmp_bytes);
module.vmp_apply_dft_to_dft(&mut c1_dft, &ai_dft, b.at(1), tmp_bytes);
module.vec_znx_dft(&mut ai_dft, a.at(1));
module.vmp_apply_dft_to_dft_add(&mut c0_dft, &ai_dft, b.at(2), tmp_bytes);
module.vmp_apply_dft_to_dft_add(&mut c1_dft, &ai_dft, b.at(3), tmp_bytes);
module.vec_znx_idft_tmp_a(&mut c0_big, &mut c0_dft);
module.vec_znx_idft_tmp_a(&mut c1_big, &mut c1_dft);
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut c0_big, tmp_bytes);
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut c1_big, tmp_bytes);
}
#[cfg(test)]
mod test {
use crate::{
ciphertext::{Ciphertext, new_rgsw_ciphertext},
decryptor::decrypt_rlwe,
elem::ElemCommon,
encryptor::{encrypt_rgsw_sk, encrypt_rlwe_sk},
keys::SecretKey,
parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral},
plaintext::Plaintext,
rgsw_product::rgsw_product_inplace,
};
use base2k::{BACKEND, Encoding, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, VmpPMat, alloc_aligned};
use sampling::source::{Source, new_seed};
#[test]
fn test_rgsw_product() {
let log_base2k: usize = 10;
let log_q: usize = 50;
let log_p: usize = 15;
// Basic parameters with enough limbs to test edge cases
let params_lit: ParametersLiteral = ParametersLiteral {
backend: BACKEND::FFT64,
log_n: 12,
log_q: log_q,
log_p: log_p,
log_base2k: log_base2k,
log_scale: 20,
xe: 3.2,
xs: 1 << 11,
};
let params: Parameters = Parameters::new(&params_lit);
let module: &Module = params.module();
let log_q: usize = params.log_q();
let log_qp: usize = params.log_qp();
let gct_rows: usize = params.cols_q();
let gct_cols: usize = params.cols_qp();
// scratch space
let mut tmp_bytes: Vec<u8> = alloc_aligned(
params.decrypt_rlwe_tmp_byte(log_q)
| params.encrypt_rlwe_sk_tmp_bytes(log_q)
| params.rgsw_product_tmp_bytes(log_q, log_q, log_qp)
| params.encrypt_rgsw_sk_tmp_bytes(gct_rows, log_qp),
); );
// c1_dft <- DFT(a[1]) // Samplers for public and private randomness
module.vec_znx_dft(&mut c1_dft, a.at(1), in_cols); let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(new_seed());
let mut source_xs: Source = Source::new(new_seed());
let mut sk: SecretKey = SecretKey::new(module);
sk.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol();
module.svp_prepare(&mut sk_svp_ppol, &sk.0);
let mut ct_rgsw: Ciphertext<VmpPMat> = new_rgsw_ciphertext(module, log_base2k, gct_rows, log_qp);
let k: i64 = 3;
// X^k
let m: Scalar = module.new_scalar();
let data: &mut [i64] = m.raw_mut();
data[k as usize] = 1;
encrypt_rgsw_sk(
module,
&mut ct_rgsw,
&m,
&sk_svp_ppol,
&mut source_xa,
&mut source_xe,
DEFAULT_SIGMA,
&mut tmp_bytes,
);
let log_k: usize = 2 * log_base2k;
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
let mut pt: Plaintext = params.new_plaintext(log_q);
let mut pt_rotate: Plaintext = params.new_plaintext(log_q);
pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32);
module.vec_znx_rotate(k, pt_rotate.at_mut(0), pt.at_mut(0));
encrypt_rlwe_sk(
module,
&mut ct.elem_mut(),
Some(pt.at(0)),
&sk_svp_ppol,
&mut source_xa,
&mut source_xe,
params.xe(),
&mut tmp_bytes,
);
rgsw_product_inplace(module, &mut ct, &ct_rgsw, gct_cols, &mut tmp_bytes);
decrypt_rlwe(
module,
pt.elem_mut(),
ct.elem(),
&sk_svp_ppol,
&mut tmp_bytes,
);
module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_rotate.at(0));
// pt.at(0).print(pt.cols(), 16);
let noise_have: f64 = pt.at(0).std(log_base2k).log2();
let var_msg: f64 = 1f64 / params.n() as f64; // X^{k}
let var_a0_err: f64 = params.xe() * params.xe();
let var_a1_err: f64 = 1f64 / 12f64;
let noise_pred: f64 = params.noise_rgsw_product(var_msg, var_a0_err, var_a1_err, ct.log_q(), ct_rgsw.log_q());
println!("noise_pred: {}", noise_pred);
println!("noise_have: {}", noise_have);
assert!(noise_have <= noise_pred + 1.0);
}
}
impl Parameters {
pub fn noise_rgsw_product(&self, var_msg: f64, var_a0_err: f64, var_a1_err: f64, a_logq: usize, b_logq: usize) -> f64 {
let n: f64 = self.n() as f64;
let var_xs: f64 = self.xs() as f64;
let var_gct_err_lhs: f64;
let var_gct_err_rhs: f64;
if b_logq < self.log_qp() {
let var_round: f64 = 1f64 / 12f64;
var_gct_err_lhs = var_round;
var_gct_err_rhs = var_round;
} else {
var_gct_err_lhs = self.xe() * self.xe();
var_gct_err_rhs = 0f64;
}
noise_rgsw_product(
n,
self.log_base2k(),
var_xs,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
a_logq,
b_logq,
)
}
}
pub fn noise_rgsw_product(
n: f64,
log_base2k: usize,
var_xs: f64,
var_msg: f64,
var_a0_err: f64,
var_a1_err: f64,
var_gct_err_lhs: f64,
var_gct_err_rhs: f64,
a_logq: usize,
b_logq: usize,
) -> f64 {
let a_logq: usize = min(a_logq, b_logq);
let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k;
let b_scale = 2.0f64.powi(b_logq as i32);
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
let base: f64 = (1 << (log_base2k)) as f64;
let var_base: f64 = base * base / 12f64;
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
let mut noise: f64 = 2.0 * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
noise += var_msg * var_a0_err * a_scale * a_scale * n;
noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs;
noise = noise.sqrt();
noise /= b_scale;
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
} }

113
rlwe/src/test.rs Normal file
View File

@@ -0,0 +1,113 @@
use base2k::{alloc_aligned, SvpPPol, SvpPPolOps, VecZnx, BACKEND};
use sampling::source::{Source, new_seed};
use crate::{ciphertext::Ciphertext, decryptor::decrypt_rlwe, elem::ElemCommon, encryptor::encrypt_rlwe_sk, keys::SecretKey, parameters::{Parameters, ParametersLiteral, DEFAULT_SIGMA}, plaintext::Plaintext};
pub struct Context{
pub params: Parameters,
pub sk0: SecretKey,
pub sk0_ppol:SvpPPol,
pub sk1: SecretKey,
pub sk1_ppol: SvpPPol,
pub tmp_bytes: Vec<u8>,
}
impl Context{
pub fn new(log_n: usize, log_base2k: usize, log_q: usize, log_p: usize) -> Self{
let params_lit: ParametersLiteral = ParametersLiteral {
backend: BACKEND::FFT64,
log_n: log_n,
log_q: log_q,
log_p: log_p,
log_base2k: log_base2k,
log_scale: 20,
xe: DEFAULT_SIGMA,
xs: 1 << (log_n-1),
};
let params: Parameters =Parameters::new(&params_lit);
let module = params.module();
let log_q: usize = params.log_q();
let mut source_xs: Source = Source::new(new_seed());
let mut sk0: SecretKey = SecretKey::new(module);
sk0.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk0_ppol: base2k::SvpPPol = module.new_svp_ppol();
module.svp_prepare(&mut sk0_ppol, &sk0.0);
let mut sk1: SecretKey = SecretKey::new(module);
sk1.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk1_ppol: base2k::SvpPPol = module.new_svp_ppol();
module.svp_prepare(&mut sk1_ppol, &sk1.0);
let tmp_bytes: Vec<u8> = alloc_aligned(params.decrypt_rlwe_tmp_byte(log_q)| params.encrypt_rlwe_sk_tmp_bytes(log_q));
Context{
params: params,
sk0: sk0,
sk0_ppol: sk0_ppol,
sk1: sk1,
sk1_ppol: sk1_ppol,
tmp_bytes: tmp_bytes,
}
}
pub fn encrypt_rlwe_sk0(&mut self, pt: &Plaintext, ct: &mut Ciphertext<VecZnx>){
let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(new_seed());
encrypt_rlwe_sk(
self.params.module(),
ct.elem_mut(),
Some(pt.elem()),
&self.sk0_ppol,
&mut source_xa,
&mut source_xe,
self.params.xe(),
&mut self.tmp_bytes,
);
}
pub fn encrypt_rlwe_sk1(&mut self, ct: &mut Ciphertext<VecZnx>, pt: &Plaintext){
let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(new_seed());
encrypt_rlwe_sk(
self.params.module(),
ct.elem_mut(),
Some(pt.elem()),
&self.sk1_ppol,
&mut source_xa,
&mut source_xe,
self.params.xe(),
&mut self.tmp_bytes,
);
}
pub fn decrypt_sk0(&mut self, pt: &mut Plaintext, ct: &Ciphertext<VecZnx>){
decrypt_rlwe(
self.params.module(),
pt.elem_mut(),
ct.elem(),
&self.sk0_ppol,
&mut self.tmp_bytes,
);
}
pub fn decrypt_sk1(&mut self, pt: &mut Plaintext, ct: &Ciphertext<VecZnx>){
decrypt_rlwe(
self.params.module(),
pt.elem_mut(),
ct.elem(),
&self.sk1_ppol,
&mut self.tmp_bytes,
);
}
}

236
rlwe/src/trace.rs Normal file
View File

@@ -0,0 +1,236 @@
use crate::{automorphism::AutomorphismKey, ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMatOps, assert_alignement};
use std::collections::HashMap;
pub fn trace_galois_elements(module: &Module) -> Vec<i64> {
let mut gal_els: Vec<i64> = Vec::new();
(0..module.log_n()).for_each(|i| {
if i == 0 {
gal_els.push(-1);
} else {
gal_els.push(module.galois_element(1 << (i - 1)));
}
});
gal_els
}
impl Parameters {
pub fn trace_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
self.automorphism_tmp_bytes(res_logq, in_logq, gct_logq)
}
}
pub fn trace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
+ 2 * module.bytes_of_vec_znx_dft(std::cmp::min(c_cols, a_cols));
}
pub fn trace_inplace(
module: &Module,
a: &mut Ciphertext<VecZnx>,
start: usize,
end: usize,
b: &HashMap<i64, AutomorphismKey>,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
let cols: usize = a.cols();
let b_rows: usize;
if let Some((_, key)) = b.iter().next() {
b_rows = key.value.rows();
#[cfg(debug_assertions)]
{
println!("{} {}", b_cols, key.value.cols());
assert!(b_cols <= key.value.cols())
}
} else {
panic!("b: HashMap<i64, AutomorphismKey>, is empty")
}
#[cfg(debug_assertions)]
{
assert!(start <= end);
assert!(end <= module.n());
assert!(tmp_bytes.len() >= trace_tmp_bytes(module, cols, cols, b_rows, b_cols));
assert_alignement(tmp_bytes.as_ptr());
}
let cols: usize = std::cmp::min(b_cols, a.cols());
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft);
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
let log_base2k: usize = a.log_base2k();
(start..end).for_each(|i| {
a.at_mut(0).rsh(log_base2k, 1, tmp_bytes);
a.at_mut(1).rsh(log_base2k, 1, tmp_bytes);
let p: i64;
if i == 0 {
p = -1;
} else {
p = module.galois_element(1 << (i - 1));
}
if let Some(key) = b.get(&p) {
module.vec_znx_dft(&mut a1_dft, a.at(1));
// a[0] = NORMALIZE(a[0] + AUTO(a[0] + IDFT(<DFT(a[1]), key[0]>)))
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, key.value.at(0), tmp_bytes);
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
module.vec_znx_big_automorphism_inplace(p, &mut res_big);
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut res_big, tmp_bytes);
// a[1] = NORMALIZE(a[1] + AUTO(IDFT(<DFT(a[1]), key[1]>)))
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, key.value.at(1), tmp_bytes);
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
module.vec_znx_big_automorphism_inplace(p, &mut res_big);
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(1));
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut res_big, tmp_bytes);
} else {
panic!("b[{}] is empty", p)
}
})
}
#[cfg(test)]
mod test {
use super::{trace_galois_elements, trace_inplace};
use crate::{
automorphism::AutomorphismKey,
ciphertext::Ciphertext,
decryptor::decrypt_rlwe,
elem::ElemCommon,
encryptor::encrypt_rlwe_sk,
keys::SecretKey,
parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral},
plaintext::Plaintext,
};
use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, alloc_aligned};
use sampling::source::{Source, new_seed};
use std::collections::HashMap;
#[test]
fn test_trace_inplace() {
let log_base2k: usize = 10;
let log_q: usize = 50;
let log_p: usize = 15;
// Basic parameters with enough limbs to test edge cases
let params_lit: ParametersLiteral = ParametersLiteral {
backend: BACKEND::FFT64,
log_n: 12,
log_q: log_q,
log_p: log_p,
log_base2k: log_base2k,
log_scale: 20,
xe: 3.2,
xs: 1 << 11,
};
let params: Parameters = Parameters::new(&params_lit);
let module: &Module = params.module();
let log_q: usize = params.log_q();
let log_qp: usize = params.log_qp();
let gct_rows: usize = params.cols_q();
let gct_cols: usize = params.cols_qp();
// scratch space
let mut tmp_bytes: Vec<u8> = alloc_aligned(
params.decrypt_rlwe_tmp_byte(log_q)
| params.encrypt_rlwe_sk_tmp_bytes(log_q)
| params.automorphism_key_new_tmp_bytes(gct_rows, log_qp)
| params.automorphism_tmp_bytes(log_q, log_q, log_qp),
);
// Samplers for public and private randomness
let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(new_seed());
let mut source_xs: Source = Source::new(new_seed());
let mut sk: SecretKey = SecretKey::new(module);
sk.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol();
module.svp_prepare(&mut sk_svp_ppol, &sk.0);
let gal_els: Vec<i64> = trace_galois_elements(module);
let auto_keys: HashMap<i64, AutomorphismKey> = AutomorphismKey::new_many(
module,
&gal_els,
&sk,
log_base2k,
gct_rows,
log_qp,
&mut source_xa,
&mut source_xe,
DEFAULT_SIGMA,
&mut tmp_bytes,
);
let mut data: Vec<i64> = vec![0i64; params.n()];
data.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = 1 + i as i64);
let log_k: usize = 2 * log_base2k;
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
let mut pt: Plaintext = params.new_plaintext(log_q);
pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32);
pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes);
pt.at(0).decode_vec_i64(log_base2k, log_k, &mut data);
pt.at(0).print(pt.cols(), 16);
encrypt_rlwe_sk(
module,
&mut ct.elem_mut(),
Some(pt.at(0)),
&sk_svp_ppol,
&mut source_xa,
&mut source_xe,
params.xe(),
&mut tmp_bytes,
);
trace_inplace(module, &mut ct, 0, 4, &auto_keys, gct_cols, &mut tmp_bytes);
trace_inplace(
module,
&mut ct,
4,
module.log_n(),
&auto_keys,
gct_cols,
&mut tmp_bytes,
);
// pt = dec(auto(ct)) - auto(pt)
decrypt_rlwe(
module,
pt.elem_mut(),
ct.elem(),
&sk_svp_ppol,
&mut tmp_bytes,
);
pt.at(0).print(pt.cols(), 16);
pt.at(0).decode_vec_i64(log_base2k, log_k, &mut data);
println!("trace: {:?}", &data[..16]);
}
}

79
rustfmt.toml Normal file
View File

@@ -0,0 +1,79 @@
max_width = 130
hard_tabs = false
tab_spaces = 4
newline_style = "Auto"
indent_style = "Block"
use_small_heuristics = "Default"
fn_call_width = 60
attr_fn_like_width = 100
struct_lit_width = 18
struct_variant_width = 35
array_width = 60
chain_width = 60
single_line_if_else_max_width = 50
single_line_let_else_max_width = 50
wrap_comments = false
format_code_in_doc_comments = true
doc_comment_code_block_width = 100
comment_width = 80
normalize_comments = true
normalize_doc_attributes = true
format_strings = true
format_macro_matchers = false
format_macro_bodies = true
skip_macro_invocations = []
hex_literal_case = "Preserve"
empty_item_single_line = true
struct_lit_single_line = true
fn_single_line = false
where_single_line = false
imports_indent = "Block"
imports_layout = "Mixed"
imports_granularity = "Preserve"
group_imports = "Preserve"
reorder_imports = true
reorder_modules = true
reorder_impl_items = false
type_punctuation_density = "Wide"
space_before_colon = false
space_after_colon = true
spaces_around_ranges = false
binop_separator = "Front"
remove_nested_parens = true
combine_control_expr = true
short_array_element_width_threshold = 10
overflow_delimited_expr = false
struct_field_align_threshold = 0
enum_discrim_align_threshold = 0
match_arm_blocks = true
match_arm_leading_pipes = "Never"
force_multiline_blocks = false
fn_params_layout = "Tall"
brace_style = "SameLineWhere"
control_brace_style = "AlwaysSameLine"
trailing_semicolon = true
trailing_comma = "Vertical"
match_block_trailing_comma = false
blank_lines_upper_bound = 1
blank_lines_lower_bound = 0
edition = "2024"
style_edition = "2024"
inline_attribute_width = 0
format_generated_files = true
generated_marker_line_search_limit = 5
merge_derives = true
use_try_shorthand = false
use_field_init_shorthand = false
force_explicit_abi = true
condense_wildcard_suffixes = false
color = "Auto"
required_version = "1.8.0"
unstable_features = true
disable_all_formatting = false
skip_children = false
show_parse_errors = true
error_on_line_overflow = false
error_on_unformatted = false
ignore = []
emit_mode = "Files"
make_backup = false

View File

@@ -1,5 +1,5 @@
use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
use rand_chacha::rand_core::SeedableRng;
use rand_core::{OsRng, RngCore}; use rand_core::{OsRng, RngCore};
const MAXF64: f64 = 9007199254740992.0; const MAXF64: f64 = 9007199254740992.0;