Various improvement to memory management and API

[module]: added enum for backend
[VecZnx, VecZnxDft, VecZnxBig, VmpPMat]: added ptr to data
[VecZnxBorrow]: removed
[VecZnxAPI]: removed
This commit is contained in:
Jean-Philippe Bossuat
2025-03-17 12:07:40 +01:00
parent 97a1559bf2
commit 46c577409e
28 changed files with 896 additions and 1064 deletions

View File

@@ -56,6 +56,7 @@
"xlocnum": "cpp", "xlocnum": "cpp",
"xloctime": "cpp", "xloctime": "cpp",
"xmemory": "cpp", "xmemory": "cpp",
"xtr1common": "cpp" "xtr1common": "cpp",
"vec_znx_arithmetic_private.h": "c"
} }
} }

1
Cargo.lock generated
View File

@@ -655,6 +655,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"base2k", "base2k",
"criterion", "criterion",
"itertools 0.14.0",
"rand_distr", "rand_distr",
"rug", "rug",
"sampling", "sampling",

View File

@@ -1,6 +1,6 @@
use base2k::{ use base2k::{
Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, alloc_aligned, Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx,
VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, FFT64, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, MODULETYPE,
}; };
use itertools::izip; use itertools::izip;
use sampling::source::Source; use sampling::source::Source;
@@ -11,9 +11,9 @@ fn main() {
let cols: usize = 3; let cols: usize = 3;
let msg_cols: usize = 2; let msg_cols: usize = 2;
let log_scale: usize = msg_cols * log_base2k - 5; let log_scale: usize = msg_cols * log_base2k - 5;
let module: Module = Module::new::<FFT64>(n); let module: Module = Module::new(n, MODULETYPE::FFT64);
let mut carry: Vec<u8> = vec![0; module.vec_znx_big_normalize_tmp_bytes()]; let mut carry: Vec<u8> = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes());
let seed: [u8; 32] = [0; 32]; let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed); let mut source: Source = Source::new(seed);

View File

@@ -1,13 +1,13 @@
use base2k::{ use base2k::{
Encoding, Free, Infos, Module, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxDft, alloc_aligned, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
VecZnxDftOps, VecZnxOps, VecZnxVec, VmpPMat, VmpPMatOps, FFT64, VecZnxDftOps, VecZnxOps, VecZnxVec, VmpPMat, VmpPMatOps, MODULETYPE,
}; };
fn main() { fn main() {
let log_n: i32 = 5; let log_n: i32 = 5;
let n: usize = 1 << log_n; let n: usize = 1 << log_n;
let module: Module = Module::new::<FFT64>(n); let module: Module = Module::new(n, MODULETYPE::FFT64);
let log_base2k: usize = 15; let log_base2k: usize = 15;
let cols: usize = 5; let cols: usize = 5;
let log_k: usize = log_base2k * cols - 5; let log_k: usize = log_base2k * cols - 5;
@@ -19,7 +19,7 @@ fn main() {
let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows, cols) let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows, cols)
| module.vmp_apply_dft_tmp_bytes(cols, cols, rows, cols); | module.vmp_apply_dft_tmp_bytes(cols, cols, rows, cols);
let mut buf: Vec<u8> = vec![0; tmp_bytes]; let mut buf: Vec<u8> = alloc_aligned(tmp_bytes);
let mut a_values: Vec<i64> = vec![i64::default(); n]; let mut a_values: Vec<i64> = vec![i64::default(); n];
a_values[1] = (1 << log_base2k) + 1; a_values[1] = (1 << log_base2k) + 1;
@@ -37,7 +37,7 @@ fn main() {
}); });
(0..rows).for_each(|i| { (0..rows).for_each(|i| {
vecznx[i].data[i * n + 1] = 1 as i64; vecznx[i].raw_mut()[i * n + 1] = 1 as i64;
}); });
let slices: Vec<&[i64]> = vecznx.dblptr(); let slices: Vec<&[i64]> = vecznx.dblptr();
@@ -60,8 +60,6 @@ fn main() {
res.print(res.cols(), n); res.print(res.cols(), n);
module.free(); module.free();
c_dft.free();
vmp_pmat.free();
//println!("{:?}", values_res) println!("{:?}", values_res)
} }

View File

@@ -1,5 +1,5 @@
use crate::ffi::znx::znx_zero_i64_ref; use crate::ffi::znx::znx_zero_i64_ref;
use crate::{VecZnx, VecZnxBorrow, VecZnxCommon}; use crate::{Infos, VecZnx};
use itertools::izip; use itertools::izip;
use rug::{Assign, Float}; use rug::{Assign, Float};
use std::cmp::min; use std::cmp::min;
@@ -89,42 +89,7 @@ impl Encoding for VecZnx {
} }
} }
impl Encoding for VecZnxBorrow { fn encode_vec_i64(a: &mut VecZnx, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) {
fn encode_vec_i64(&mut self, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) {
encode_vec_i64(self, log_base2k, log_k, data, log_max)
}
fn decode_vec_i64(&self, log_base2k: usize, log_k: usize, data: &mut [i64]) {
decode_vec_i64(self, log_base2k, log_k, data)
}
fn decode_vec_float(&self, log_base2k: usize, data: &mut [Float]) {
decode_vec_float(self, log_base2k, data)
}
fn encode_coeff_i64(
&mut self,
log_base2k: usize,
log_k: usize,
i: usize,
value: i64,
log_max: usize,
) {
encode_coeff_i64(self, log_base2k, log_k, i, value, log_max)
}
fn decode_coeff_i64(&self, log_base2k: usize, log_k: usize, i: usize) -> i64 {
decode_coeff_i64(self, log_base2k, log_k, i)
}
}
fn encode_vec_i64<T: VecZnxCommon>(
a: &mut T,
log_base2k: usize,
log_k: usize,
data: &[i64],
log_max: usize,
) {
let cols: usize = (log_k + log_base2k - 1) / log_base2k; let cols: usize = (log_k + log_base2k - 1) / log_base2k;
debug_assert!( debug_assert!(
@@ -170,7 +135,7 @@ fn encode_vec_i64<T: VecZnxCommon>(
} }
} }
fn decode_vec_i64<T: VecZnxCommon>(a: &T, log_base2k: usize, log_k: usize, data: &mut [i64]) { fn decode_vec_i64(a: &VecZnx, log_base2k: usize, log_k: usize, data: &mut [i64]) {
let cols: usize = (log_k + log_base2k - 1) / log_base2k; let cols: usize = (log_k + log_base2k - 1) / log_base2k;
debug_assert!( debug_assert!(
data.len() >= a.n(), data.len() >= a.n(),
@@ -194,7 +159,7 @@ fn decode_vec_i64<T: VecZnxCommon>(a: &T, log_base2k: usize, log_k: usize, data:
}) })
} }
fn decode_vec_float<T: VecZnxCommon>(a: &T, log_base2k: usize, data: &mut [Float]) { fn decode_vec_float(a: &VecZnx, log_base2k: usize, data: &mut [Float]) {
let cols: usize = a.cols(); let cols: usize = a.cols();
debug_assert!( debug_assert!(
data.len() >= a.n(), data.len() >= a.n(),
@@ -224,8 +189,8 @@ fn decode_vec_float<T: VecZnxCommon>(a: &T, log_base2k: usize, data: &mut [Float
}); });
} }
fn encode_coeff_i64<T: VecZnxCommon>( fn encode_coeff_i64(
a: &mut T, a: &mut VecZnx,
log_base2k: usize, log_base2k: usize,
log_k: usize, log_k: usize,
i: usize, i: usize,
@@ -268,7 +233,7 @@ fn encode_coeff_i64<T: VecZnxCommon>(
} }
} }
fn decode_coeff_i64<T: VecZnxCommon>(a: &T, log_base2k: usize, log_k: usize, i: usize) -> i64 { fn decode_coeff_i64(a: &VecZnx, log_base2k: usize, log_k: usize, i: usize) -> i64 {
let cols: usize = (log_k + log_base2k - 1) / log_base2k; let cols: usize = (log_k + log_base2k - 1) / log_base2k;
debug_assert!(i < a.n()); debug_assert!(i < a.n());
let data: &[i64] = a.raw(); let data: &[i64] = a.raw();

View File

@@ -1,43 +0,0 @@
use crate::ffi::svp;
use crate::ffi::vec_znx_big;
use crate::ffi::vec_znx_dft;
use crate::ffi::vmp;
use crate::{SvpPPol, VecZnxBig, VecZnxDft, VmpPMat};
/// This trait should be implemented by structs that point to
/// memory allocated through C.
pub trait Free {
// Frees the memory and self destructs.
fn free(self);
}
impl Free for VmpPMat {
/// Frees the C allocated memory of the [VmpPMat] and self destructs the struct.
fn free(self) {
unsafe { vmp::delete_vmp_pmat(self.data) };
drop(self);
}
}
impl Free for VecZnxDft {
fn free(self) {
unsafe { vec_znx_dft::delete_vec_znx_dft(self.0) };
drop(self);
}
}
impl Free for VecZnxBig {
fn free(self) {
unsafe {
vec_znx_big::delete_vec_znx_big(self.0);
}
drop(self);
}
}
impl Free for SvpPPol {
fn free(self) {
unsafe { svp::delete_svp_ppol(self.0) };
let _ = drop(self);
}
}

View File

@@ -1,5 +1,3 @@
use crate::{VecZnx, VecZnxBorrow, VmpPMat};
pub trait Infos { pub trait Infos {
/// Returns the ring degree of the receiver. /// Returns the ring degree of the receiver.
fn n(&self) -> usize; fn n(&self) -> usize;
@@ -14,71 +12,3 @@ pub trait Infos {
/// Returns the number of rows of the receiver. /// Returns the number of rows of the receiver.
fn rows(&self) -> usize; fn rows(&self) -> usize;
} }
impl Infos for VecZnx {
/// Returns the base 2 logarithm of the [VecZnx] degree.
fn log_n(&self) -> usize {
(usize::BITS - (self.n - 1).leading_zeros()) as _
}
/// Returns the [VecZnx] degree.
fn n(&self) -> usize {
self.n
}
/// Returns the number of cols of the [VecZnx].
fn cols(&self) -> usize {
self.data.len() / self.n
}
/// Returns the number of rows of the [VecZnx].
fn rows(&self) -> usize {
1
}
}
impl Infos for VecZnxBorrow {
/// Returns the base 2 logarithm of the [VecZnx] degree.
fn log_n(&self) -> usize {
(usize::BITS - (self.n - 1).leading_zeros()) as _
}
/// Returns the [VecZnx] degree.
fn n(&self) -> usize {
self.n
}
/// Returns the number of cols of the [VecZnx].
fn cols(&self) -> usize {
self.cols
}
/// Returns the number of rows of the [VecZnx].
fn rows(&self) -> usize {
1
}
}
impl Infos for VmpPMat {
/// Returns the ring dimension of the [VmpPMat].
fn n(&self) -> usize {
self.n
}
fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
/// Returns the number of rows (i.e. of [VecZnxDft]) of the [VmpPMat]
fn rows(&self) -> usize {
self.rows
}
/// Returns the number of cols of the [VmpPMat].
/// The number of cols refers to the number of cols
/// of each [VecZnxDft].
/// This method is equivalent to [Self::cols].
fn cols(&self) -> usize {
self.cols
}
}

View File

@@ -8,7 +8,6 @@ pub mod encoding;
)] )]
// Other modules and exports // Other modules and exports
pub mod ffi; pub mod ffi;
pub mod free;
pub mod infos; pub mod infos;
pub mod module; pub mod module;
pub mod sampling; pub mod sampling;
@@ -20,7 +19,6 @@ pub mod vec_znx_dft;
pub mod vmp; pub mod vmp;
pub use encoding::*; pub use encoding::*;
pub use free::*;
pub use infos::*; pub use infos::*;
pub use module::*; pub use module::*;
pub use sampling::*; pub use sampling::*;
@@ -124,11 +122,3 @@ pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
pub fn alloc_aligned<T>(size: usize) -> Vec<T> { pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
alloc_aligned_custom::<T>(size, DEFAULTALIGN) alloc_aligned_custom::<T>(size, DEFAULTALIGN)
} }
fn alias_mut_slice_to_vec<T>(slice: &[T]) -> Vec<T> {
unsafe {
let ptr: *mut T = slice.as_ptr() as *mut T;
let len: usize = slice.len();
Vec::from_raw_parts(ptr, len, len)
}
}

View File

@@ -1,26 +1,46 @@
use crate::ffi::module::{delete_module_info, module_info_t, new_module_info, MODULE}; use crate::ffi::module::{delete_module_info, module_info_t, new_module_info, MODULE};
use crate::{Free, GALOISGENERATOR}; use crate::GALOISGENERATOR;
pub type MODULETYPE = u8; #[derive(Copy, Clone)]
pub const FFT64: u8 = 0; #[repr(u8)]
pub const NTT120: u8 = 1; pub enum MODULETYPE {
FFT64,
NTT120,
}
pub struct Module(pub *mut MODULE, pub usize); pub struct Module {
pub ptr: *mut MODULE,
pub n: usize,
pub backend: MODULETYPE,
}
impl Module { impl Module {
// Instantiates a new module. // Instantiates a new module.
pub fn new<const MODULETYPE: MODULETYPE>(n: usize) -> Self { pub fn new(n: usize, module_type: MODULETYPE) -> Self {
unsafe { unsafe {
let m: *mut module_info_t = new_module_info(n as u64, MODULETYPE as u32); let module_type_u32: u32;
match module_type {
MODULETYPE::FFT64 => module_type_u32 = 0,
MODULETYPE::NTT120 => module_type_u32 = 1,
}
let m: *mut module_info_t = new_module_info(n as u64, module_type_u32);
if m.is_null() { if m.is_null() {
panic!("Failed to create module."); panic!("Failed to create module.");
} }
Self(m, n) Self {
ptr: m,
n: n,
backend: module_type,
}
} }
} }
pub fn backend(&self) -> MODULETYPE {
self.backend
}
pub fn n(&self) -> usize { pub fn n(&self) -> usize {
self.1 self.n
} }
pub fn log_n(&self) -> usize { pub fn log_n(&self) -> usize {
@@ -53,11 +73,9 @@ impl Module {
(gal_el as i64) * gen.signum() (gal_el as i64) * gen.signum()
} }
}
impl Free for Module { pub fn free(self) {
fn free(self) { unsafe { delete_module_info(self.ptr) }
unsafe { delete_module_info(self.0) }
drop(self); drop(self);
} }
} }

View File

@@ -1,16 +1,16 @@
use crate::{Infos, Module, VecZnxApi}; use crate::{Infos, Module, VecZnx};
use rand_distr::{Distribution, Normal}; use rand_distr::{Distribution, Normal};
use sampling::source::Source; use sampling::source::Source;
pub trait Sampling<T: VecZnxApi + Infos> { pub trait Sampling {
/// Fills the first `cols` cols with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] /// Fills the first `cols` cols with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
fn fill_uniform(&self, log_base2k: usize, a: &mut T, cols: usize, source: &mut Source); fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source);
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
fn add_dist_f64<D: Distribution<f64>>( fn add_dist_f64<D: Distribution<f64>>(
&self, &self,
log_base2k: usize, log_base2k: usize,
a: &mut T, a: &mut VecZnx,
log_k: usize, log_k: usize,
source: &mut Source, source: &mut Source,
dist: D, dist: D,
@@ -21,7 +21,7 @@ pub trait Sampling<T: VecZnxApi + Infos> {
fn add_normal( fn add_normal(
&self, &self,
log_base2k: usize, log_base2k: usize,
a: &mut T, a: &mut VecZnx,
log_k: usize, log_k: usize,
source: &mut Source, source: &mut Source,
sigma: f64, sigma: f64,
@@ -29,8 +29,8 @@ pub trait Sampling<T: VecZnxApi + Infos> {
); );
} }
impl<T: VecZnxApi + Infos> Sampling<T> for Module { impl Sampling for Module {
fn fill_uniform(&self, log_base2k: usize, a: &mut T, cols: usize, source: &mut Source) { fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source) {
let base2k: u64 = 1 << log_base2k; let base2k: u64 = 1 << log_base2k;
let mask: u64 = base2k - 1; let mask: u64 = base2k - 1;
let base2k_half: i64 = (base2k >> 1) as i64; let base2k_half: i64 = (base2k >> 1) as i64;
@@ -43,7 +43,7 @@ impl<T: VecZnxApi + Infos> Sampling<T> for Module {
fn add_dist_f64<D: Distribution<f64>>( fn add_dist_f64<D: Distribution<f64>>(
&self, &self,
log_base2k: usize, log_base2k: usize,
a: &mut T, a: &mut VecZnx,
log_k: usize, log_k: usize,
source: &mut Source, source: &mut Source,
dist: D, dist: D,
@@ -79,7 +79,7 @@ impl<T: VecZnxApi + Infos> Sampling<T> for Module {
fn add_normal( fn add_normal(
&self, &self,
log_base2k: usize, log_base2k: usize,
a: &mut T, a: &mut VecZnx,
log_k: usize, log_k: usize,
source: &mut Source, source: &mut Source,
sigma: f64, sigma: f64,

View File

@@ -1,13 +1,18 @@
use crate::ffi::svp; use crate::ffi::svp;
use crate::{alias_mut_slice_to_vec, assert_alignement, Module, VecZnxApi, VecZnxDft}; use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::{assert_alignement, Module, VecZnx, VecZnxDft};
use crate::{alloc_aligned, cast, Infos}; use crate::{alloc_aligned, cast_mut, Infos};
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand_core::RngCore; use rand_core::RngCore;
use rand_distr::{Distribution, WeightedIndex}; use rand_distr::{Distribution, WeightedIndex};
use sampling::source::Source; use sampling::source::Source;
pub struct Scalar(pub Vec<i64>); pub struct Scalar {
pub n: usize,
pub data: Vec<i64>,
pub ptr: *mut i64,
}
impl Module { impl Module {
pub fn new_scalar(&self) -> Scalar { pub fn new_scalar(&self) -> Scalar {
@@ -17,52 +22,70 @@ impl Module {
impl Scalar { impl Scalar {
pub fn new(n: usize) -> Self { pub fn new(n: usize) -> Self {
Self(alloc_aligned::<i64>(n)) let mut data: Vec<i64> = alloc_aligned::<i64>(n);
let ptr: *mut i64 = data.as_mut_ptr();
Self {
n: n,
data: data,
ptr: ptr,
}
} }
pub fn n(&self) -> usize { pub fn n(&self) -> usize {
self.0.len() self.n
} }
pub fn buffer_size(n: usize) -> usize { pub fn buffer_size(n: usize) -> usize {
n n
} }
pub fn from_buffer(&mut self, n: usize, buf: &mut [u8]) { pub fn from_buffer(&mut self, n: usize, bytes: &mut [u8]) -> Self {
let size: usize = Self::buffer_size(n); let size: usize = Self::buffer_size(n);
debug_assert!( debug_assert!(
buf.len() >= size, bytes.len() == size,
"invalid buffer: buf.len()={} < self.buffer_size(n={})={}", "invalid buffer: bytes.len()={} < self.buffer_size(n={})={}",
buf.len(), bytes.len(),
n, n,
size size
); );
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(buf.as_ptr()) assert_alignement(bytes.as_ptr())
}
unsafe {
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
Self {
n: n,
data: Vec::from_raw_parts(bytes_i64.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: ptr,
}
} }
self.0 = alias_mut_slice_to_vec(cast::<u8, i64>(&buf[..size]))
} }
pub fn as_ptr(&self) -> *const i64 { pub fn as_ptr(&self) -> *const i64 {
self.0.as_ptr() self.ptr
}
pub fn raw(&self) -> &[i64] {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n) }
} }
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
let choices: [i64; 3] = [-1, 0, 1]; let choices: [i64; 3] = [-1, 0, 1];
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
let dist: WeightedIndex<f64> = WeightedIndex::new(&weights).unwrap(); let dist: WeightedIndex<f64> = WeightedIndex::new(&weights).unwrap();
self.0 self.data
.iter_mut() .iter_mut()
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]); .for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
} }
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
assert!(hw <= self.n()); assert!(hw <= self.n());
self.0[..hw] self.data[..hw]
.iter_mut() .iter_mut()
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
self.0.shuffle(source); self.data.shuffle(source);
} }
} }
@@ -105,35 +128,23 @@ pub trait SvpPPolOps {
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
/// the [VecZnxDft] is multiplied with [SvpPPol]. /// the [VecZnxDft] is multiplied with [SvpPPol].
fn svp_apply_dft<T: VecZnxApi + Infos>( fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_cols: usize);
&self,
c: &mut VecZnxDft,
a: &SvpPPol,
b: &T,
b_cols: usize,
);
} }
impl SvpPPolOps for Module { impl SvpPPolOps for Module {
fn new_svp_ppol(&self) -> SvpPPol { fn new_svp_ppol(&self) -> SvpPPol {
unsafe { SvpPPol(svp::new_svp_ppol(self.0), self.n()) } unsafe { SvpPPol(svp::new_svp_ppol(self.ptr), self.n()) }
} }
fn bytes_of_svp_ppol(&self) -> usize { fn bytes_of_svp_ppol(&self) -> usize {
unsafe { svp::bytes_of_svp_ppol(self.0) as usize } unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize }
} }
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) { fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) {
unsafe { svp::svp_prepare(self.0, svp_ppol.0, a.as_ptr()) } unsafe { svp::svp_prepare(self.ptr, svp_ppol.0, a.as_ptr()) }
} }
fn svp_apply_dft<T: VecZnxApi + Infos>( fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_cols: usize) {
&self,
c: &mut VecZnxDft,
a: &SvpPPol,
b: &T,
b_cols: usize,
) {
debug_assert!( debug_assert!(
c.cols() >= b_cols, c.cols() >= b_cols,
"invalid c_vector: c_vector.cols()={} < b.cols()={}", "invalid c_vector: c_vector.cols()={} < b.cols()={}",
@@ -142,8 +153,8 @@ impl SvpPPolOps for Module {
); );
unsafe { unsafe {
svp::svp_apply_dft( svp::svp_apply_dft(
self.0, self.ptr,
c.0, c.ptr as *mut vec_znx_dft_t,
b_cols as u64, b_cols as u64,
a.0, a.0,
b.as_ptr(), b.as_ptr(),

View File

@@ -1,18 +1,35 @@
use crate::cast_mut; use crate::cast_mut;
use crate::ffi::vec_znx; use crate::ffi::vec_znx;
use crate::ffi::znx; use crate::ffi::znx;
use crate::ffi::znx::znx_zero_i64_ref; use crate::{alloc_aligned, assert_alignement};
use crate::{alias_mut_slice_to_vec, alloc_aligned, assert_alignement};
use crate::{Infos, Module}; use crate::{Infos, Module};
use itertools::izip; use itertools::izip;
use std::cmp::min; use std::cmp::min;
/// [VecZnx] represents a vector of small norm polynomials of Zn\[X\] with [i64] coefficients.
/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array
/// in the memory.
#[derive(Clone)]
pub struct VecZnx {
/// Polynomial degree.
n: usize,
/// Number of columns.
cols: usize,
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
data: Vec<i64>,
/// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it).
ptr: *mut i64,
}
pub trait VecZnxVec { pub trait VecZnxVec {
fn dblptr(&self) -> Vec<&[i64]>; fn dblptr(&self) -> Vec<&[i64]>;
fn dblptr_mut(&mut self) -> Vec<&mut [i64]>; fn dblptr_mut(&mut self) -> Vec<&mut [i64]>;
} }
impl<T: VecZnxCommon> VecZnxVec for Vec<T> { impl VecZnxVec for Vec<VecZnx> {
fn dblptr(&self) -> Vec<&[i64]> { fn dblptr(&self) -> Vec<&[i64]> {
self.iter().map(|v| v.raw()).collect() self.iter().map(|v| v.raw()).collect()
} }
@@ -22,328 +39,141 @@ impl<T: VecZnxCommon> VecZnxVec for Vec<T> {
} }
} }
pub trait VecZnxApi: AsRef<Self> + AsMut<Self> {
type Owned: VecZnxCommon;
fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned;
/// Returns the minimum size of the [u8] array required to assign a
/// new backend array.
fn bytes_of(n: usize, cols: usize) -> usize;
/// Copy the data of a onto self.
fn copy_from<A: VecZnxCommon, B: VecZnxCommon>(&mut self, a: &A)
where
Self: AsMut<B>;
/// Returns the backing array.
fn raw(&self) -> &[i64];
/// Returns the mutable backing array.
fn raw_mut(&mut self) -> &mut [i64];
/// Returns a non-mutable pointer to the backing array.
fn as_ptr(&self) -> *const i64;
/// Returns a mutable pointer to the backing array.
fn as_mut_ptr(&mut self) -> *mut i64;
/// Returns a non-mutable reference to the i-th cols.
fn at(&self, i: usize) -> &[i64];
/// Returns a mutable reference to the i-th cols .
fn at_mut(&mut self, i: usize) -> &mut [i64];
/// Returns a non-mutable pointer to the i-th cols.
fn at_ptr(&self, i: usize) -> *const i64;
/// Returns a mutable pointer to the i-th cols.
fn at_mut_ptr(&mut self, i: usize) -> *mut i64;
/// Zeroes the backing array.
fn zero(&mut self);
/// Normalization: propagates carry and ensures each coefficients
/// falls into the range [-2^{K-1}, 2^{K-1}].
fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]);
/// Right shifts the coefficients by k bits.
///
/// # Arguments
///
/// * `log_base2k`: the base two logarithm of the coefficients decomposition.
/// * `k`: the shift amount.
/// * `carry`: scratch space of size at least equal to self.n() * self.cols() << 3.
///
/// # Panics
///
/// The method will panic if carry.len() < self.n() * self.cols() << 3.
fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]);
/// If self.n() > a.n(): Extracts X^{i*self.n()/a.n()} -> X^{i}.
/// If self.n() < a.n(): Extracts X^{i} -> X^{i*a.n()/self.n()}.
///
/// # Arguments
///
/// * `a`: the receiver polynomial in which the extracted coefficients are stored.
fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(&self, a: &mut A)
where
Self: AsRef<B>;
fn print(&self, cols: usize, n: usize);
}
pub fn bytes_of_vec_znx(n: usize, cols: usize) -> usize { pub fn bytes_of_vec_znx(n: usize, cols: usize) -> usize {
n * cols * 8 n * cols * 8
} }
pub struct VecZnxBorrow { impl VecZnx {
pub n: usize,
pub cols: usize,
pub data: *mut i64,
}
impl AsMut<VecZnxBorrow> for VecZnxBorrow {
fn as_mut(&mut self) -> &mut VecZnxBorrow {
self
}
}
impl AsRef<VecZnxBorrow> for VecZnxBorrow {
fn as_ref(&self) -> &VecZnxBorrow {
self
}
}
impl VecZnxCommon for VecZnxBorrow {}
impl VecZnxApi for VecZnxBorrow {
type Owned = VecZnxBorrow;
/// Returns a new struct implementing [VecZnxBorrow] with the provided data as backing array.
///
/// The struct will *NOT* take ownership of buf[..[VecZnx::bytes_of]]
///
/// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [VecZnx::bytes_of].
fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned {
let size = Self::bytes_of(n, cols);
debug_assert!(
bytes.len() >= size,
"invalid buffer: buf.len()={} < self.buffer_size(n={}, cols={})={}",
bytes.len(),
n,
cols,
size
);
#[cfg(debug_assertions)]
{
assert_alignement(bytes.as_ptr())
}
VecZnxBorrow {
n: n,
cols: cols,
data: cast_mut(&mut bytes[..size]).as_mut_ptr(),
}
}
fn bytes_of(n: usize, cols: usize) -> usize {
bytes_of_vec_znx(n, cols)
}
fn copy_from<A: VecZnxCommon, B: VecZnxCommon>(&mut self, a: &A)
where
Self: AsMut<B>,
{
copy_vec_znx_from::<A, B>(self.as_mut(), a);
}
fn as_ptr(&self) -> *const i64 {
self.data
}
fn as_mut_ptr(&mut self) -> *mut i64 {
self.data
}
fn raw(&self) -> &[i64] {
unsafe { std::slice::from_raw_parts(self.data, self.n * self.cols) }
}
fn raw_mut(&mut self) -> &mut [i64] {
unsafe { std::slice::from_raw_parts_mut(self.data, self.n * self.cols) }
}
fn at(&self, i: usize) -> &[i64] {
let n: usize = self.n();
&self.raw()[n * i..n * (i + 1)]
}
fn at_mut(&mut self, i: usize) -> &mut [i64] {
let n: usize = self.n();
&mut self.raw_mut()[n * i..n * (i + 1)]
}
fn at_ptr(&self, i: usize) -> *const i64 {
self.data.wrapping_add(self.n * i)
}
fn at_mut_ptr(&mut self, i: usize) -> *mut i64 {
self.data.wrapping_add(self.n * i)
}
fn zero(&mut self) {
unsafe {
znx_zero_i64_ref((self.n * self.cols) as u64, self.data);
}
}
fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) {
normalize(log_base2k, self, carry)
}
fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) {
rsh(log_base2k, self, k, carry)
}
fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(&self, a: &mut A)
where
Self: AsRef<B>,
{
switch_degree(a, self.as_ref());
}
fn print(&self, cols: usize, n: usize) {
(0..cols).for_each(|i| println!("{}: {:?}", i, &self.at(i)[..n]))
}
}
impl VecZnxCommon for VecZnx {}
impl VecZnxApi for VecZnx {
type Owned = VecZnx;
/// Returns a new struct implementing [VecZnx] with the provided data as backing array. /// Returns a new struct implementing [VecZnx] with the provided data as backing array.
/// ///
/// The struct will take ownership of buf[..[VecZnx::bytes_of]] /// The struct will take ownership of buf[..[VecZnx::bytes_of]]
/// ///
/// User must ensure that data is properly alligned and that /// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [VecZnx::bytes_of]. /// the size of data is at least equal to [VecZnx::bytes_of].
fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned { pub fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self {
let size = Self::bytes_of(n, cols);
debug_assert!(
bytes.len() >= size,
"invalid bytes: bytes.len()={} < self.bytes_of(n={}, cols={})={}",
bytes.len(),
n,
cols,
size
);
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(bytes.as_ptr()) assert_eq!(bytes.len(), Self::bytes_of(n, cols));
assert_alignement(bytes.as_ptr());
}
unsafe {
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
VecZnx {
n: n,
cols: cols,
data: Vec::from_raw_parts(bytes_i64.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: ptr,
}
}
}
pub fn from_bytes_borrow(n: usize, cols: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert!(bytes.len() >= Self::bytes_of(n, cols));
assert_alignement(bytes.as_ptr());
} }
VecZnx { VecZnx {
n: n, n: n,
data: alias_mut_slice_to_vec(cast_mut(&mut bytes[..size])), cols: cols,
data: Vec::new(),
ptr: bytes.as_mut_ptr() as *mut i64,
} }
} }
fn bytes_of(n: usize, cols: usize) -> usize { pub fn bytes_of(n: usize, cols: usize) -> usize {
bytes_of_vec_znx(n, cols) bytes_of_vec_znx(n, cols)
} }
fn copy_from<A: VecZnxCommon, B: VecZnxCommon>(&mut self, a: &A) pub fn copy_from(&mut self, a: &VecZnx) {
where copy_vec_znx_from(self, a);
Self: AsMut<B>,
{
copy_vec_znx_from(self.as_mut(), a);
} }
fn raw(&self) -> &[i64] { pub fn raw(&self) -> &[i64] {
&self.data unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.cols) }
} }
fn raw_mut(&mut self) -> &mut [i64] { pub fn borrowing(&self) -> bool {
&mut self.data self.data.len() == 0
} }
fn as_ptr(&self) -> *const i64 { pub fn raw_mut(&mut self) -> &mut [i64] {
self.data.as_ptr() unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.cols) }
} }
fn as_mut_ptr(&mut self) -> *mut i64 { pub fn as_ptr(&self) -> *const i64 {
self.data.as_mut_ptr() self.ptr
} }
fn at(&self, i: usize) -> &[i64] { pub fn as_mut_ptr(&mut self) -> *mut i64 {
self.ptr
}
pub fn at(&self, i: usize) -> &[i64] {
let n: usize = self.n(); let n: usize = self.n();
&self.raw()[n * i..n * (i + 1)] &self.raw()[n * i..n * (i + 1)]
} }
fn at_mut(&mut self, i: usize) -> &mut [i64] { pub fn at_mut(&mut self, i: usize) -> &mut [i64] {
let n: usize = self.n(); let n: usize = self.n();
&mut self.raw_mut()[n * i..n * (i + 1)] &mut self.raw_mut()[n * i..n * (i + 1)]
} }
fn at_ptr(&self, i: usize) -> *const i64 { pub fn at_ptr(&self, i: usize) -> *const i64 {
&self.data[i * self.n] as *const i64 self.ptr.wrapping_add(i * self.n)
} }
fn at_mut_ptr(&mut self, i: usize) -> *mut i64 { pub fn at_mut_ptr(&mut self, i: usize) -> *mut i64 {
&mut self.data[i * self.n] as *mut i64 self.ptr.wrapping_add(i * self.n)
} }
fn zero(&mut self) { pub fn zero(&mut self) {
unsafe { znx::znx_zero_i64_ref(self.data.len() as u64, self.data.as_mut_ptr()) } unsafe { znx::znx_zero_i64_ref((self.n * self.cols) as u64, self.ptr) }
} }
fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) {
normalize(log_base2k, self, carry) normalize(log_base2k, self, carry)
} }
fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) {
rsh(log_base2k, self, k, carry) rsh(log_base2k, self, k, carry)
} }
fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(&self, a: &mut A) pub fn switch_degree(&self, a: &mut VecZnx) {
where switch_degree(a, self)
Self: AsRef<B>,
{
switch_degree(a, self.as_ref())
} }
fn print(&self, cols: usize, n: usize) { pub fn print(&self, cols: usize, n: usize) {
(0..cols).for_each(|i| println!("{}: {:?}", i, &self.at(i)[..n])) (0..cols).for_each(|i| println!("{}: {:?}", i, &self.at(i)[..n]))
} }
} }
/// [VecZnx] represents a vector of small norm polynomials of Zn\[X\] with [i64] coefficients. impl Infos for VecZnx {
/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array /// Returns the base 2 logarithm of the [VecZnx] degree.
/// in the memory. fn log_n(&self) -> usize {
#[derive(Clone)] (usize::BITS - (self.n - 1).leading_zeros()) as _
pub struct VecZnx {
/// Polynomial degree.
pub n: usize,
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
pub data: Vec<i64>,
} }
impl AsMut<VecZnx> for VecZnx { /// Returns the [VecZnx] degree.
fn as_mut(&mut self) -> &mut VecZnx { fn n(&self) -> usize {
self self.n
}
} }
impl AsRef<VecZnx> for VecZnx { /// Returns the number of cols of the [VecZnx].
fn as_ref(&self) -> &VecZnx { fn cols(&self) -> usize {
self self.cols
}
/// Returns the number of rows of the [VecZnx].
fn rows(&self) -> usize {
1
} }
} }
/// Copies the coefficients of `a` on the receiver. /// Copies the coefficients of `a` on the receiver.
/// Copy is done with the minimum size matching both backing arrays. /// Copy is done with the minimum size matching both backing arrays.
pub fn copy_vec_znx_from<A: VecZnxCommon, B: VecZnxCommon>(b: &mut B, a: &A) { pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) {
let data_a: &[i64] = a.raw(); let data_a: &[i64] = a.raw();
let data_b: &mut [i64] = b.raw_mut(); let data_b: &mut [i64] = b.raw_mut();
let size = min(data_b.len(), data_a.len()); let size = min(data_b.len(), data_a.len());
@@ -353,9 +183,13 @@ pub fn copy_vec_znx_from<A: VecZnxCommon, B: VecZnxCommon>(b: &mut B, a: &A) {
impl VecZnx { impl VecZnx {
/// Allocates a new [VecZnx] composed of #cols polynomials of Z\[X\]. /// Allocates a new [VecZnx] composed of #cols polynomials of Z\[X\].
pub fn new(n: usize, cols: usize) -> Self { pub fn new(n: usize, cols: usize) -> Self {
let mut data: Vec<i64> = alloc_aligned::<i64>(n * cols);
let ptr: *mut i64 = data.as_mut_ptr();
Self { Self {
n: n, n: n,
data: alloc_aligned::<i64>(n * cols), cols: cols,
data: data,
ptr: ptr,
} }
} }
@@ -370,8 +204,12 @@ impl VecZnx {
return; return;
} }
if !self.borrowing() {
self.data self.data
.truncate((self.cols() - k / log_base2k) * self.n()); .truncate((self.cols() - k / log_base2k) * self.n());
}
self.cols -= k / log_base2k;
let k_rem: usize = k % log_base2k; let k_rem: usize = k % log_base2k;
@@ -384,7 +222,7 @@ impl VecZnx {
} }
} }
pub fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(b: &mut B, a: &A) { pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) {
let (n_in, n_out) = (a.n(), b.n()); let (n_in, n_out) = (a.n(), b.n());
let (gap_in, gap_out): (usize, usize); let (gap_in, gap_out): (usize, usize);
@@ -406,7 +244,7 @@ pub fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(b: &mut B, a: &A) {
}); });
} }
fn normalize<T: VecZnxCommon>(log_base2k: usize, a: &mut T, tmp_bytes: &mut [u8]) { fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
let n: usize = a.n(); let n: usize = a.n();
debug_assert!( debug_assert!(
@@ -437,7 +275,7 @@ fn normalize<T: VecZnxCommon>(log_base2k: usize, a: &mut T, tmp_bytes: &mut [u8]
} }
} }
pub fn rsh<T: VecZnxCommon>(log_base2k: usize, a: &mut T, k: usize, tmp_bytes: &mut [u8]) { pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) {
let n: usize = a.n(); let n: usize = a.n();
debug_assert!( debug_assert!(
@@ -469,7 +307,6 @@ pub fn rsh<T: VecZnxCommon>(log_base2k: usize, a: &mut T, k: usize, tmp_bytes: &
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
} }
let mask: i64 = (1 << k_rem) - 1;
let log_base2k: usize = log_base2k; let log_base2k: usize = log_base2k;
(cols_steps..cols).for_each(|i| { (cols_steps..cols).for_each(|i| {
@@ -487,8 +324,6 @@ fn get_base_k_carry(x: i64, k: usize) -> i64{
(x << 64 - k) >> (64 - k) (x << 64 - k) >> (64 - k)
} }
pub trait VecZnxCommon: VecZnxApi + Infos {}
pub trait VecZnxOps { pub trait VecZnxOps {
/// Allocates a new [VecZnx]. /// Allocates a new [VecZnx].
/// ///
@@ -504,50 +339,34 @@ pub trait VecZnxOps {
fn vec_znx_normalize_tmp_bytes(&self) -> usize; fn vec_znx_normalize_tmp_bytes(&self) -> usize;
/// c <- a + b. /// c <- a + b.
fn vec_znx_add<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>( fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
&self,
c: &mut C,
a: &A,
b: &B,
);
/// b <- b + a. /// b <- b + a.
fn vec_znx_add_inplace<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A); fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx);
/// c <- a - b. /// c <- a - b.
fn vec_znx_sub<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>( fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
&self,
c: &mut C,
a: &A,
b: &B,
);
/// b <- b - a. /// b <- b - a.
fn vec_znx_sub_inplace<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A); fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx);
/// b <- -a. /// b <- -a.
fn vec_znx_negate<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A); fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx);
/// b <- -b. /// b <- -b.
fn vec_znx_negate_inplace<A: VecZnxCommon>(&self, a: &mut A); fn vec_znx_negate_inplace(&self, a: &mut VecZnx);
/// b <- a * X^k (mod X^{n} + 1) /// b <- a * X^k (mod X^{n} + 1)
fn vec_znx_rotate<A: VecZnxCommon, B: VecZnxCommon>(&self, k: i64, b: &mut B, a: &A); fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx);
/// a <- a * X^k (mod X^{n} + 1) /// a <- a * X^k (mod X^{n} + 1)
fn vec_znx_rotate_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A); fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx);
/// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) /// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
fn vec_znx_automorphism<A: VecZnxCommon, B: VecZnxCommon>( fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx, a_cols: usize);
&self,
k: i64,
b: &mut B,
a: &A,
a_cols: usize,
);
/// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) /// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
fn vec_znx_automorphism_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A, a_cols: usize); fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_cols: usize);
/// Splits b into subrings and copies them them into a. /// Splits b into subrings and copies them them into a.
/// ///
@@ -555,12 +374,7 @@ pub trait VecZnxOps {
/// ///
/// This method requires that all [VecZnx] of b have the same ring degree /// This method requires that all [VecZnx] of b have the same ring degree
/// and that b.n() * b.len() <= a.n() /// and that b.n() * b.len() <= a.n()
fn vec_znx_split<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>( fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx);
&self,
b: &mut Vec<B>,
a: &A,
buf: &mut C,
);
/// Merges the subrings a into b. /// Merges the subrings a into b.
/// ///
@@ -568,7 +382,7 @@ pub trait VecZnxOps {
/// ///
/// This method requires that all [VecZnx] of a have the same ring degree /// This method requires that all [VecZnx] of a have the same ring degree
/// and that a.n() * a.len() <= b.n() /// and that a.n() * a.len() <= b.n()
fn vec_znx_merge<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &Vec<A>); fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>);
} }
impl VecZnxOps for Module { impl VecZnxOps for Module {
@@ -581,19 +395,14 @@ impl VecZnxOps for Module {
} }
fn vec_znx_normalize_tmp_bytes(&self) -> usize { fn vec_znx_normalize_tmp_bytes(&self) -> usize {
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.0) as usize } unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize }
} }
// c <- a + b // c <- a + b
fn vec_znx_add<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>( fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) {
&self,
c: &mut C,
a: &A,
b: &B,
) {
unsafe { unsafe {
vec_znx::vec_znx_add( vec_znx::vec_znx_add(
self.0, self.ptr,
c.as_mut_ptr(), c.as_mut_ptr(),
c.cols() as u64, c.cols() as u64,
c.n() as u64, c.n() as u64,
@@ -608,10 +417,10 @@ impl VecZnxOps for Module {
} }
// b <- a + b // b <- a + b
fn vec_znx_add_inplace<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A) { fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
unsafe { unsafe {
vec_znx::vec_znx_add( vec_znx::vec_znx_add(
self.0, self.ptr,
b.as_mut_ptr(), b.as_mut_ptr(),
b.cols() as u64, b.cols() as u64,
b.n() as u64, b.n() as u64,
@@ -626,15 +435,10 @@ impl VecZnxOps for Module {
} }
// c <- a + b // c <- a + b
fn vec_znx_sub<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>( fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) {
&self,
c: &mut C,
a: &A,
b: &B,
) {
unsafe { unsafe {
vec_znx::vec_znx_sub( vec_znx::vec_znx_sub(
self.0, self.ptr,
c.as_mut_ptr(), c.as_mut_ptr(),
c.cols() as u64, c.cols() as u64,
c.n() as u64, c.n() as u64,
@@ -649,10 +453,10 @@ impl VecZnxOps for Module {
} }
// b <- a + b // b <- a + b
fn vec_znx_sub_inplace<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A) { fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
unsafe { unsafe {
vec_znx::vec_znx_sub( vec_znx::vec_znx_sub(
self.0, self.ptr,
b.as_mut_ptr(), b.as_mut_ptr(),
b.cols() as u64, b.cols() as u64,
b.n() as u64, b.n() as u64,
@@ -666,10 +470,10 @@ impl VecZnxOps for Module {
} }
} }
fn vec_znx_negate<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A) { fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) {
unsafe { unsafe {
vec_znx::vec_znx_negate( vec_znx::vec_znx_negate(
self.0, self.ptr,
b.as_mut_ptr(), b.as_mut_ptr(),
b.cols() as u64, b.cols() as u64,
b.n() as u64, b.n() as u64,
@@ -680,10 +484,10 @@ impl VecZnxOps for Module {
} }
} }
fn vec_znx_negate_inplace<A: VecZnxCommon>(&self, a: &mut A) { fn vec_znx_negate_inplace(&self, a: &mut VecZnx) {
unsafe { unsafe {
vec_znx::vec_znx_negate( vec_znx::vec_znx_negate(
self.0, self.ptr,
a.as_mut_ptr(), a.as_mut_ptr(),
a.cols() as u64, a.cols() as u64,
a.n() as u64, a.n() as u64,
@@ -694,10 +498,10 @@ impl VecZnxOps for Module {
} }
} }
fn vec_znx_rotate<A: VecZnxCommon, B: VecZnxCommon>(&self, k: i64, b: &mut B, a: &A) { fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) {
unsafe { unsafe {
vec_znx::vec_znx_rotate( vec_znx::vec_znx_rotate(
self.0, self.ptr,
k, k,
b.as_mut_ptr(), b.as_mut_ptr(),
b.cols() as u64, b.cols() as u64,
@@ -709,10 +513,10 @@ impl VecZnxOps for Module {
} }
} }
fn vec_znx_rotate_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A) { fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) {
unsafe { unsafe {
vec_znx::vec_znx_rotate( vec_znx::vec_znx_rotate(
self.0, self.ptr,
k, k,
a.as_mut_ptr(), a.as_mut_ptr(),
a.cols() as u64, a.cols() as u64,
@@ -739,11 +543,11 @@ impl VecZnxOps for Module {
/// ///
/// # Example /// # Example
/// ``` /// ```
/// use base2k::{Module, FFT64, VecZnx, Encoding, Infos, VecZnxApi, VecZnxOps}; /// use base2k::{Module, MODULETYPE, VecZnx, Encoding, Infos, VecZnxOps};
/// use itertools::izip; /// use itertools::izip;
/// ///
/// let n: usize = 8; // polynomial degree /// let n: usize = 8; // polynomial degree
/// let module = Module::new::<FFT64>(n); /// let module = Module::new(n, MODULETYPE::FFT64);
/// let mut a: VecZnx = module.new_vec_znx(2); /// let mut a: VecZnx = module.new_vec_znx(2);
/// let mut b: VecZnx = module.new_vec_znx(2); /// let mut b: VecZnx = module.new_vec_znx(2);
/// let mut c: VecZnx = module.new_vec_znx(2); /// let mut c: VecZnx = module.new_vec_znx(2);
@@ -759,21 +563,15 @@ impl VecZnxOps for Module {
/// (1..col.len()).for_each(|i|{ /// (1..col.len()).for_each(|i|{
/// col[n-i] = -(i as i64) /// col[n-i] = -(i as i64)
/// }); /// });
/// izip!(b.data.iter(), c.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); /// izip!(b.raw().iter(), c.raw().iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
/// ``` /// ```
fn vec_znx_automorphism<A: VecZnxCommon, B: VecZnxCommon>( fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx, a_cols: usize) {
&self,
k: i64,
b: &mut B,
a: &A,
a_cols: usize,
) {
debug_assert_eq!(a.n(), self.n()); debug_assert_eq!(a.n(), self.n());
debug_assert_eq!(b.n(), self.n()); debug_assert_eq!(b.n(), self.n());
debug_assert!(a.cols() >= a_cols); debug_assert!(a.cols() >= a_cols);
unsafe { unsafe {
vec_znx::vec_znx_automorphism( vec_znx::vec_znx_automorphism(
self.0, self.ptr,
k, k,
b.as_mut_ptr(), b.as_mut_ptr(),
b.cols() as u64, b.cols() as u64,
@@ -799,11 +597,11 @@ impl VecZnxOps for Module {
/// ///
/// # Example /// # Example
/// ``` /// ```
/// use base2k::{Module, FFT64, VecZnx, Encoding, Infos, VecZnxApi, VecZnxOps}; /// use base2k::{Module, MODULETYPE, VecZnx, Encoding, Infos, VecZnxOps};
/// use itertools::izip; /// use itertools::izip;
/// ///
/// let n: usize = 8; // polynomial degree /// let n: usize = 8; // polynomial degree
/// let module = Module::new::<FFT64>(n); /// let module = Module::new(n, MODULETYPE::FFT64);
/// let mut a: VecZnx = VecZnx::new(n, 2); /// let mut a: VecZnx = VecZnx::new(n, 2);
/// let mut b: VecZnx = VecZnx::new(n, 2); /// let mut b: VecZnx = VecZnx::new(n, 2);
/// ///
@@ -818,14 +616,14 @@ impl VecZnxOps for Module {
/// (1..col.len()).for_each(|i|{ /// (1..col.len()).for_each(|i|{
/// col[n-i] = -(i as i64) /// col[n-i] = -(i as i64)
/// }); /// });
/// izip!(a.data.iter(), b.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); /// izip!(a.raw().iter(), b.raw().iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
/// ``` /// ```
fn vec_znx_automorphism_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A, a_cols: usize) { fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_cols: usize) {
debug_assert_eq!(a.n(), self.n()); debug_assert_eq!(a.n(), self.n());
debug_assert!(a.cols() >= a_cols); debug_assert!(a.cols() >= a_cols);
unsafe { unsafe {
vec_znx::vec_znx_automorphism( vec_znx::vec_znx_automorphism(
self.0, self.ptr,
k, k,
a.as_mut_ptr(), a.as_mut_ptr(),
a.cols() as u64, a.cols() as u64,
@@ -837,12 +635,7 @@ impl VecZnxOps for Module {
} }
} }
fn vec_znx_split<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>( fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx) {
&self,
b: &mut Vec<B>,
a: &A,
buf: &mut C,
) {
let (n_in, n_out) = (a.n(), b[0].n()); let (n_in, n_out) = (a.n(), b[0].n());
debug_assert!( debug_assert!(
@@ -868,7 +661,7 @@ impl VecZnxOps for Module {
}) })
} }
fn vec_znx_merge<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &Vec<A>) { fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>) {
let (n_in, n_out) = (b.n(), a[0].n()); let (n_in, n_out) = (b.n(), a[0].n());
debug_assert!( debug_assert!(

View File

@@ -1,29 +1,69 @@
use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_big::{self, vec_znx_bigcoeff_t};
use crate::ffi::vec_znx_dft; use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, MODULETYPE};
use crate::{assert_alignement, Infos, Module, VecZnxApi, VecZnxDft};
pub struct VecZnxBig(pub *mut vec_znx_big::vec_znx_bigcoeff_t, pub usize); pub struct VecZnxBig {
pub data: Vec<u8>,
pub ptr: *mut u8,
pub n: usize,
pub cols: usize,
pub backend: MODULETYPE,
}
impl VecZnxBig { impl VecZnxBig {
/// Returns a new [VecZnxBig] with the provided data as backing array. /// Returns a new [VecZnxBig] with the provided data as backing array.
/// User must ensure that data is properly alligned and that /// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. /// the size of data is at least equal to [Module::bytes_of_vec_znx_big].
pub fn from_bytes(cols: usize, bytes: &mut [u8]) -> VecZnxBig { pub fn from_bytes(module: &Module, cols: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(bytes.as_ptr()) assert_alignement(bytes.as_ptr())
}; };
VecZnxBig( unsafe {
bytes.as_mut_ptr() as *mut vec_znx_big::vec_znx_bigcoeff_t, Self {
cols, data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
) ptr: bytes.as_mut_ptr(),
n: module.n(),
cols: cols,
backend: module.backend,
}
}
}
pub fn from_bytes_borrow(module: &Module, cols: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols));
assert_alignement(bytes.as_ptr());
}
Self {
data: Vec::new(),
ptr: bytes.as_mut_ptr(),
n: module.n(),
cols: cols,
backend: module.backend,
}
} }
pub fn as_vec_znx_dft(&mut self) -> VecZnxDft { pub fn as_vec_znx_dft(&mut self) -> VecZnxDft {
VecZnxDft(self.0 as *mut vec_znx_dft::vec_znx_dft_t, self.1) VecZnxDft {
data: Vec::new(),
ptr: self.ptr,
n: self.n,
cols: self.cols,
backend: self.backend,
} }
}
pub fn n(&self) -> usize {
self.n
}
pub fn cols(&self) -> usize { pub fn cols(&self) -> usize {
self.1 self.cols
}
pub fn backend(&self) -> MODULETYPE {
self.backend
} }
} }
@@ -47,39 +87,34 @@ pub trait VecZnxBigOps {
fn bytes_of_vec_znx_big(&self, cols: usize) -> usize; fn bytes_of_vec_znx_big(&self, cols: usize) -> usize;
/// b <- b - a /// b <- b - a
fn vec_znx_big_sub_small_a_inplace<T: VecZnxApi + Infos>(&self, b: &mut VecZnxBig, a: &T); fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx);
/// c <- b - a /// c <- b - a
fn vec_znx_big_sub_small_a<T: VecZnxApi + Infos>( fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig);
&self,
c: &mut VecZnxBig,
a: &T,
b: &VecZnxBig,
);
/// c <- b + a /// c <- b + a
fn vec_znx_big_add_small<T: VecZnxApi + Infos>(&self, c: &mut VecZnxBig, a: &T, b: &VecZnxBig); fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig);
/// b <- b + a /// b <- b + a
fn vec_znx_big_add_small_inplace<T: VecZnxApi + Infos>(&self, b: &mut VecZnxBig, a: &T); fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx);
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
/// b <- normalize(a) /// b <- normalize(a)
fn vec_znx_big_normalize<T: VecZnxApi + Infos>( fn vec_znx_big_normalize(
&self, &self,
log_base2k: usize, log_base2k: usize,
b: &mut T, b: &mut VecZnx,
a: &VecZnxBig, a: &VecZnxBig,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
); );
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize; fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize;
fn vec_znx_big_range_normalize_base2k<T: VecZnxApi + Infos>( fn vec_znx_big_range_normalize_base2k(
&self, &self,
log_base2k: usize, log_base2k: usize,
res: &mut T, res: &mut VecZnx,
a: &VecZnxBig, a: &VecZnxBig,
a_range_begin: usize, a_range_begin: usize,
a_range_xend: usize, a_range_xend: usize,
@@ -94,7 +129,15 @@ pub trait VecZnxBigOps {
impl VecZnxBigOps for Module { impl VecZnxBigOps for Module {
fn new_vec_znx_big(&self, cols: usize) -> VecZnxBig { fn new_vec_znx_big(&self, cols: usize) -> VecZnxBig {
unsafe { VecZnxBig(vec_znx_big::new_vec_znx_big(self.0, cols as u64), cols) } let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_vec_znx_big(cols));
let ptr: *mut u8 = data.as_mut_ptr();
VecZnxBig {
data: data,
ptr: ptr,
n: self.n(),
cols: cols,
backend: self.backend(),
}
} }
fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig { fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig {
@@ -108,55 +151,50 @@ impl VecZnxBigOps for Module {
{ {
assert_alignement(bytes.as_ptr()) assert_alignement(bytes.as_ptr())
} }
VecZnxBig::from_bytes(cols, bytes) VecZnxBig::from_bytes(self, cols, bytes)
} }
fn bytes_of_vec_znx_big(&self, cols: usize) -> usize { fn bytes_of_vec_znx_big(&self, cols: usize) -> usize {
unsafe { vec_znx_big::bytes_of_vec_znx_big(self.0, cols as u64) as usize } unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, cols as u64) as usize }
} }
fn vec_znx_big_sub_small_a_inplace<T: VecZnxApi + Infos>(&self, b: &mut VecZnxBig, a: &T) { fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) {
unsafe { unsafe {
vec_znx_big::vec_znx_big_sub_small_a( vec_znx_big::vec_znx_big_sub_small_a(
self.0, self.ptr,
b.0, b.ptr as *mut vec_znx_bigcoeff_t,
b.cols() as u64, b.cols() as u64,
a.as_ptr(), a.as_ptr(),
a.cols() as u64, a.cols() as u64,
a.n() as u64, a.n() as u64,
b.0, b.ptr as *mut vec_znx_bigcoeff_t,
b.cols() as u64, b.cols() as u64,
) )
} }
} }
fn vec_znx_big_sub_small_a<T: VecZnxApi + Infos>( fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) {
&self,
c: &mut VecZnxBig,
a: &T,
b: &VecZnxBig,
) {
unsafe { unsafe {
vec_znx_big::vec_znx_big_sub_small_a( vec_znx_big::vec_znx_big_sub_small_a(
self.0, self.ptr,
c.0, c.ptr as *mut vec_znx_bigcoeff_t,
c.cols() as u64, c.cols() as u64,
a.as_ptr(), a.as_ptr(),
a.cols() as u64, a.cols() as u64,
a.n() as u64, a.n() as u64,
b.0, b.ptr as *mut vec_znx_bigcoeff_t,
b.cols() as u64, b.cols() as u64,
) )
} }
} }
fn vec_znx_big_add_small<T: VecZnxApi + Infos>(&self, c: &mut VecZnxBig, a: &T, b: &VecZnxBig) { fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) {
unsafe { unsafe {
vec_znx_big::vec_znx_big_add_small( vec_znx_big::vec_znx_big_add_small(
self.0, self.ptr,
c.0, c.ptr as *mut vec_znx_bigcoeff_t,
c.cols() as u64, c.cols() as u64,
b.0, b.ptr as *mut vec_znx_bigcoeff_t,
b.cols() as u64, b.cols() as u64,
a.as_ptr(), a.as_ptr(),
a.cols() as u64, a.cols() as u64,
@@ -165,13 +203,13 @@ impl VecZnxBigOps for Module {
} }
} }
fn vec_znx_big_add_small_inplace<T: VecZnxApi + Infos>(&self, b: &mut VecZnxBig, a: &T) { fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) {
unsafe { unsafe {
vec_znx_big::vec_znx_big_add_small( vec_znx_big::vec_znx_big_add_small(
self.0, self.ptr,
b.0, b.ptr as *mut vec_znx_bigcoeff_t,
b.cols() as u64, b.cols() as u64,
b.0, b.ptr as *mut vec_znx_bigcoeff_t,
b.cols() as u64, b.cols() as u64,
a.as_ptr(), a.as_ptr(),
a.cols() as u64, a.cols() as u64,
@@ -181,13 +219,13 @@ impl VecZnxBigOps for Module {
} }
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.0) as usize } unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize }
} }
fn vec_znx_big_normalize<T: VecZnxApi + Infos>( fn vec_znx_big_normalize(
&self, &self,
log_base2k: usize, log_base2k: usize,
b: &mut T, b: &mut VecZnx,
a: &VecZnxBig, a: &VecZnxBig,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) { ) {
@@ -203,12 +241,12 @@ impl VecZnxBigOps for Module {
} }
unsafe { unsafe {
vec_znx_big::vec_znx_big_normalize_base2k( vec_znx_big::vec_znx_big_normalize_base2k(
self.0, self.ptr,
log_base2k as u64, log_base2k as u64,
b.as_mut_ptr(), b.as_mut_ptr(),
b.cols() as u64, b.cols() as u64,
b.n() as u64, b.n() as u64,
a.0, a.ptr as *mut vec_znx_bigcoeff_t,
a.cols() as u64, a.cols() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
) )
@@ -216,13 +254,13 @@ impl VecZnxBigOps for Module {
} }
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize { fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize {
unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.0) as usize } unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize }
} }
fn vec_znx_big_range_normalize_base2k<T: VecZnxApi + Infos>( fn vec_znx_big_range_normalize_base2k(
&self, &self,
log_base2k: usize, log_base2k: usize,
res: &mut T, res: &mut VecZnx,
a: &VecZnxBig, a: &VecZnxBig,
a_range_begin: usize, a_range_begin: usize,
a_range_xend: usize, a_range_xend: usize,
@@ -241,12 +279,12 @@ impl VecZnxBigOps for Module {
} }
unsafe { unsafe {
vec_znx_big::vec_znx_big_range_normalize_base2k( vec_znx_big::vec_znx_big_range_normalize_base2k(
self.0, self.ptr,
log_base2k as u64, log_base2k as u64,
res.as_mut_ptr(), res.as_mut_ptr(),
res.cols() as u64, res.cols() as u64,
res.n() as u64, res.n() as u64,
a.0, a.ptr as *mut vec_znx_bigcoeff_t,
a_range_begin as u64, a_range_begin as u64,
a_range_xend as u64, a_range_xend as u64,
a_range_step as u64, a_range_step as u64,
@@ -258,11 +296,11 @@ impl VecZnxBigOps for Module {
fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) {
unsafe { unsafe {
vec_znx_big::vec_znx_big_automorphism( vec_znx_big::vec_znx_big_automorphism(
self.0, self.ptr,
gal_el, gal_el,
b.0, b.ptr as *mut vec_znx_bigcoeff_t,
b.cols() as u64, b.cols() as u64,
a.0, a.ptr as *mut vec_znx_bigcoeff_t,
a.cols() as u64, a.cols() as u64,
); );
} }
@@ -271,11 +309,11 @@ impl VecZnxBigOps for Module {
fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) {
unsafe { unsafe {
vec_znx_big::vec_znx_big_automorphism( vec_znx_big::vec_znx_big_automorphism(
self.0, self.ptr,
gal_el, gal_el,
a.0, a.ptr as *mut vec_znx_bigcoeff_t,
a.cols() as u64, a.cols() as u64,
a.0, a.ptr as *mut vec_znx_bigcoeff_t,
a.cols() as u64, a.cols() as u64,
); );
} }

View File

@@ -1,33 +1,104 @@
use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_big::vec_znx_bigcoeff_t;
use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft;
use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft; use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t};
use crate::{assert_alignement, Infos, Module, VecZnxApi, VecZnxBig}; use crate::{alloc_aligned, VecZnx};
use crate::{assert_alignement, Infos, Module, VecZnxBig, MODULETYPE};
pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize); pub struct VecZnxDft {
pub data: Vec<u8>,
pub ptr: *mut u8,
pub n: usize,
pub cols: usize,
pub backend: MODULETYPE,
}
impl VecZnxDft { impl VecZnxDft {
/// Returns a new [VecZnxDft] with the provided data as backing array. /// Returns a new [VecZnxDft] with the provided data as backing array.
/// User must ensure that data is properly alligned and that /// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft].
pub fn from_bytes(cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { pub fn from_bytes(module: &Module, cols: usize, bytes: &mut [u8]) -> VecZnxDft {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(tmp_bytes.as_ptr()) assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols));
assert_alignement(bytes.as_ptr())
}
unsafe {
VecZnxDft {
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: bytes.as_mut_ptr(),
n: module.n(),
cols: cols,
backend: module.backend,
}
}
}
pub fn from_bytes_borrow(module: &Module, cols: usize, bytes: &mut [u8]) -> VecZnxDft {
#[cfg(debug_assertions)]
{
assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols));
assert_alignement(bytes.as_ptr());
}
VecZnxDft {
data: Vec::new(),
ptr: bytes.as_mut_ptr(),
n: module.n(),
cols: cols,
backend: module.backend,
} }
VecZnxDft(
tmp_bytes.as_mut_ptr() as *mut vec_znx_dft::vec_znx_dft_t,
cols,
)
} }
/// Cast a [VecZnxDft] into a [VecZnxBig]. /// Cast a [VecZnxDft] into a [VecZnxBig].
/// The returned [VecZnxBig] shares the backing array /// The returned [VecZnxBig] shares the backing array
/// with the original [VecZnxDft]. /// with the original [VecZnxDft].
pub fn as_vec_znx_big(&mut self) -> VecZnxBig { pub fn as_vec_znx_big(&mut self) -> VecZnxBig {
VecZnxBig(self.0 as *mut vec_znx_big::vec_znx_bigcoeff_t, self.1) VecZnxBig {
data: Vec::new(),
ptr: self.ptr,
n: self.n,
cols: self.cols,
backend: self.backend,
} }
}
pub fn n(&self) -> usize {
self.n
}
pub fn cols(&self) -> usize { pub fn cols(&self) -> usize {
self.1 self.cols
}
pub fn backend(&self) -> MODULETYPE {
self.backend
}
/// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft].
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
/// The length of the returned array is cols * n.
pub fn raw<T>(&self, module: &Module) -> &[T] {
let ptr: *const T = self.ptr as *const T;
let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::<T>();
unsafe { &std::slice::from_raw_parts(ptr, len) }
}
pub fn at<T>(&self, module: &Module, col_i: usize) -> &[T] {
&self.raw::<T>(module)[col_i * module.n()..(col_i + 1) * module.n()]
}
/// Returns a mutable reference of `T` of the entire contiguous array of the [VecZnxDft].
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
/// The length of the returned array is cols * n.
pub fn raw_mut<T>(&self, module: &Module) -> &mut [T] {
let ptr: *mut T = self.ptr as *mut T;
let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::<T>();
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
}
pub fn at_mut<T>(&self, module: &Module, col_i: usize) -> &mut [T] {
&mut self.raw_mut::<T>(module)[col_i * module.n()..(col_i + 1) * module.n()]
} }
} }
@@ -72,12 +143,20 @@ pub trait VecZnxDftOps {
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
); );
fn vec_znx_dft<T: VecZnxApi + Infos>(&self, b: &mut VecZnxDft, a: &T, a_limbs: usize); fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_limbs: usize);
} }
impl VecZnxDftOps for Module { impl VecZnxDftOps for Module {
fn new_vec_znx_dft(&self, cols: usize) -> VecZnxDft { fn new_vec_znx_dft(&self, cols: usize) -> VecZnxDft {
unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, cols as u64), cols) } let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_vec_znx_dft(cols));
let ptr: *mut u8 = data.as_mut_ptr();
VecZnxDft {
data: data,
ptr: ptr,
n: self.n(),
cols: cols,
backend: self.backend(),
}
} }
fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
@@ -91,11 +170,11 @@ impl VecZnxDftOps for Module {
{ {
assert_alignement(tmp_bytes.as_ptr()) assert_alignement(tmp_bytes.as_ptr())
} }
VecZnxDft::from_bytes(cols, tmp_bytes) VecZnxDft::from_bytes(self, cols, tmp_bytes)
} }
fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize { fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize {
unsafe { bytes_of_vec_znx_dft(self.0, cols as u64) as usize } unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize }
} }
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) { fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) {
@@ -106,19 +185,25 @@ impl VecZnxDftOps for Module {
a_limbs a_limbs
); );
unsafe { unsafe {
vec_znx_dft::vec_znx_idft_tmp_a(self.0, b.0, b.cols() as u64, a.0, a_limbs as u64) vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
b.ptr as *mut vec_znx_bigcoeff_t,
b.cols() as u64,
a.ptr as *mut vec_znx_dft_t,
a_limbs as u64,
)
} }
} }
fn vec_znx_idft_tmp_bytes(&self) -> usize { fn vec_znx_idft_tmp_bytes(&self) -> usize {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.0) as usize } unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
} }
/// b <- DFT(a) /// b <- DFT(a)
/// ///
/// # Panics /// # Panics
/// If b.cols < a_cols /// If b.cols < a_cols
fn vec_znx_dft<T: VecZnxApi + Infos>(&self, b: &mut VecZnxDft, a: &T, a_cols: usize) { fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize) {
debug_assert!( debug_assert!(
b.cols() >= a_cols, b.cols() >= a_cols,
"invalid a_cols: b.cols()={} < a_cols={}", "invalid a_cols: b.cols()={} < a_cols={}",
@@ -127,8 +212,8 @@ impl VecZnxDftOps for Module {
); );
unsafe { unsafe {
vec_znx_dft::vec_znx_dft( vec_znx_dft::vec_znx_dft(
self.0, self.ptr,
b.0, b.ptr as *mut vec_znx_dft_t,
b.cols() as u64, b.cols() as u64,
a.as_ptr(), a.as_ptr(),
a_cols as u64, a_cols as u64,
@@ -169,10 +254,10 @@ impl VecZnxDftOps for Module {
} }
unsafe { unsafe {
vec_znx_dft::vec_znx_idft( vec_znx_dft::vec_znx_idft(
self.0, self.ptr,
b.0, b.ptr as *mut vec_znx_bigcoeff_t,
a.cols() as u64, a.cols() as u64,
a.0, a.ptr as *mut vec_znx_dft_t,
a_cols as u64, a_cols as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
) )

View File

@@ -1,5 +1,6 @@
use crate::ffi::vmp; use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::{assert_alignement, Infos, Module, VecZnxApi, VecZnxDft}; use crate::ffi::vmp::{self, vmp_pmat_t};
use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, MODULETYPE};
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array. /// stored as a 3D matrix in the DFT domain in a single contiguous array.
@@ -11,20 +12,75 @@ use crate::{assert_alignement, Infos, Module, VecZnxApi, VecZnxDft};
/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx] and a [VmpPMat]. /// [VmpPMat] is used to permform a vector matrix product between a [VecZnx] and a [VmpPMat].
/// See the trait [VmpPMatOps] for additional information. /// See the trait [VmpPMatOps] for additional information.
pub struct VmpPMat { pub struct VmpPMat {
/// The pointer to the C memory. /// Raw data, is empty if borrowing scratch space.
pub data: *mut vmp::vmp_pmat_t, data: Vec<u8>,
/// Pointer to data. Can point to scratch space.
ptr: *mut u8,
/// The number of [VecZnxDft]. /// The number of [VecZnxDft].
pub rows: usize, rows: usize,
/// The number of cols in each [VecZnxDft]. /// The number of cols in each [VecZnxDft].
pub cols: usize, cols: usize,
/// The ring degree of each [VecZnxDft]. /// The ring degree of each [VecZnxDft].
pub n: usize, n: usize,
backend: MODULETYPE,
}
impl Infos for VmpPMat {
/// Returns the ring dimension of the [VmpPMat].
fn n(&self) -> usize {
self.n
}
fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
/// Returns the number of rows (i.e. of [VecZnxDft]) of the [VmpPMat]
fn rows(&self) -> usize {
self.rows
}
/// Returns the number of cols of the [VmpPMat].
/// The number of cols refers to the number of cols
/// of each [VecZnxDft].
/// This method is equivalent to [Self::cols].
fn cols(&self) -> usize {
self.cols
}
} }
impl VmpPMat { impl VmpPMat {
/// Returns the pointer to the [vmp_pmat_t]. pub fn as_ptr(&self) -> *const u8 {
pub fn data(&self) -> *mut vmp::vmp_pmat_t { self.ptr
self.data }
pub fn as_mut_ptr(&self) -> *mut u8 {
self.ptr
}
pub fn borrowed(&self) -> bool{
self.data.len() == 0
}
/// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat].
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
/// The length of the returned array is rows * cols * n.
pub fn raw<T>(&self) -> &[T] {
let ptr: *const T = self.ptr as *const T;
let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::<T>();
unsafe { &std::slice::from_raw_parts(ptr, len) }
}
/// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat].
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
/// The length of the returned array is rows * cols * n.
pub fn raw_mut<T>(&self) -> &mut [T] {
let ptr: *mut T = self.ptr as *mut T;
let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::<T>();
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
} }
/// Returns a copy of the backend array at index (i, j) of the [VmpPMat]. /// Returns a copy of the backend array at index (i, j) of the [VmpPMat].
@@ -36,16 +92,16 @@ impl VmpPMat {
/// * `row`: row index (i). /// * `row`: row index (i).
/// * `col`: col index (j). /// * `col`: col index (j).
pub fn at<T: Default + Copy>(&self, row: usize, col: usize) -> Vec<T> { pub fn at<T: Default + Copy>(&self, row: usize, col: usize) -> Vec<T> {
let mut res: Vec<T> = vec![T::default(); self.n]; let mut res: Vec<T> = alloc_aligned(self.n);
if self.n < 8 { if self.n < 8 {
res.copy_from_slice( res.copy_from_slice(
&self.get_backend_array::<T>()[(row + col * self.rows()) * self.n() &self.raw::<T>()[(row + col * self.rows()) * self.n()
..(row + col * self.rows()) * (self.n() + 1)], ..(row + col * self.rows()) * (self.n() + 1)],
); );
} else { } else {
(0..self.n >> 3).for_each(|blk| { (0..self.n >> 3).for_each(|blk| {
res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.get_array(row, col, blk)[..8]); res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]);
}); });
} }
@@ -54,33 +110,25 @@ impl VmpPMat {
/// When using [`crate::FFT64`] as backend, `T` should be [f64]. /// When using [`crate::FFT64`] as backend, `T` should be [f64].
/// When using [`crate::NTT120`] as backend, `T` should be [i64]. /// When using [`crate::NTT120`] as backend, `T` should be [i64].
fn get_array<T>(&self, row: usize, col: usize, blk: usize) -> &[T] { fn at_block<T>(&self, row: usize, col: usize, blk: usize) -> &[T] {
let nrows: usize = self.rows(); let nrows: usize = self.rows();
let ncols: usize = self.cols(); let ncols: usize = self.cols();
if col == (ncols - 1) && (ncols & 1 == 1) { if col == (ncols - 1) && (ncols & 1 == 1) {
&self.get_backend_array::<T>()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..] &self.raw::<T>()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..]
} else { } else {
&self.get_backend_array::<T>()[blk * nrows * ncols * 8 &self.raw::<T>()[blk * nrows * ncols * 8
+ (col / 2) * (2 * nrows) * 8 + (col / 2) * (2 * nrows) * 8
+ row * 2 * 8 + row * 2 * 8
+ (col % 2) * 8..] + (col % 2) * 8..]
} }
} }
/// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat].
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
/// The length of the returned array is rows * cols * n.
pub fn get_backend_array<T>(&self) -> &[T] {
let ptr: *const T = self.data as *const T;
let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::<T>();
unsafe { &std::slice::from_raw_parts(ptr, len) }
}
} }
/// This trait implements methods for vector matrix product, /// This trait implements methods for vector matrix product,
/// that is, multiplying a [VecZnx] with a [VmpPMat]. /// that is, multiplying a [VecZnx] with a [VmpPMat].
pub trait VmpPMatOps { pub trait VmpPMatOps {
fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize) -> usize;
/// Allocates a new [VmpPMat] with the given number of rows and columns. /// Allocates a new [VmpPMat] with the given number of rows and columns.
/// ///
/// # Arguments /// # Arguments
@@ -106,26 +154,6 @@ pub trait VmpPMatOps {
/// * `b`: [VmpPMat] on which the values are encoded. /// * `b`: [VmpPMat] on which the values are encoded.
/// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat]. /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat].
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
///
/// # Example
/// ```
/// use base2k::{Module, VmpPMat, VmpPMatOps, FFT64, Free, alloc_aligned};
/// use std::cmp::min;
///
/// let n: usize = 1024;
/// let module = Module::new::<FFT64>(n);
/// let rows = 5;
/// let cols = 6;
///
/// let mut b_mat: Vec<i64> = vec![0i64;n * cols * rows];
///
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols));
///
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
/// module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat, &mut buf);
///
/// vmp_pmat.free() // don't forget to free the memory once vmp_pmat is not needed anymore.
/// ```
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]); fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]);
/// Prepares a [VmpPMat] from a vector of [VecZnx]. /// Prepares a [VmpPMat] from a vector of [VecZnx].
@@ -137,32 +165,6 @@ pub trait VmpPMatOps {
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
/// ///
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
///
/// # Example
/// ```
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free, alloc_aligned};
/// use std::cmp::min;
///
/// let n: usize = 1024;
/// let module: Module = Module::new::<FFT64>(n);
/// let rows: usize = 5;
/// let cols: usize = 6;
///
/// let mut vecznx: Vec<VecZnx>= Vec::new();
/// (0..rows).for_each(|_|{
/// vecznx.push(module.new_vec_znx(cols));
/// });
///
/// let slices: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect();
///
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols));
///
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
/// module.vmp_prepare_dblptr(&mut vmp_pmat, &slices, &mut buf);
///
/// vmp_pmat.free();
/// module.free();
/// ```
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]); fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]);
/// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx]. /// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx].
@@ -175,26 +177,6 @@ pub trait VmpPMatOps {
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
/// ///
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
/// /// # Example
/// ```
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free, alloc_aligned};
/// use std::cmp::min;
///
/// let n: usize = 1024;
/// let module: Module = Module::new::<FFT64>(n);
/// let rows: usize = 5;
/// let cols: usize = 6;
///
/// let vecznx = module.new_vec_znx(cols);
///
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols));
///
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
/// module.vmp_prepare_row(&mut vmp_pmat, vecznx.raw(), 0, &mut buf);
///
/// vmp_pmat.free();
/// module.free();
/// ```
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
@@ -237,38 +219,7 @@ pub trait VmpPMatOps {
/// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `a`: the left operand [VecZnx] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
/// fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]);
/// # Example
/// ```
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, alloc_aligned};
///
/// let n = 1024;
///
/// let module: Module = Module::new::<FFT64>(n);
/// let cols: usize = 5;
///
/// let rows: usize = cols;
/// let cols: usize = cols + 1;
/// let c_cols: usize = cols;
/// let a_cols: usize = cols;
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_apply_dft_tmp_bytes(c_cols, a_cols, rows, cols));
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
///
/// let a: VecZnx = module.new_vec_znx(cols);
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
/// module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf);
///
/// c_dft.free();
/// vmp_pmat.free();
/// module.free();
/// ```
fn vmp_apply_dft<T: VecZnxApi + Infos>(
&self,
c: &mut VecZnxDft,
a: &T,
b: &VmpPMat,
buf: &mut [u8],
);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft]. /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft].
/// ///
@@ -311,32 +262,6 @@ pub trait VmpPMatOps {
/// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `a`: the left operand [VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// # Example
/// ```
/// use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, alloc_aligned};
///
/// let n = 1024;
///
/// let module: Module = Module::new::<FFT64>(n);
/// let cols: usize = 5;
///
/// let rows: usize = cols;
/// let cols: usize = cols + 1;
/// let c_cols: usize = cols;
/// let a_cols: usize = cols;
/// let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, rows, cols));
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
///
/// let a_dft: VecZnxDft = module.new_vec_znx_dft(cols);
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
/// module.vmp_apply_dft_to_dft(&mut c_dft, &a_dft, &vmp_pmat, &mut tmp_bytes);
///
/// a_dft.free();
/// c_dft.free();
/// vmp_pmat.free();
/// module.free();
/// ```
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place. /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place.
@@ -363,46 +288,29 @@ pub trait VmpPMatOps {
/// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft].
/// * `a`: the right operand [VmpPMat] of the vector matrix product. /// * `a`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// # Example
/// ```rust
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, VecZnxDftOps,alloc_aligned};
///
/// let n = 1024;
///
/// let module: Module = Module::new::<FFT64>(n);
/// let cols: usize = 5;
///
/// let rows: usize = cols;
/// let cols: usize = cols + 1;
/// let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_apply_dft_to_dft_tmp_bytes(cols, cols, rows, cols));
/// let a: VecZnx = module.new_vec_znx(cols);
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
///
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
/// module.vmp_apply_dft_to_dft_inplace(&mut c_dft, &vmp_pmat, &mut tmp_bytes);
///
/// c_dft.free();
/// vmp_pmat.free();
/// module.free();
/// ```
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]); fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]);
} }
impl VmpPMatOps for Module { impl VmpPMatOps for Module {
fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat { fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize) -> usize {
unsafe { unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, cols as u64) as usize }
VmpPMat {
data: vmp::new_vmp_pmat(self.0, rows as u64, cols as u64),
rows,
cols,
n: self.n(),
} }
fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat {
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_vmp_pmat(rows, cols));
let ptr: *mut u8 = data.as_mut_ptr();
VmpPMat {
data: data,
ptr: ptr,
n: self.n(),
cols: cols,
rows: rows,
backend: self.backend(),
} }
} }
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize) -> usize { fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize) -> usize {
unsafe { vmp::vmp_prepare_tmp_bytes(self.0, rows as u64, cols as u64) as usize } unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, cols as u64) as usize }
} }
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) { fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) {
@@ -414,8 +322,8 @@ impl VmpPMatOps for Module {
} }
unsafe { unsafe {
vmp::vmp_prepare_contiguous( vmp::vmp_prepare_contiguous(
self.0, self.ptr,
b.data(), b.as_mut_ptr() as *mut vmp_pmat_t,
a.as_ptr(), a.as_ptr(),
b.rows() as u64, b.rows() as u64,
b.cols() as u64, b.cols() as u64,
@@ -437,8 +345,8 @@ impl VmpPMatOps for Module {
} }
unsafe { unsafe {
vmp::vmp_prepare_dblptr( vmp::vmp_prepare_dblptr(
self.0, self.ptr,
b.data(), b.as_mut_ptr() as *mut vmp_pmat_t,
ptrs.as_ptr(), ptrs.as_ptr(),
b.rows() as u64, b.rows() as u64,
b.cols() as u64, b.cols() as u64,
@@ -456,8 +364,8 @@ impl VmpPMatOps for Module {
} }
unsafe { unsafe {
vmp::vmp_prepare_row( vmp::vmp_prepare_row(
self.0, self.ptr,
b.data(), b.as_mut_ptr() as *mut vmp_pmat_t,
a.as_ptr(), a.as_ptr(),
row_i as u64, row_i as u64,
b.rows() as u64, b.rows() as u64,
@@ -476,7 +384,7 @@ impl VmpPMatOps for Module {
) -> usize { ) -> usize {
unsafe { unsafe {
vmp::vmp_apply_dft_tmp_bytes( vmp::vmp_apply_dft_tmp_bytes(
self.0, self.ptr,
res_cols as u64, res_cols as u64,
a_cols as u64, a_cols as u64,
gct_rows as u64, gct_rows as u64,
@@ -485,13 +393,7 @@ impl VmpPMatOps for Module {
} }
} }
fn vmp_apply_dft<T: VecZnxApi + Infos>( fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
&self,
c: &mut VecZnxDft,
a: &T,
b: &VmpPMat,
tmp_bytes: &mut [u8],
) {
debug_assert!( debug_assert!(
tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()) tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
); );
@@ -501,13 +403,13 @@ impl VmpPMatOps for Module {
} }
unsafe { unsafe {
vmp::vmp_apply_dft( vmp::vmp_apply_dft(
self.0, self.ptr,
c.0, c.ptr as *mut vec_znx_dft_t,
c.cols() as u64, c.cols() as u64,
a.as_ptr(), a.as_ptr(),
a.cols() as u64, a.cols() as u64,
a.n() as u64, a.n() as u64,
b.data(), b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64, b.rows() as u64,
b.cols() as u64, b.cols() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
@@ -524,7 +426,7 @@ impl VmpPMatOps for Module {
) -> usize { ) -> usize {
unsafe { unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes( vmp::vmp_apply_dft_to_dft_tmp_bytes(
self.0, self.ptr,
res_cols as u64, res_cols as u64,
a_cols as u64, a_cols as u64,
gct_rows as u64, gct_rows as u64,
@@ -550,12 +452,12 @@ impl VmpPMatOps for Module {
} }
unsafe { unsafe {
vmp::vmp_apply_dft_to_dft( vmp::vmp_apply_dft_to_dft(
self.0, self.ptr,
c.0, c.ptr as *mut vec_znx_dft_t,
c.cols() as u64, c.cols() as u64,
a.0, a.ptr as *const vec_znx_dft_t,
a.cols() as u64, a.cols() as u64,
b.data(), b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64, b.rows() as u64,
b.cols() as u64, b.cols() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
@@ -574,12 +476,12 @@ impl VmpPMatOps for Module {
} }
unsafe { unsafe {
vmp::vmp_apply_dft_to_dft( vmp::vmp_apply_dft_to_dft(
self.0, self.ptr,
b.0, b.ptr as *mut vec_znx_dft_t,
b.cols() as u64, b.cols() as u64,
b.0, b.ptr as *mut vec_znx_dft_t,
b.cols() as u64, b.cols() as u64,
a.data(), a.as_ptr() as *const vmp_pmat_t,
a.rows() as u64, a.rows() as u64,
a.cols() as u64, a.cols() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),

View File

@@ -11,6 +11,7 @@ criterion = {workspace = true}
base2k = {path="../base2k"} base2k = {path="../base2k"}
sampling = {path="../sampling"} sampling = {path="../sampling"}
rand_distr = {workspace = true} rand_distr = {workspace = true}
itertools = {workspace = true}
[[bench]] [[bench]]
name = "gadget_product" name = "gadget_product"

View File

@@ -1,5 +1,5 @@
use base2k::{ use base2k::{
FFT64, Infos, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, Infos, MODULETYPE, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps,
VmpPMat, alloc_aligned_u8, VmpPMat, alloc_aligned_u8,
}; };
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
@@ -36,6 +36,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
for log_n in 10..11 { for log_n in 10..11 {
let params_lit: ParametersLiteral = ParametersLiteral { let params_lit: ParametersLiteral = ParametersLiteral {
backend: MODULETYPE::FFT64,
log_n: log_n, log_n: log_n,
log_q: 32, log_q: 32,
log_p: 0, log_p: 0,
@@ -45,7 +46,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
xs: 128, xs: 128,
}; };
let params: Parameters = Parameters::new::<FFT64>(&params_lit); let params: Parameters = Parameters::new(&params_lit);
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8( let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
@@ -101,7 +102,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(params.log_q()); let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(params.log_q());
params.encrypt_rlwe_sk_thread_safe( params.encrypt_rlwe_sk(
&mut ct, &mut ct,
None, None,
&sk0_svp_ppol, &sk0_svp_ppol,

View File

@@ -1,4 +1,4 @@
use base2k::{Encoding, FFT64, SvpPPolOps, VecZnx, VecZnxApi}; use base2k::{Encoding, SvpPPolOps, VecZnx, alloc_aligned};
use rlwe::{ use rlwe::{
ciphertext::Ciphertext, ciphertext::Ciphertext,
elem::ElemCommon, elem::ElemCommon,
@@ -10,6 +10,7 @@ use sampling::source::Source;
fn main() { fn main() {
let params_lit: ParametersLiteral = ParametersLiteral { let params_lit: ParametersLiteral = ParametersLiteral {
backend: base2k::MODULETYPE::FFT64,
log_n: 10, log_n: 10,
log_q: 54, log_q: 54,
log_p: 0, log_p: 0,
@@ -19,13 +20,12 @@ fn main() {
xs: 128, xs: 128,
}; };
let params: Parameters = Parameters::new::<FFT64>(&params_lit); let params: Parameters = Parameters::new(&params_lit);
let mut tmp_bytes: Vec<u8> = vec![ let mut tmp_bytes: Vec<u8> = alloc_aligned(
0u8;
params.decrypt_rlwe_tmp_byte(params.log_q()) params.decrypt_rlwe_tmp_byte(params.log_q())
| params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()),
]; );
let mut source: Source = Source::new([0; 32]); let mut source: Source = Source::new([0; 32]);
let mut sk: SecretKey = SecretKey::new(params.module()); let mut sk: SecretKey = SecretKey::new(params.module());
@@ -35,7 +35,7 @@ fn main() {
want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
let mut pt: Plaintext<VecZnx> = params.new_plaintext(params.log_q()); let mut pt: Plaintext = params.new_plaintext(params.log_q());
let log_base2k = pt.log_base2k(); let log_base2k = pt.log_base2k();
@@ -56,7 +56,7 @@ fn main() {
let mut sk_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); let mut sk_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
params.module().svp_prepare(&mut sk_svp_ppol, &sk.0); params.module().svp_prepare(&mut sk_svp_ppol, &sk.0);
params.encrypt_rlwe_sk_thread_safe( params.encrypt_rlwe_sk(
&mut ct, &mut ct,
Some(&pt), Some(&pt),
&sk_svp_ppol, &sk_svp_ppol,
@@ -66,7 +66,6 @@ fn main() {
); );
params.decrypt_rlwe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes); params.decrypt_rlwe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes);
pt.0.value[0].print(pt.cols(), 16); pt.0.value[0].print(pt.cols(), 16);
let mut have = vec![i64::default(); params.n()]; let mut have = vec![i64::default(); params.n()];

View File

@@ -0,0 +1,151 @@
use base2k::{
Encoding, Infos, Module, Sampling, SvpPPol, SvpPPolOps, VecZnx, VecZnxDftOps, VecZnxOps,
VmpPMat, VmpPMatOps, is_aligned,
};
use itertools::izip;
use rlwe::ciphertext::{Ciphertext, new_gadget_ciphertext};
use rlwe::elem::ElemCommon;
use rlwe::encryptor::encrypt_rlwe_sk;
use rlwe::keys::SecretKey;
use rlwe::plaintext::Plaintext;
use sampling::source::{Source, new_seed};
fn main() {
let n: usize = 32;
let module: Module = Module::new(n, base2k::MODULETYPE::FFT64);
let log_base2k: usize = 16;
let log_k: usize = 32;
let cols: usize = 4;
let mut a: VecZnx = module.new_vec_znx(cols);
let mut data: Vec<i64> = vec![0i64; n];
data[0] = 0;
data[1] = 0;
a.encode_vec_i64(log_base2k, log_k, &data, 16);
let mut a_dft: base2k::VecZnxDft = module.new_vec_znx_dft(cols);
module.vec_znx_dft(&mut a_dft, &a, cols);
(0..cols).for_each(|i| {
println!("{:?}", a_dft.at::<f64>(&module, i));
})
}
pub struct GadgetCiphertextProtocol {}
impl GadgetCiphertextProtocol {
pub fn new() -> GadgetCiphertextProtocol {
Self {}
}
pub fn allocate(
module: &Module,
log_base2k: usize,
rows: usize,
log_q: usize,
) -> GadgetCiphertextShare {
GadgetCiphertextShare::new(module, log_base2k, rows, log_q)
}
pub fn gen_share(
module: &Module,
sk: &SecretKey,
pt: &Plaintext,
seed: &[u8; 32],
share: &mut GadgetCiphertextShare,
tmp_bytes: &mut [u8],
) {
share.seed.copy_from_slice(seed);
let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(*seed);
let mut sk_ppol: SvpPPol = module.new_svp_ppol();
sk.prepare(module, &mut sk_ppol);
share.value.iter_mut().for_each(|ai| {
//let elem = Elem<VecZnx>{};
//encrypt_rlwe_sk_thread_safe(module, ai, Some(pt.elem()), &sk_ppol, &mut source_xa, &mut source_xe, 3.2, tmp_bytes);
})
}
}
pub struct GadgetCiphertextShare {
pub seed: [u8; 32],
pub log_q: usize,
pub log_base2k: usize,
pub value: Vec<VecZnx>,
}
impl GadgetCiphertextShare {
pub fn new(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Self {
let value: Vec<VecZnx> = Vec::new();
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
(0..rows).for_each(|_| {
let vec_znx: VecZnx = module.new_vec_znx(cols);
});
Self {
seed: [u8::default(); 32],
log_q: log_q,
log_base2k: log_base2k,
value: value,
}
}
pub fn rows(&self) -> usize {
self.value.len()
}
pub fn cols(&self) -> usize {
self.value[0].cols()
}
pub fn aggregate_inplace(&mut self, module: &Module, a: &GadgetCiphertextShare) {
izip!(self.value.iter_mut(), a.value.iter()).for_each(|(bi, ai)| {
module.vec_znx_add_inplace(bi, ai);
})
}
pub fn get(&self, module: &Module, b: &mut Ciphertext<VmpPMat>, tmp_bytes: &mut [u8]) {
assert!(is_aligned(tmp_bytes.as_ptr()));
let rows: usize = b.rows();
let cols: usize = b.cols();
assert!(tmp_bytes.len() >= gadget_ciphertext_share_get_tmp_bytes(module, rows, cols));
assert_eq!(self.value.len(), rows);
assert_eq!(self.value[0].cols(), cols);
let (tmp_bytes_vmp_prepare_row, tmp_bytes_vec_znx) =
tmp_bytes.split_at_mut(module.vmp_prepare_tmp_bytes(rows, cols));
let mut c: VecZnx = VecZnx::from_bytes_borrow(module.n(), cols, tmp_bytes_vec_znx);
let mut source: Source = Source::new(self.seed);
(0..self.value.len()).for_each(|row_i| {
module.vmp_prepare_row(
b.at_mut(0),
self.value[row_i].raw(),
row_i,
tmp_bytes_vmp_prepare_row,
);
module.fill_uniform(self.log_base2k, &mut c, cols, &mut source);
module.vmp_prepare_row(b.at_mut(1), c.raw(), row_i, tmp_bytes_vmp_prepare_row)
})
}
pub fn get_new(&self, module: &Module, tmp_bytes: &mut [u8]) -> Ciphertext<VmpPMat> {
let mut b: Ciphertext<VmpPMat> =
new_gadget_ciphertext(module, self.log_base2k, self.rows(), self.log_q);
self.get(module, &mut b, tmp_bytes);
b
}
}
pub fn gadget_ciphertext_share_get_tmp_bytes(module: &Module, rows: usize, cols: usize) -> usize {
module.vmp_prepare_tmp_bytes(rows, cols) + module.bytes_of_vec_znx(cols)
}
pub struct CircularCiphertextProtocol {}
pub struct CircularGadgetCiphertextProtocol {}

View File

@@ -1,11 +1,11 @@
use crate::{ use crate::{
ciphertext::Ciphertext, ciphertext::Ciphertext,
elem::{Elem, ElemCommon, VecZnxCommon}, elem::{Elem, ElemCommon},
keys::SecretKey, keys::SecretKey,
parameters::Parameters, parameters::Parameters,
plaintext::Plaintext, plaintext::Plaintext,
}; };
use base2k::{Module, SvpPPol, SvpPPolOps, VecZnxBigOps, VecZnxDft, VecZnxDftOps}; use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZnxDftOps};
use std::cmp::min; use std::cmp::min;
pub struct Decryptor { pub struct Decryptor {
@@ -32,30 +32,24 @@ impl Parameters {
) )
} }
pub fn decrypt_rlwe<T>( pub fn decrypt_rlwe(
&self, &self,
res: &mut Plaintext<T>, res: &mut Plaintext,
ct: &Ciphertext<T>, ct: &Ciphertext<VecZnx>,
sk: &SvpPPol, sk: &SvpPPol,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) where ) {
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemCommon<T>,
{
decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes) decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes)
} }
} }
pub fn decrypt_rlwe<T>( pub fn decrypt_rlwe(
module: &Module, module: &Module,
res: &mut Elem<T>, res: &mut Elem<VecZnx>,
a: &Elem<T>, a: &Elem<VecZnx>,
sk: &SvpPPol, sk: &SvpPPol,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) where ) {
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemCommon<T>,
{
let cols: usize = a.cols(); let cols: usize = a.cols();
assert!( assert!(
@@ -65,9 +59,11 @@ pub fn decrypt_rlwe<T>(
decrypt_rlwe_tmp_byte(module, cols) decrypt_rlwe_tmp_byte(module, cols)
); );
let res_dft_bytes: usize = module.bytes_of_vec_znx_dft(cols); let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) =
tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let mut res_dft: VecZnxDft = VecZnxDft::from_bytes(a.cols(), tmp_bytes); let mut res_dft: VecZnxDft =
VecZnxDft::from_bytes_borrow(module, a.cols(), tmp_bytes_vec_znx_dft);
let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big(); let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big();
// res_dft <- DFT(ct[1]) * DFT(sk) // res_dft <- DFT(ct[1]) * DFT(sk)
@@ -77,12 +73,7 @@ pub fn decrypt_rlwe<T>(
// res_big <- ct[1] x sk + ct[0] // res_big <- ct[1] x sk + ct[0]
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0)); module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
// res <- normalize(ct[1] x sk + ct[0]) // res <- normalize(ct[1] x sk + ct[0])
module.vec_znx_big_normalize( module.vec_znx_big_normalize(a.log_base2k(), res.at_mut(0), &res_big, tmp_bytes_normalize);
a.log_base2k(),
res.at_mut(0),
&res_big,
&mut tmp_bytes[res_dft_bytes..],
);
res.log_base2k = a.log_base2k(); res.log_base2k = a.log_base2k();
res.log_q = min(res.log_q(), a.log_q()); res.log_q = min(res.log_q(), a.log_q());

View File

@@ -1,17 +1,7 @@
use base2k::{Infos, Module, VecZnx, VecZnxBorrow, VecZnxOps, VmpPMat, VmpPMatOps}; use base2k::{Infos, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps};
use crate::parameters::Parameters; use crate::parameters::Parameters;
impl Parameters {
pub fn elem_from_bytes<T>(&self, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<T>
where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
Elem::<T>::from_bytes(self.module(), self.log_base2k(), log_q, size, bytes)
}
}
pub struct Elem<T> { pub struct Elem<T> {
pub value: Vec<T>, pub value: Vec<T>,
pub log_base2k: usize, pub log_base2k: usize,
@@ -19,26 +9,26 @@ pub struct Elem<T> {
pub log_scale: usize, pub log_scale: usize,
} }
pub trait VecZnxCommon: base2k::VecZnxCommon {} pub trait ElemVecZnx {
impl VecZnxCommon for VecZnx {}
impl VecZnxCommon for VecZnxBorrow {}
pub trait ElemVecZnx<T: VecZnxCommon<Owned = T>> {
fn from_bytes( fn from_bytes(
module: &Module, module: &Module,
log_base2k: usize, log_base2k: usize,
log_q: usize, log_q: usize,
size: usize, size: usize,
bytes: &mut [u8], bytes: &mut [u8],
) -> Elem<T>; ) -> Elem<VecZnx>;
fn from_bytes_borrow(
module: &Module,
log_base2k: usize,
log_q: usize,
size: usize,
bytes: &mut [u8],
) -> Elem<VecZnx>;
fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize; fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize;
fn zero(&mut self); fn zero(&mut self);
} }
impl<T> ElemVecZnx<T> for Elem<T> impl ElemVecZnx for Elem<VecZnx> {
where
T: VecZnxCommon<Owned = T>,
{
fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize { fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize {
let cols = (log_q + log_base2k - 1) / log_base2k; let cols = (log_q + log_base2k - 1) / log_base2k;
module.n() * cols * size * 8 module.n() * cols * size * 8
@@ -50,16 +40,42 @@ where
log_q: usize, log_q: usize,
size: usize, size: usize,
bytes: &mut [u8], bytes: &mut [u8],
) -> Elem<T> { ) -> Elem<VecZnx> {
assert!(size > 0); assert!(size > 0);
let n: usize = module.n(); let n: usize = module.n();
assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size)); assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size));
let mut value: Vec<T> = Vec::new(); let mut value: Vec<VecZnx> = Vec::new();
let limbs: usize = (log_q + log_base2k - 1) / log_base2k; let limbs: usize = (log_q + log_base2k - 1) / log_base2k;
let elem_size = T::bytes_of(n, limbs); let elem_size = VecZnx::bytes_of(n, limbs);
let mut ptr: usize = 0; let mut ptr: usize = 0;
(0..size).for_each(|_| { (0..size).for_each(|_| {
value.push(T::from_bytes(n, limbs, &mut bytes[ptr..])); value.push(VecZnx::from_bytes(n, limbs, &mut bytes[ptr..]));
ptr += elem_size
});
Self {
value,
log_q,
log_base2k,
log_scale: 0,
}
}
fn from_bytes_borrow(
module: &Module,
log_base2k: usize,
log_q: usize,
size: usize,
bytes: &mut [u8],
) -> Elem<VecZnx> {
assert!(size > 0);
let n: usize = module.n();
assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size));
let mut value: Vec<VecZnx> = Vec::new();
let limbs: usize = (log_q + log_base2k - 1) / log_base2k;
let elem_size = VecZnx::bytes_of(n, limbs);
let mut ptr: usize = 0;
(0..size).for_each(|_| {
value.push(VecZnx::from_bytes_borrow(n, limbs, &mut bytes[ptr..]));
ptr += elem_size ptr += elem_size
}); });
Self { Self {

View File

@@ -1,12 +1,12 @@
use crate::ciphertext::Ciphertext; use crate::ciphertext::Ciphertext;
use crate::elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon}; use crate::elem::{Elem, ElemCommon, ElemVecZnx};
use crate::keys::SecretKey; use crate::keys::SecretKey;
use crate::parameters::Parameters; use crate::parameters::Parameters;
use crate::plaintext::Plaintext; use crate::plaintext::Plaintext;
use base2k::sampling::Sampling; use base2k::sampling::Sampling;
use base2k::{ use base2k::{
Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxBorrow, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps,
VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, VecZnxOps, VmpPMat, VmpPMatOps,
}; };
use sampling::source::{Source, new_seed}; use sampling::source::{Source, new_seed};
@@ -49,20 +49,17 @@ impl EncryptorSk {
self.source_xe = Source::new(seed) self.source_xe = Source::new(seed)
} }
pub fn encrypt_rlwe_sk<T>( pub fn encrypt_rlwe_sk(
&mut self, &mut self,
params: &Parameters, params: &Parameters,
ct: &mut Ciphertext<T>, ct: &mut Ciphertext<VecZnx>,
pt: Option<&Plaintext<T>>, pt: Option<&Plaintext>,
) where ) {
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemCommon<T>,
{
assert!( assert!(
self.initialized == true, self.initialized == true,
"invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]" "invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]"
); );
params.encrypt_rlwe_sk_thread_safe( params.encrypt_rlwe_sk(
ct, ct,
pt, pt,
&self.sk, &self.sk,
@@ -72,23 +69,20 @@ impl EncryptorSk {
); );
} }
pub fn encrypt_rlwe_sk_thread_safe<T>( pub fn encrypt_rlwe_sk_core(
&self, &self,
params: &Parameters, params: &Parameters,
ct: &mut Ciphertext<T>, ct: &mut Ciphertext<VecZnx>,
pt: Option<&Plaintext<T>>, pt: Option<&Plaintext>,
source_xa: &mut Source, source_xa: &mut Source,
source_xe: &mut Source, source_xe: &mut Source,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) where ) {
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemCommon<T>,
{
assert!( assert!(
self.initialized == true, self.initialized == true,
"invalid call to [EncryptorSk.encrypt_rlwe_sk_thread_safe]: [EncryptorSk] has not been initialized with a [SecretKey]" "invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]"
); );
params.encrypt_rlwe_sk_thread_safe(ct, pt, &self.sk, source_xa, source_xe, tmp_bytes); params.encrypt_rlwe_sk(ct, pt, &self.sk, source_xa, source_xe, tmp_bytes);
} }
} }
@@ -97,19 +91,16 @@ impl Parameters {
encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q) encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q)
} }
pub fn encrypt_rlwe_sk_thread_safe<T>( pub fn encrypt_rlwe_sk(
&self, &self,
ct: &mut Ciphertext<T>, ct: &mut Ciphertext<VecZnx>,
pt: Option<&Plaintext<T>>, pt: Option<&Plaintext>,
sk: &SvpPPol, sk: &SvpPPol,
source_xa: &mut Source, source_xa: &mut Source,
source_xe: &mut Source, source_xe: &mut Source,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) where ) {
T: VecZnxCommon<Owned = T>, encrypt_rlwe_sk(
Elem<T>: ElemCommon<T>,
{
encrypt_rlwe_sk_thread_safe(
self.module(), self.module(),
&mut ct.0, &mut ct.0,
pt.map(|pt| &pt.0), pt.map(|pt| &pt.0),
@@ -127,19 +118,16 @@ pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usiz
+ module.vec_znx_big_normalize_tmp_bytes() + module.vec_znx_big_normalize_tmp_bytes()
} }
pub fn encrypt_rlwe_sk_thread_safe<T>( pub fn encrypt_rlwe_sk(
module: &Module, module: &Module,
ct: &mut Elem<T>, ct: &mut Elem<VecZnx>,
pt: Option<&Elem<T>>, pt: Option<&Elem<VecZnx>>,
sk: &SvpPPol, sk: &SvpPPol,
source_xa: &mut Source, source_xa: &mut Source,
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) where ) {
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemCommon<T>,
{
let cols: usize = ct.cols(); let cols: usize = ct.cols();
let log_base2k: usize = ct.log_base2k(); let log_base2k: usize = ct.log_base2k();
let log_q: usize = ct.log_q(); let log_q: usize = ct.log_q();
@@ -153,16 +141,16 @@ pub fn encrypt_rlwe_sk_thread_safe<T>(
let log_q: usize = ct.log_q(); let log_q: usize = ct.log_q();
let log_base2k: usize = ct.log_base2k(); let log_base2k: usize = ct.log_base2k();
let c1: &mut T = ct.at_mut(1); let c1: &mut VecZnx = ct.at_mut(1);
// c1 <- Z_{2^prec}[X]/(X^{N}+1) // c1 <- Z_{2^prec}[X]/(X^{N}+1)
module.fill_uniform(log_base2k, c1, cols, source_xa); module.fill_uniform(log_base2k, c1, cols, source_xa);
let bytes_of_vec_znx_dft: usize = module.bytes_of_vec_znx_dft(cols); let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) =
tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
// Scratch space for DFT values // Scratch space for DFT values
let mut buf_dft: VecZnxDft = let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, cols, tmp_bytes_vec_znx_dft);
VecZnxDft::from_bytes(cols, &mut tmp_bytes[..bytes_of_vec_znx_dft]);
// Applies buf_dft <- DFT(s) * DFT(c1) // Applies buf_dft <- DFT(s) * DFT(c1)
module.svp_apply_dft(&mut buf_dft, sk, c1, cols); module.svp_apply_dft(&mut buf_dft, sk, c1, cols);
@@ -173,16 +161,14 @@ pub fn encrypt_rlwe_sk_thread_safe<T>(
// buf_big = s x c1 // buf_big = s x c1
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, cols); module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, cols);
let carry: &mut [u8] = &mut tmp_bytes[bytes_of_vec_znx_dft..];
// c0 <- -s x c1 + m // c0 <- -s x c1 + m
let c0: &mut T = ct.at_mut(0); let c0: &mut VecZnx = ct.at_mut(0);
if let Some(pt) = pt { if let Some(pt) = pt {
module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt.at(0)); module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt.at(0));
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, carry); module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
} else { } else {
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, carry); module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
module.vec_znx_negate_inplace(c0); module.vec_znx_negate_inplace(c0);
} }
@@ -211,7 +197,7 @@ pub fn encrypt_grlwe_sk_tmp_bytes(
) -> usize { ) -> usize {
let cols = (log_q + log_base2k - 1) / log_base2k; let cols = (log_q + log_base2k - 1) / log_base2k;
Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2) Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2)
+ Plaintext::<VecZnx>::bytes_of(module, log_base2k, log_q) + Plaintext::bytes_of(module, log_base2k, log_q)
+ encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) + encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
+ module.vmp_prepare_tmp_bytes(rows, cols) + module.vmp_prepare_tmp_bytes(rows, cols)
} }
@@ -240,25 +226,25 @@ pub fn encrypt_grlwe_sk(
min_tmp_bytes_len min_tmp_bytes_len
); );
let bytes_of_elem: usize = Elem::<VecZnxBorrow>::bytes_of(module, log_base2k, log_q, 2); let bytes_of_elem: usize = Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2);
let bytes_of_pt: usize = Plaintext::<VecZnx>::bytes_of(module, log_base2k, log_q); let bytes_of_pt: usize = Plaintext::bytes_of(module, log_base2k, log_q);
let bytes_of_enc_sk: usize = encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q); let bytes_of_enc_sk: usize = encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q);
let (tmp_bytes_pt, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_pt); let (tmp_bytes_pt, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_pt);
let (tmp_bytes_enc_sk, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_enc_sk); let (tmp_bytes_enc_sk, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_enc_sk);
let (tmp_bytes_elem, tmp_bytes_vmp_prepare_row) = tmp_bytes.split_at_mut(bytes_of_elem); let (tmp_bytes_elem, tmp_bytes_vmp_prepare_row) = tmp_bytes.split_at_mut(bytes_of_elem);
let mut tmp_elem: Elem<VecZnxBorrow> = let mut tmp_elem: Elem<VecZnx> =
Elem::<VecZnxBorrow>::from_bytes(module, log_base2k, ct.log_q(), 2, tmp_bytes_elem); Elem::<VecZnx>::from_bytes_borrow(module, log_base2k, ct.log_q(), 2, tmp_bytes_elem);
let mut tmp_pt: Plaintext<VecZnxBorrow> = let mut tmp_pt: Plaintext =
Plaintext::<VecZnxBorrow>::from_bytes(module, log_base2k, log_q, tmp_bytes_pt); Plaintext::from_bytes_borrow(module, log_base2k, log_q, tmp_bytes_pt);
(0..rows).for_each(|row_i| { (0..rows).for_each(|row_i| {
// Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i}) // Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i})
tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.0); tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.raw());
// Encrypts RLWE(m * 2^{-log_base2k*i}) // Encrypts RLWE(m * 2^{-log_base2k*i})
encrypt_rlwe_sk_thread_safe( encrypt_rlwe_sk(
module, module,
&mut tmp_elem, &mut tmp_elem,
Some(&tmp_pt.0), Some(&tmp_pt.0),

View File

@@ -1,9 +1,5 @@
use crate::{ use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
ciphertext::Ciphertext, use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon},
parameters::Parameters,
};
use base2k::{Module, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
use std::cmp::min; use std::cmp::min;
pub fn gadget_product_tmp_bytes( pub fn gadget_product_tmp_bytes(
@@ -53,19 +49,16 @@ impl Parameters {
/// ///
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i]) /// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
/// = (cs + m * a + e, c) with min(res_cols, b_cols) cols. /// = (cs + m * a + e, c) with min(res_cols, b_cols) cols.
pub fn gadget_product_core<T>( pub fn gadget_product_core(
module: &Module, module: &Module,
res_dft_0: &mut VecZnxDft, res_dft_0: &mut VecZnxDft,
res_dft_1: &mut VecZnxDft, res_dft_1: &mut VecZnxDft,
a: &T, a: &VecZnx,
a_cols: usize, a_cols: usize,
b: &Ciphertext<VmpPMat>, b: &Ciphertext<VmpPMat>,
b_cols: usize, b_cols: usize,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) where ) {
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
assert!(b_cols <= b.cols()); assert!(b_cols <= b.cols());
module.vec_znx_dft(res_dft_1, a, min(a_cols, b_cols)); module.vec_znx_dft(res_dft_1, a, min(a_cols, b_cols));
module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes); module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes);
@@ -104,7 +97,7 @@ mod test {
plaintext::Plaintext, plaintext::Plaintext,
}; };
use base2k::{ use base2k::{
FFT64, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxDft, Infos, MODULETYPE, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8,
}; };
use sampling::source::{Source, new_seed}; use sampling::source::{Source, new_seed};
@@ -117,6 +110,7 @@ mod test {
// Basic parameters with enough limbs to test edge cases // Basic parameters with enough limbs to test edge cases
let params_lit: ParametersLiteral = ParametersLiteral { let params_lit: ParametersLiteral = ParametersLiteral {
backend: MODULETYPE::FFT64,
log_n: 12, log_n: 12,
log_q: q_cols * log_base2k, log_q: q_cols * log_base2k,
log_p: p_cols * log_base2k, log_p: p_cols * log_base2k,
@@ -126,7 +120,7 @@ mod test {
xs: 1 << 11, xs: 1 << 11,
}; };
let params: Parameters = Parameters::new::<FFT64>(&params_lit); let params: Parameters = Parameters::new(&params_lit);
// scratch space // scratch space
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8( let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
@@ -213,8 +207,8 @@ mod test {
); );
// Plaintext for decrypted output of gadget product // Plaintext for decrypted output of gadget product
let mut pt: Plaintext<VecZnx> = let mut pt: Plaintext =
Plaintext::<VecZnx>::new(params.module(), params.log_base2k(), params.log_qp()); Plaintext::new(params.module(), params.log_base2k(), params.log_qp());
// Iterates over all possible cols values for input/output polynomials and gadget ciphertext. // Iterates over all possible cols values for input/output polynomials and gadget ciphertext.

View File

@@ -1,6 +1,6 @@
use crate::ciphertext::{Ciphertext, new_gadget_ciphertext}; use crate::ciphertext::{Ciphertext, new_gadget_ciphertext};
use crate::elem::{Elem, ElemCommon}; use crate::elem::{Elem, ElemCommon};
use crate::encryptor::{encrypt_rlwe_sk_thread_safe, encrypt_rlwe_sk_tmp_bytes}; use crate::encryptor::{encrypt_rlwe_sk, encrypt_rlwe_sk_tmp_bytes};
use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VmpPMat}; use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VmpPMat};
use sampling::source::Source; use sampling::source::Source;
@@ -40,7 +40,7 @@ impl PublicKey {
xe_source: &mut Source, xe_source: &mut Source,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) { ) {
encrypt_rlwe_sk_thread_safe( encrypt_rlwe_sk(
module, module,
&mut self.0, &mut self.0,
None, None,

View File

@@ -1,6 +1,7 @@
use base2k::module::{MODULETYPE, Module}; use base2k::module::{MODULETYPE, Module};
pub struct ParametersLiteral { pub struct ParametersLiteral {
pub backend: MODULETYPE,
pub log_n: usize, pub log_n: usize,
pub log_q: usize, pub log_q: usize,
pub log_p: usize, pub log_p: usize,
@@ -22,7 +23,7 @@ pub struct Parameters {
} }
impl Parameters { impl Parameters {
pub fn new<const MTYPE: MODULETYPE>(p: &ParametersLiteral) -> Self { pub fn new(p: &ParametersLiteral) -> Self {
assert!( assert!(
p.log_n + 2 * p.log_base2k <= 53, p.log_n + 2 * p.log_base2k <= 53,
"invalid parameters: p.log_n + 2*p.log_base2k > 53" "invalid parameters: p.log_n + 2*p.log_base2k > 53"
@@ -35,7 +36,7 @@ impl Parameters {
log_base2k: p.log_base2k, log_base2k: p.log_base2k,
xe: p.xe, xe: p.xe,
xs: p.xs, xs: p.xs,
module: Module::new::<MTYPE>(1 << p.log_n), module: Module::new(1 << p.log_n, p.backend),
} }
} }

View File

@@ -1,61 +1,65 @@
use crate::ciphertext::Ciphertext; use crate::ciphertext::Ciphertext;
use crate::elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon}; use crate::elem::{Elem, ElemCommon, ElemVecZnx};
use crate::parameters::Parameters; use crate::parameters::Parameters;
use base2k::{Module, VecZnx}; use base2k::{Module, VecZnx};
pub struct Plaintext<T>(pub Elem<T>); pub struct Plaintext(pub Elem<VecZnx>);
impl Parameters { impl Parameters {
pub fn new_plaintext(&self, log_q: usize) -> Plaintext<VecZnx> { pub fn new_plaintext(&self, log_q: usize) -> Plaintext {
Plaintext::new(self.module(), self.log_base2k(), log_q) Plaintext::new(self.module(), self.log_base2k(), log_q)
} }
pub fn bytes_of_plaintext<T>(&self, log_q: usize) -> usize pub fn bytes_of_plaintext(&self, log_q: usize) -> usize
where where {
T: VecZnxCommon<Owned = T>, Elem::<VecZnx>::bytes_of(self.module(), self.log_base2k(), log_q, 1)
Elem<T>: ElemVecZnx<T>,
{
Elem::<T>::bytes_of(self.module(), self.log_base2k(), log_q, 1)
} }
pub fn plaintext_from_bytes<T>(&self, log_q: usize, bytes: &mut [u8]) -> Plaintext<T> pub fn plaintext_from_bytes(&self, log_q: usize, bytes: &mut [u8]) -> Plaintext {
where Plaintext(Elem::<VecZnx>::from_bytes(
T: VecZnxCommon<Owned = T>, self.module(),
Elem<T>: ElemVecZnx<T>, self.log_base2k(),
{ log_q,
Plaintext::<T>(self.elem_from_bytes::<T>(log_q, 1, bytes)) 1,
bytes,
))
} }
} }
impl Plaintext<VecZnx> { impl Plaintext {
pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self { pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self {
Self(Elem::<VecZnx>::new(module, log_base2k, log_q, 1)) Self(Elem::<VecZnx>::new(module, log_base2k, log_q, 1))
} }
} }
impl<T> Plaintext<T> impl Plaintext {
where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize) -> usize { pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize) -> usize {
Elem::<T>::bytes_of(module, log_base2k, log_q, 1) Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 1)
} }
pub fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self { pub fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self {
Self(Elem::<T>::from_bytes(module, log_base2k, log_q, 1, bytes)) Self(Elem::<VecZnx>::from_bytes(
module, log_base2k, log_q, 1, bytes,
))
} }
pub fn as_ciphertext(&self) -> Ciphertext<T> { pub fn from_bytes_borrow(
unsafe { Ciphertext::<T>(std::ptr::read(&self.0)) } module: &Module,
log_base2k: usize,
log_q: usize,
bytes: &mut [u8],
) -> Self {
Self(Elem::<VecZnx>::from_bytes_borrow(
module, log_base2k, log_q, 1, bytes,
))
}
pub fn as_ciphertext(&self) -> Ciphertext<VecZnx> {
unsafe { Ciphertext::<VecZnx>(std::ptr::read(&self.0)) }
} }
} }
impl<T> ElemCommon<T> for Plaintext<T> impl ElemCommon<VecZnx> for Plaintext {
where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
fn n(&self) -> usize { fn n(&self) -> usize {
self.0.n() self.0.n()
} }
@@ -68,11 +72,11 @@ where
self.0.log_q self.0.log_q
} }
fn elem(&self) -> &Elem<T> { fn elem(&self) -> &Elem<VecZnx> {
&self.0 &self.0
} }
fn elem_mut(&mut self) -> &mut Elem<T> { fn elem_mut(&mut self) -> &mut Elem<VecZnx> {
&mut self.0 &mut self.0
} }
@@ -88,11 +92,11 @@ where
self.0.cols() self.0.cols()
} }
fn at(&self, i: usize) -> &T { fn at(&self, i: usize) -> &VecZnx {
self.0.at(i) self.0.at(i)
} }
fn at_mut(&mut self, i: usize) -> &mut T { fn at_mut(&mut self, i: usize) -> &mut VecZnx {
self.0.at_mut(i) self.0.at_mut(i)
} }

View File

@@ -1,20 +1,19 @@
use crate::{ use crate::{
ciphertext::Ciphertext, ciphertext::Ciphertext,
elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon}, elem::{Elem, ElemCommon, ElemVecZnx},
};
use base2k::{
Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps,
}; };
use base2k::{Module, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
use std::cmp::min; use std::cmp::min;
pub fn rgsw_product<T>( pub fn rgsw_product(
module: &Module, module: &Module,
_res: &mut Elem<T>, _res: &mut Elem<VecZnx>,
a: &Ciphertext<T>, a: &Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>, b: &Ciphertext<VmpPMat>,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) where ) {
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
let _log_base2k: usize = b.log_base2k(); let _log_base2k: usize = b.log_base2k();
let rows: usize = min(b.rows(), a.cols()); let rows: usize = min(b.rows(), a.cols());
let cols: usize = b.cols(); let cols: usize = b.cols();