This commit is contained in:
Janmajaya Mall
2025-05-02 20:49:04 +05:30
parent ca5e6d46c9
commit 3ed6fa8ab5
8 changed files with 770 additions and 443 deletions

View File

@@ -1,12 +1,16 @@
use crate::Backend; use crate::Backend;
use crate::DataView;
use crate::DataViewMut;
use crate::Module; use crate::Module;
use crate::ZnxView;
use crate::alloc_aligned;
use crate::assert_alignement; use crate::assert_alignement;
use crate::cast_mut; use crate::cast_mut;
use crate::ffi::znx; use crate::ffi::znx;
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxRsh, ZnxSliceSize, ZnxZero, switch_degree}; use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxRsh, ZnxZero, switch_degree};
use std::cmp::min; use std::{cmp::min, fmt};
pub const VEC_ZNX_ROWS: usize = 1; // pub const VEC_ZNX_ROWS: usize = 1;
/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of /// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of
/// Zn\[X\] with [i64] coefficients. /// Zn\[X\] with [i64] coefficients.
@@ -18,68 +22,57 @@ pub const VEC_ZNX_ROWS: usize = 1;
/// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory /// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory
/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci /// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci
/// are small polynomials of Zn\[X\]. /// are small polynomials of Zn\[X\].
pub struct VecZnx { pub struct VecZnx<D> {
pub inner: ZnxBase, data: D,
n: usize,
cols: usize,
size: usize,
} }
impl GetZnxBase for VecZnx { impl<D> ZnxInfos for VecZnx<D> {
fn znx(&self) -> &ZnxBase { fn cols(&self) -> usize {
&self.inner self.cols
} }
fn znx_mut(&mut self) -> &mut ZnxBase { fn rows(&self) -> usize {
&mut self.inner 1
} }
}
impl ZnxInfos for VecZnx {} fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
impl ZnxSliceSize for VecZnx {
fn sl(&self) -> usize { fn sl(&self) -> usize {
self.cols() * self.n() self.cols() * self.n()
} }
} }
impl ZnxLayout for VecZnx { impl<D> DataView for VecZnx<D> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D> DataViewMut for VecZnx<D> {
fn data_mut(&self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for VecZnx<D> {
type Scalar = i64; type Scalar = i64;
} }
impl ZnxZero for VecZnx {} impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D> {
pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) {
impl ZnxRsh for VecZnx {} normalize(log_base2k, self, col, carry)
impl<B: Backend> ZnxAlloc<B> for VecZnx {
type Scalar = i64;
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx {
debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
VecZnx {
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes),
}
} }
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
debug_assert_eq!(
_rows, VEC_ZNX_ROWS,
"rows != {} not supported for VecZnx",
VEC_ZNX_ROWS
);
module.n() * cols * size * size_of::<Self::Scalar>()
}
}
/// Copies the coefficients of `a` on the receiver.
/// Copy is done with the minimum size matching both backing arrays.
/// Panics if the cols do not match.
pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) {
assert_eq!(b.cols(), a.cols());
let data_a: &[i64] = a.raw();
let data_b: &mut [i64] = b.raw_mut();
let size = min(data_b.len(), data_a.len());
data_b[..size].copy_from_slice(&data_a[..size])
}
impl VecZnx {
/// Truncates the precision of the [VecZnx] by k bits. /// Truncates the precision of the [VecZnx] by k bits.
/// ///
/// # Arguments /// # Arguments
@@ -91,12 +84,6 @@ impl VecZnx {
return; return;
} }
if !self.borrowing() {
self.inner
.data
.truncate(self.n() * self.cols() * (self.size() - k / log_base2k));
}
self.inner.size -= k / log_base2k; self.inner.size -= k / log_base2k;
let k_rem: usize = k % log_base2k; let k_rem: usize = k % log_base2k;
@@ -109,29 +96,72 @@ impl VecZnx {
} }
} }
pub fn copy_from(&mut self, a: &Self) { /// Switches degree of from `a.n()` to `self.n()` into `self`
copy_vec_znx_from(self, a); pub fn switch_degree<Data: AsRef<[u8]>>(&mut self, col: usize, a: &Data, col_a: usize) {
} switch_degree(self, col_a, a, col)
pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) {
normalize(log_base2k, self, col, carry)
}
pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) {
switch_degree(a, col_a, self, col)
} }
// Prints the first `n` coefficients of each limb // Prints the first `n` coefficients of each limb
pub fn print(&self, n: usize, col: usize) { // pub fn print(&self, n: usize, col: usize) {
(0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n])); // (0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n]));
// }
}
impl<D: From<Vec<u8>>> VecZnx<D> {
pub(crate) fn bytes_of<Scalar: Sized>(n: usize, cols: usize, size: usize) -> usize {
n * cols * size * size_of::<Scalar>()
}
pub(crate) fn new<Scalar: Sized>(n: usize, cols: usize, size: usize) -> Self {
let data = alloc_aligned::<u8>(Self::bytes_of::<Scalar>(n, cols, size));
Self {
data: data.into(),
n,
cols,
size,
}
}
pub(crate) fn new_from_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of::<Scalar>(n, cols, size));
Self {
data: data.into(),
n,
cols,
size,
}
} }
} }
//(Jay)TODO: Impl. truncate pow2 for Owned Vector
/// Copies the coefficients of `a` on the receiver.
/// Copy is done with the minimum size matching both backing arrays.
/// Panics if the cols do not match.
pub fn copy_vec_znx_from<DataMut, Data>(b: &mut VecZnx<DataMut>, a: &VecZnx<Data>)
where
DataMut: AsMut<[u8]> + AsRef<[u8]>,
Data: AsRef<[u8]>,
{
assert_eq!(b.cols(), a.cols());
let data_a: &[i64] = a.raw();
let data_b: &mut [i64] = b.raw_mut();
let size = min(data_b.len(), data_a.len());
data_b[..size].copy_from_slice(&data_a[..size])
}
// if !self.borrowing() {
// self.inner
// .data
// .truncate(self.n() * self.cols() * (self.size() - k / log_base2k));
// }
fn normalize_tmp_bytes(n: usize) -> usize { fn normalize_tmp_bytes(n: usize) -> usize {
n * std::mem::size_of::<i64>() n * std::mem::size_of::<i64>()
} }
fn normalize(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { fn normalize<D: AsMut<[u8]>>(log_base2k: usize, a: &mut VecZnx<D>, a_col: usize, tmp_bytes: &mut [u8]) {
let n: usize = a.n(); let n: usize = a.n();
debug_assert!( debug_assert!(
@@ -162,3 +192,62 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u
}); });
} }
} }
// impl<B: Backend> ZnxAlloc<B> for VecZnx {
// type Scalar = i64;
// fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx {
// debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
// VecZnx {
// inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes),
// }
// }
// fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
// debug_assert_eq!(
// _rows, VEC_ZNX_ROWS,
// "rows != {} not supported for VecZnx",
// VEC_ZNX_ROWS
// );
// module.n() * cols * size * size_of::<Self::Scalar>()
// }
// }
impl<D: AsRef<[u8]>> fmt::Display for VecZnx<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"VecZnx(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
}
if coeffs.len() > max_show {
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
}
writeln!(f, "]")?;
}
}
Ok(())
}
}
pub type VecZnxOwned = VecZnx<Vec<u8>>;
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;

View File

@@ -1,74 +1,91 @@
use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_big;
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize, ZnxZero}; use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxView};
use crate::{Backend, FFT64, Module, NTT120}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, alloc_aligned};
use std::marker::PhantomData; use std::marker::PhantomData;
const VEC_ZNX_BIG_ROWS: usize = 1; const VEC_ZNX_BIG_ROWS: usize = 1;
pub struct VecZnxBig<B: Backend> { /// VecZnxBig is Backend dependent, denoted with backend generic `B`
pub inner: ZnxBase, pub struct VecZnxBig<D, B> {
pub _marker: PhantomData<B>, data: D,
n: usize,
cols: usize,
size: usize,
_phantom: PhantomData<B>,
} }
impl<B: Backend> GetZnxBase for VecZnxBig<B> { impl<D, B> ZnxInfos for VecZnxBig<D, B> {
fn znx(&self) -> &ZnxBase { fn cols(&self) -> usize {
&self.inner self.cols
} }
fn znx_mut(&mut self) -> &mut ZnxBase { fn rows(&self) -> usize {
&mut self.inner 1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
fn sl(&self) -> usize {
self.cols() * self.n()
} }
} }
impl<B: Backend> ZnxInfos for VecZnxBig<B> {} impl<D, B> DataView for VecZnxBig<D, B> {
type D = D;
impl<B: Backend> ZnxAlloc<B> for VecZnxBig<B> { fn data(&self) -> &Self::D {
type Scalar = u8; &self.data
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
VecZnxBig {
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_BIG_ROWS, cols, size, bytes),
_marker: PhantomData,
}
}
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
debug_assert_eq!(
_rows, VEC_ZNX_BIG_ROWS,
"rows != {} not supported for VecZnxBig",
VEC_ZNX_BIG_ROWS
);
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
} }
} }
impl ZnxLayout for VecZnxBig<FFT64> { impl<D, B> DataViewMut for VecZnxBig<D, B> {
fn data_mut(&self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for VecZnxBig<D, FFT64> {
type Scalar = i64; type Scalar = i64;
} }
impl ZnxLayout for VecZnxBig<NTT120> { impl<D: From<Vec<u8>>, B: Backend> VecZnxBig<D, B> {
type Scalar = i128; pub(crate) fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
} unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
}
impl ZnxZero for VecZnxBig<FFT64> {} pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
let data = alloc_aligned::<u8>(Self::bytes_of(module, cols, size));
Self {
data: data.into(),
n: module.n(),
cols,
size,
_phantom: PhantomData,
}
}
impl ZnxSliceSize for VecZnxBig<FFT64> { pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
fn sl(&self) -> usize { let data: Vec<u8> = bytes.into();
self.n() * self.cols() assert!(data.len() == Self::bytes_of(module, cols, size));
Self {
data: data.into(),
n: module.n(),
cols,
size,
_phantom: PhantomData,
}
} }
} }
impl ZnxSliceSize for VecZnxBig<NTT120> { pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
fn sl(&self) -> usize {
self.n() * 4 * self.cols()
}
}
impl ZnxZero for VecZnxBig<NTT120> {} // impl VecZnxBig<FFT64> {
// pub fn print(&self, n: usize, col: usize) {
impl VecZnxBig<FFT64> { // (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n]));
pub fn print(&self, n: usize, col: usize) { // }
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); // }
}
}

View File

@@ -1,10 +1,10 @@
use crate::ffi::vec_znx; use crate::ffi::vec_znx;
use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize}; use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxView, ZnxViewMut};
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, assert_alignement}; use crate::{Backend, DataView, FFT64, Module, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxOps, assert_alignement};
pub trait VecZnxBigOps<B: Backend> { pub trait VecZnxBigAlloc<B> {
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<B>; fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// Returns a new [VecZnxBig] with the provided bytes array as backing array.
/// ///
@@ -18,98 +18,100 @@ pub trait VecZnxBigOps<B: Backend> {
/// ///
/// # Panics /// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. /// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBig<B>; fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array. // /// Returns a new [VecZnxBig] with the provided bytes array as backing array.
/// // ///
/// Behavior: the backing array is only borrowed. // /// Behavior: the backing array is only borrowed.
/// // ///
/// # Arguments // /// # Arguments
/// // ///
/// * `cols`: the number of polynomials.. // /// * `cols`: the number of polynomials..
/// * `size`: the number of polynomials per column. // /// * `size`: the number of polynomials per column.
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. // /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
/// // ///
/// # Panics // /// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. // /// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>; // fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>;
/// Returns the minimum number of bytes necessary to allocate /// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxBig] through [VecZnxBig::from_bytes]. /// a new [VecZnxBig] through [VecZnxBig::from_bytes].
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize;
}
pub trait VecZnxBigOps<DataMut, Data, B> {
/// Adds `a` to `b` and stores the result on `c`. /// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add( fn vec_znx_big_add(
&self, &self,
res: &mut VecZnxBig<B>, res: &mut VecZnxBig<DataMut, B>,
res_col: usize, res_col: usize,
a: &VecZnxBig<B>, a: &VecZnxBig<Data, B>,
a_col: usize, a_col: usize,
b: &VecZnxBig<B>, b: &VecZnxBig<Data, B>,
b_col: usize, b_col: usize,
); );
/// Adds `a` to `b` and stores the result on `b`. /// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize); fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &VecZnxBig<Data, B>, a_col: usize);
/// Adds `a` to `b` and stores the result on `c`. /// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add_small( fn vec_znx_big_add_small(
&self, &self,
res: &mut VecZnxBig<B>, res: &mut VecZnxBig<DataMut, B>,
res_col: usize, res_col: usize,
a: &VecZnxBig<B>, a: &VecZnxBig<Data, B>,
a_col: usize, a_col: usize,
b: &VecZnx, b: &VecZnx<Data>,
b_col: usize, b_col: usize,
); );
/// Adds `a` to `b` and stores the result on `b`. /// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnx, a_col: usize); fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
/// Subtracts `a` to `b` and stores the result on `c`. /// Subtracts `a` to `b` and stores the result on `c`.
fn vec_znx_big_sub( fn vec_znx_big_sub(
&self, &self,
res: &mut VecZnxBig<B>, res: &mut VecZnxBig<DataMut, B>,
res_col: usize, res_col: usize,
a: &VecZnxBig<B>, a: &VecZnxBig<Data, B>,
a_col: usize, a_col: usize,
b: &VecZnxBig<B>, b: &VecZnxBig<Data, B>,
b_col: usize, b_col: usize,
); );
/// Subtracts `a` to `b` and stores the result on `b`. /// Subtracts `a` to `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize); fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &VecZnxBig<Data, B>, a_col: usize);
/// Subtracts `b` to `a` and stores the result on `b`. /// Subtracts `b` to `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize); fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &VecZnxBig<Data, B>, a_col: usize);
/// Subtracts `b` to `a` and stores the result on `c`. /// Subtracts `b` to `a` and stores the result on `c`.
fn vec_znx_big_sub_small_a( fn vec_znx_big_sub_small_a(
&self, &self,
res: &mut VecZnxBig<B>, res: &mut VecZnxBig<DataMut, B>,
res_col: usize, res_col: usize,
a: &VecZnx, a: &VecZnx<Data>,
a_col: usize, a_col: usize,
b: &VecZnxBig<B>, b: &VecZnxBig<Data, B>,
b_col: usize, b_col: usize,
); );
/// Subtracts `a` to `b` and stores the result on `b`. /// Subtracts `a` to `b` and stores the result on `b`.
fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnx, a_col: usize); fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
/// Subtracts `b` to `a` and stores the result on `c`. /// Subtracts `b` to `a` and stores the result on `c`.
fn vec_znx_big_sub_small_b( fn vec_znx_big_sub_small_b(
&self, &self,
res: &mut VecZnxBig<B>, res: &mut VecZnxBig<DataMut, B>,
res_col: usize, res_col: usize,
a: &VecZnxBig<B>, a: &VecZnxBig<Data, B>,
a_col: usize, a_col: usize,
b: &VecZnx, b: &VecZnx<Data>,
b_col: usize, b_col: usize,
); );
/// Subtracts `b` to `a` and stores the result on `b`. /// Subtracts `b` to `a` and stores the result on `b`.
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnx, a_col: usize); fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
@@ -123,44 +125,57 @@ pub trait VecZnxBigOps<B: Backend> {
fn vec_znx_big_normalize( fn vec_znx_big_normalize(
&self, &self,
log_base2k: usize, log_base2k: usize,
res: &mut VecZnx, res: &mut VecZnx<DataMut>,
res_col: usize, res_col: usize,
a: &VecZnxBig<B>, a: &VecZnxBig<Data, B>,
a_col: usize, a_col: usize,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
); );
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize); fn vec_znx_big_automorphism(
&self,
k: i64,
res: &mut VecZnxBig<DataMut, B>,
res_col: usize,
a: &VecZnxBig<Data, B>,
a_col: usize,
);
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<B>, a_col: usize); fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<DataMut, B>, a_col: usize);
} }
impl VecZnxBigOps<FFT64> for Module<FFT64> { impl VecZnxBigAlloc<FFT64> for Module<FFT64> {
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<FFT64> { fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<FFT64> {
VecZnxBig::new(self, 1, cols, size) VecZnxBig::new(self, cols, size)
} }
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBig<FFT64> { fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<FFT64> {
VecZnxBig::from_bytes(self, 1, cols, size, bytes) VecZnxBig::new_from_bytes(self, cols, size, bytes)
} }
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> { // fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> {
VecZnxBig::from_bytes_borrow(self, 1, cols, size, tmp_bytes) // VecZnxBig::from_bytes_borrow(self, 1, cols, size, tmp_bytes)
} // }
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
VecZnxBig::bytes_of(self, 1, cols, size) VecZnxBig::bytes_of(self, cols, size)
} }
}
impl<DataMut, Data> VecZnxBigOps<DataMut, Data, FFT64> for Module<FFT64>
where
DataMut: AsMut<[u8]> + AsRef<[u8]>,
Data: AsRef<[u8]>,
{
fn vec_znx_big_add( fn vec_znx_big_add(
&self, &self,
res: &mut VecZnxBig<FFT64>, res: &mut VecZnxBig<DataMut, FFT64>,
res_col: usize, res_col: usize,
a: &VecZnxBig<FFT64>, a: &VecZnxBig<Data, FFT64>,
a_col: usize, a_col: usize,
b: &VecZnxBig<FFT64>, b: &VecZnxBig<Data, FFT64>,
b_col: usize, b_col: usize,
) { ) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -186,20 +201,25 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
} }
} }
fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) { fn vec_znx_big_add_inplace(
&self,
res: &mut VecZnxBig<DataMut, FFT64>,
res_col: usize,
a: &VecZnxBig<Data, FFT64>,
a_col: usize,
) {
unsafe { unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>; Self::vec_znx_big_add(self, res, res_col, a, a_col, res, res_col);
Self::vec_znx_big_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
} }
} }
fn vec_znx_big_sub( fn vec_znx_big_sub(
&self, &self,
res: &mut VecZnxBig<FFT64>, res: &mut VecZnxBig<DataMut, FFT64>,
res_col: usize, res_col: usize,
a: &VecZnxBig<FFT64>, a: &VecZnxBig<Data, FFT64>,
a_col: usize, a_col: usize,
b: &VecZnxBig<FFT64>, b: &VecZnxBig<Data, FFT64>,
b_col: usize, b_col: usize,
) { ) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -225,27 +245,38 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
} }
} }
fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) { //(Jay)TODO: check whether definitions sub_ab, sub_ba make sense to you
fn vec_znx_big_sub_ab_inplace(
&self,
res: &mut VecZnxBig<DataMut, FFT64>,
res_col: usize,
a: &VecZnxBig<Data, FFT64>,
a_col: usize,
) {
unsafe { unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>; Self::vec_znx_big_sub(self, res, res_col, a, a_col, res, res_col);
Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
} }
} }
fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) { fn vec_znx_big_sub_ba_inplace(
&self,
res: &mut VecZnxBig<DataMut, FFT64>,
res_col: usize,
a: &VecZnxBig<Data, FFT64>,
a_col: usize,
) {
unsafe { unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>; Self::vec_znx_big_sub(self, res, res_col, res, res_col, a, a_col);
Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
} }
} }
fn vec_znx_big_sub_small_b( fn vec_znx_big_sub_small_b(
&self, &self,
res: &mut VecZnxBig<FFT64>, res: &mut VecZnxBig<DataMut, FFT64>,
res_col: usize, res_col: usize,
a: &VecZnxBig<FFT64>, a: &VecZnxBig<Data, FFT64>,
a_col: usize, a_col: usize,
b: &VecZnx, b: &VecZnx<Data>,
b_col: usize, b_col: usize,
) { ) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -271,20 +302,25 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
} }
} }
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) { fn vec_znx_big_sub_small_b_inplace(
&self,
res: &mut VecZnxBig<DataMut, FFT64>,
res_col: usize,
a: &VecZnx<Data>,
a_col: usize,
) {
unsafe { unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>; Self::vec_znx_big_sub_small_b(self, res, res_col, res, res_col, a, a_col);
Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
} }
} }
fn vec_znx_big_sub_small_a( fn vec_znx_big_sub_small_a(
&self, &self,
res: &mut VecZnxBig<FFT64>, res: &mut VecZnxBig<DataMut, FFT64>,
res_col: usize, res_col: usize,
a: &VecZnx, a: &VecZnx<Data>,
a_col: usize, a_col: usize,
b: &VecZnxBig<FFT64>, b: &VecZnxBig<Data, FFT64>,
b_col: usize, b_col: usize,
) { ) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -310,20 +346,25 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
} }
} }
fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) { fn vec_znx_big_sub_small_a_inplace(
&self,
res: &mut VecZnxBig<DataMut, FFT64>,
res_col: usize,
a: &VecZnx<Data>,
a_col: usize,
) {
unsafe { unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>; Self::vec_znx_big_sub_small_a(self, res, res_col, a, a_col, res, res_col);
Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
} }
} }
fn vec_znx_big_add_small( fn vec_znx_big_add_small(
&self, &self,
res: &mut VecZnxBig<FFT64>, res: &mut VecZnxBig<DataMut, FFT64>,
res_col: usize, res_col: usize,
a: &VecZnxBig<FFT64>, a: &VecZnxBig<Data, FFT64>,
a_col: usize, a_col: usize,
b: &VecZnx, b: &VecZnx<Data>,
b_col: usize, b_col: usize,
) { ) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -349,11 +390,8 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
} }
} }
fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) { fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<DataMut, FFT64>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
unsafe { Self::vec_znx_big_add_small(self, res, res_col, res, res_col, a, a_col);
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_add_small(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
}
} }
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
@@ -363,9 +401,9 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
fn vec_znx_big_normalize( fn vec_znx_big_normalize(
&self, &self,
log_base2k: usize, log_base2k: usize,
res: &mut VecZnx, res: &mut VecZnx<DataMut>,
res_col: usize, res_col: usize,
a: &VecZnxBig<FFT64>, a: &VecZnxBig<Data, FFT64>,
a_col: usize, a_col: usize,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) { ) {
@@ -391,7 +429,14 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
} }
} }
fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) { fn vec_znx_big_automorphism(
&self,
k: i64,
res: &mut VecZnxBig<DataMut, FFT64>,
res_col: usize,
a: &VecZnxBig<Data, FFT64>,
a_col: usize,
) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.n(), self.n()); assert_eq!(a.n(), self.n());
@@ -411,10 +456,9 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
} }
} }
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<FFT64>, a_col: usize) { fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<DataMut, FFT64>, a_col: usize) {
unsafe { unsafe {
let a_ptr: *mut VecZnxBig<FFT64> = a as *mut VecZnxBig<FFT64>; Self::vec_znx_big_automorphism(self, k, a, a_col, a, a_col);
Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col);
} }
} }
} }

View File

@@ -1,85 +1,135 @@
use crate::ffi::vec_znx_dft;
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize, ZnxZero};
use crate::{Backend, FFT64, Module, VecZnxBig};
use std::marker::PhantomData; use std::marker::PhantomData;
use crate::ffi::vec_znx_dft;
use crate::znx_base::{ZnxAlloc, ZnxInfos};
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned};
const VEC_ZNX_DFT_ROWS: usize = 1; const VEC_ZNX_DFT_ROWS: usize = 1;
pub struct VecZnxDft<B: Backend> { pub struct VecZnxDft<D, B> {
inner: ZnxBase, data: D,
pub _marker: PhantomData<B>, n: usize,
cols: usize,
size: usize,
_phantom: PhantomData<B>,
} }
impl<B: Backend> GetZnxBase for VecZnxDft<B> { impl<D, B> ZnxInfos for VecZnxDft<D, B> {
fn znx(&self) -> &ZnxBase { fn cols(&self) -> usize {
&self.inner self.cols
} }
fn znx_mut(&mut self) -> &mut ZnxBase { fn rows(&self) -> usize {
&mut self.inner 1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
fn sl(&self) -> usize {
self.cols() * self.n()
} }
} }
impl<B: Backend> ZnxInfos for VecZnxDft<B> {} impl<D, B> DataView for VecZnxDft<D, B> {
type D = D;
impl<B: Backend> ZnxAlloc<B> for VecZnxDft<B> { fn data(&self) -> &Self::D {
type Scalar = u8; &self.data
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
Self {
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes),
_marker: PhantomData,
}
}
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
debug_assert_eq!(
_rows, VEC_ZNX_DFT_ROWS,
"rows != {} not supported for VecZnxDft",
VEC_ZNX_DFT_ROWS
);
unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
} }
} }
impl ZnxLayout for VecZnxDft<FFT64> { impl<D, B> DataViewMut for VecZnxDft<D, B> {
fn data_mut(&self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for VecZnxDft<D, FFT64> {
type Scalar = f64; type Scalar = f64;
} }
impl ZnxZero for VecZnxDft<FFT64> {} impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
pub(crate) fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
}
impl ZnxSliceSize for VecZnxDft<FFT64> { pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
fn sl(&self) -> usize { let data = alloc_aligned::<u8>(Self::bytes_of(module, cols, size));
self.n() Self {
data: data.into(),
n: module.n(),
cols,
size,
_phantom: PhantomData,
}
}
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of(module, cols, size));
Self {
data: data.into(),
n: module.n(),
cols,
size,
_phantom: PhantomData,
}
} }
} }
impl VecZnxDft<FFT64> { pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
pub fn print(&self, n: usize, col: usize) {
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n]));
}
}
impl<B: Backend> VecZnxDft<B> { // impl<B: Backend> ZnxAlloc<B> for VecZnxDft<B> {
/// Cast a [VecZnxDft] into a [VecZnxBig]. // type Scalar = u8;
/// The returned [VecZnxBig] shares the backing array
/// with the original [VecZnxDft]. // fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig<B> { // debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
assert!( // Self {
self.data().len() == 0, // inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes),
"cannot alias VecZnxDft into VecZnxBig if it owns the data" // _marker: PhantomData,
); // }
VecZnxBig::<B> { // }
inner: ZnxBase {
data: Vec::new(), // fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
ptr: self.ptr(), // debug_assert_eq!(
n: self.n(), // _rows, VEC_ZNX_DFT_ROWS,
rows: self.rows(), // "rows != {} not supported for VecZnxDft",
cols: self.cols(), // VEC_ZNX_DFT_ROWS
size: self.size(), // );
}, // unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
_marker: PhantomData, // }
} // }
}
} // impl VecZnxDft<FFT64> {
// pub fn print(&self, n: usize, col: usize) {
// (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n]));
// }
// }
// impl<B: Backend> VecZnxDft<B> {
// /// Cast a [VecZnxDft] into a [VecZnxBig].
// /// The returned [VecZnxBig] shares the backing array
// /// with the original [VecZnxDft].
// pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig<B> {
// assert!(
// self.data().len() == 0,
// "cannot alias VecZnxDft into VecZnxBig if it owns the data"
// );
// VecZnxBig::<B> {
// inner: ZnxBase {
// data: Vec::new(),
// ptr: self.ptr(),
// n: self.n(),
// rows: self.rows(),
// cols: self.cols(),
// size: self.size(),
// },
// _marker: PhantomData,
// }
// }
// }

View File

@@ -1,15 +1,14 @@
use crate::VecZnxDftOwned;
use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_big;
use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft;
use crate::znx_base::ZnxAlloc; use crate::znx_base::ZnxAlloc;
use crate::znx_base::ZnxInfos; use crate::znx_base::ZnxInfos;
use crate::znx_base::ZnxLayout; use crate::{FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxView, ZnxViewMut, ZnxZero, assert_alignement};
use crate::znx_base::ZnxSliceSize;
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxZero, assert_alignement};
use std::cmp::min; use std::cmp::min;
pub trait VecZnxDftOps<B: Backend> { pub trait VecZnxDftAlloc<B> {
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<B>; fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// Returns a new [VecZnxDft] with the provided bytes array as backing array.
/// ///
@@ -22,20 +21,20 @@ pub trait VecZnxDftOps<B: Backend> {
/// ///
/// # Panics /// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDft<B>; fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array. // /// Returns a new [VecZnxDft] with the provided bytes array as backing array.
/// // ///
/// Behavior: the backing array is only borrowed. // /// Behavior: the backing array is only borrowed.
/// // ///
/// # Arguments // /// # Arguments
/// // ///
/// * `cols`: the number of cols of the [VecZnxDft]. // /// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. // /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
/// // ///
/// # Panics // /// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. // /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>; // fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// Returns a new [VecZnxDft] with the provided bytes array as backing array.
/// ///
@@ -47,37 +46,58 @@ pub trait VecZnxDftOps<B: Backend> {
/// # Panics /// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
}
pub trait VecZnxDftOps<DataMut, Data, B> {
/// Returns the minimum number of bytes necessary to allocate /// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxDft] through [VecZnxDft::from_bytes]. /// a new [VecZnxDft] through [VecZnxDft::from_bytes].
fn vec_znx_idft_tmp_bytes(&self) -> usize; fn vec_znx_idft_tmp_bytes(&self) -> usize;
/// b <- IDFT(a), uses a as scratch space. /// b <- IDFT(a), uses a as scratch space.
fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &mut VecZnxDft<B>, a_cols: usize); fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &mut VecZnxDft<DataMut, B>, a_cols: usize);
fn vec_znx_idft(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxDft<B>, a_col: usize, tmp_bytes: &mut [u8]); fn vec_znx_idft(
&self,
res: &mut VecZnxBig<DataMut, B>,
res_col: usize,
a: &VecZnxDft<Data, B>,
a_col: usize,
tmp_bytes: &mut [u8],
);
fn vec_znx_dft(&self, res: &mut VecZnxDft<B>, res_col: usize, a: &VecZnx, a_col: usize); fn vec_znx_dft(&self, res: &mut VecZnxDft<DataMut, B>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
} }
impl VecZnxDftOps<FFT64> for Module<FFT64> { impl VecZnxDftAlloc<FFT64> for Module<FFT64> {
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<FFT64> { fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<FFT64> {
VecZnxDft::<FFT64>::new(&self, 1, cols, size) VecZnxDftOwned::new(&self, cols, size)
} }
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDft<FFT64> { fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<FFT64> {
VecZnxDft::from_bytes(self, 1, cols, size, bytes) VecZnxDftOwned::new_from_bytes(self, cols, size, bytes)
} }
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<FFT64> { // fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<FFT64> {
VecZnxDft::from_bytes_borrow(self, 1, cols, size, bytes) // VecZnxDft::from_bytes_borrow(self, 1, cols, size, bytes)
} // }
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
VecZnxDft::bytes_of(&self, 1, cols, size) VecZnxDft::bytes_of(&self, cols, size)
} }
}
fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &mut VecZnxDft<FFT64>, a_col: usize) { impl<DataMut, Data> VecZnxDftOps<DataMut, Data, FFT64> for Module<FFT64>
where
DataMut: AsMut<[u8]> + AsRef<[u8]>,
Data: AsRef<[u8]>,
{
fn vec_znx_idft_tmp_a(
&self,
res: &mut VecZnxBig<DataMut, FFT64>,
res_col: usize,
a: &mut VecZnxDft<DataMut, FFT64>,
a_col: usize,
) {
let min_size: usize = min(res.size(), a.size()); let min_size: usize = min(res.size(), a.size());
unsafe { unsafe {
@@ -86,7 +106,7 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
self.ptr, self.ptr,
res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
1 as u64, 1 as u64,
a.at_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, a.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1 as u64, 1 as u64,
) )
}); });
@@ -104,7 +124,7 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
/// ///
/// # Panics /// # Panics
/// If b.cols < a_cols /// If b.cols < a_cols
fn vec_znx_dft(&self, res: &mut VecZnxDft<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) { fn vec_znx_dft(&self, res: &mut VecZnxDft<DataMut, FFT64>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
let min_size: usize = min(res.size(), a.size()); let min_size: usize = min(res.size(), a.size());
unsafe { unsafe {
@@ -125,7 +145,14 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
} }
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
fn vec_znx_idft(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxDft<FFT64>, a_col: usize, tmp_bytes: &mut [u8]) { fn vec_znx_idft(
&self,
res: &mut VecZnxBig<DataMut, FFT64>,
res_col: usize,
a: &VecZnxDft<Data, FFT64>,
a_col: usize,
tmp_bytes: &mut [u8],
) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!( assert!(

View File

@@ -1,14 +1,15 @@
use crate::ffi::vec_znx; use crate::ffi::vec_znx;
use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree}; use crate::znx_base::{ZnxInfos, switch_degree};
use crate::{Backend, Module, VEC_ZNX_ROWS, VecZnx, assert_alignement}; use crate::{Backend, Module, VecZnx, VecZnxOwned, ZnxView, ZnxViewMut, assert_alignement};
pub trait VecZnxOps {
pub trait VecZnxAlloc {
/// Allocates a new [VecZnx]. /// Allocates a new [VecZnx].
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `cols`: the number of polynomials. /// * `cols`: the number of polynomials.
/// * `size`: the number small polynomials per column. /// * `size`: the number small polynomials per column.
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx; fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned;
/// Instantiates a new [VecZnx] from a slice of bytes. /// Instantiates a new [VecZnx] from a slice of bytes.
/// The returned [VecZnx] takes ownership of the slice of bytes. /// The returned [VecZnx] takes ownership of the slice of bytes.
@@ -20,25 +21,28 @@ pub trait VecZnxOps {
/// ///
/// # Panic /// # Panic
/// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx].
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnx; fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned;
/// Instantiates a new [VecZnx] from a slice of bytes. // /// Instantiates a new [VecZnx] from a slice of bytes.
/// The returned [VecZnx] does take ownership of the slice of bytes. // /// The returned [VecZnx] does take ownership of the slice of bytes.
/// // ///
/// # Arguments // /// # Arguments
/// // ///
/// * `cols`: the number of polynomials. // /// * `cols`: the number of polynomials.
/// * `size`: the number small polynomials per column. // /// * `size`: the number small polynomials per column.
/// // ///
/// # Panic // /// # Panic
/// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. // /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx].
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx; // fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx;
// (Jay)TODO
/// Returns the number of bytes necessary to allocate /// Returns the number of bytes necessary to allocate
/// a new [VecZnx] through [VecZnxOps::new_vec_znx_from_bytes] /// a new [VecZnx] through [VecZnxOps::new_vec_znx_from_bytes]
/// or [VecZnxOps::new_vec_znx_from_bytes_borrow]. /// or [VecZnxOps::new_vec_znx_from_bytes_borrow].
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize;
}
pub trait VecZnxOps<DataMut, Data> {
/// Returns the minimum number of bytes necessary for normalization. /// Returns the minimum number of bytes necessary for normalization.
fn vec_znx_normalize_tmp_bytes(&self) -> usize; fn vec_znx_normalize_tmp_bytes(&self) -> usize;
@@ -46,48 +50,64 @@ pub trait VecZnxOps {
fn vec_znx_normalize( fn vec_znx_normalize(
&self, &self,
log_base2k: usize, log_base2k: usize,
res: &mut VecZnx, res: &mut VecZnx<DataMut>,
res_col: usize, res_col: usize,
a: &VecZnx, a: &VecZnx<Data>,
a_col: usize, a_col: usize,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
); );
/// Normalizes the selected column of `a`. /// Normalizes the selected column of `a`.
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]); fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx<DataMut>, a_col: usize, tmp_bytes: &mut [u8]);
/// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `c`. /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`.
fn vec_znx_add(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize); fn vec_znx_add(
&self,
res: &mut VecZnx<DataMut>,
res_col: usize,
a: &VecZnx<Data>,
a_col: usize,
b: &VecZnx<Data>,
b_col: usize,
);
/// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `res`. /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`.
fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); fn vec_znx_add_inplace(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
/// Subtracts the selected column of `b` to the selected column of `a` and write the result on the selected column of `res`. /// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`.
fn vec_znx_sub(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize); fn vec_znx_sub(
&self,
res: &mut VecZnx<DataMut>,
res_col: usize,
a: &VecZnx<Data>,
a_col: usize,
b: &VecZnx<Data>,
b_col: usize,
);
/// Subtracts the selected column of `a` to the selected column of `res`. /// Subtracts the selected column of `a` from the selected column of `res` inplace.
fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
/// Subtracts the selected column of `a` to the selected column of `res` and negates the selected column of `res`. // /// Subtracts the selected column of `a` from the selected column of `res` and negates the selected column of `res`.
fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); // fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
// Negates the selected column of `a` and stores the result on the selected column of `res`. // Negates the selected column of `a` and stores the result in `res_col` of `res`.
fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); fn vec_znx_negate(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
/// Negates the selected column of `a`. /// Negates the selected column of `a`.
fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize); fn vec_znx_negate_inplace(&self, a: &mut VecZnx<DataMut>, a_col: usize);
/// Multiplies the selected column of `a` by X^k and stores the result on the selected column of `res`. /// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
/// Multiplies the selected column of `a` by X^k. /// Multiplies the selected column of `a` by X^k.
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx<DataMut>, a_col: usize);
/// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result on the selected column of `res`. /// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`.
fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
/// Applies the automorphism X^i -> X^ik on the selected column of `a`. /// Applies the automorphism X^i -> X^ik on the selected column of `a`.
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx<DataMut>, a_col: usize);
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`. /// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
/// ///
@@ -95,7 +115,14 @@ pub trait VecZnxOps {
/// ///
/// This method requires that all [VecZnx] of b have the same ring degree /// This method requires that all [VecZnx] of b have the same ring degree
/// and that b.n() * b.len() <= a.n() /// and that b.n() * b.len() <= a.n()
fn vec_znx_split(&self, res: &mut Vec<VecZnx>, res_col: usize, a: &VecZnx, a_col: usize, buf: &mut VecZnx); fn vec_znx_split(
&self,
res: &mut Vec<VecZnx<DataMut>>,
res_col: usize,
a: &VecZnx<Data>,
a_col: usize,
buf: &mut VecZnx<DataMut>,
);
/// Merges the subrings of the selected column of `a` into the selected column of `res`. /// Merges the subrings of the selected column of `a` into the selected column of `res`.
/// ///
@@ -103,26 +130,29 @@ pub trait VecZnxOps {
/// ///
/// This method requires that all [VecZnx] of a have the same ring degree /// This method requires that all [VecZnx] of a have the same ring degree
/// and that a.n() * a.len() <= b.n() /// and that a.n() * a.len() <= b.n()
fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec<VecZnx>, a_col: usize); fn vec_znx_merge(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &Vec<VecZnx<Data>>, a_col: usize);
} }
impl<B: Backend> VecZnxOps for Module<B> { impl<B: Backend> VecZnxAlloc for Module<B> {
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx { //(Jay)TODO: One must define the Scalar generic param here.
VecZnx::new(self, VEC_ZNX_ROWS, cols, size) fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned {
VecZnxOwned::new(self.n(), cols, size)
} }
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize { fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize {
VecZnx::bytes_of(self, VEC_ZNX_ROWS, cols, size) VecZnxOwned::bytes_of(self.n(), cols, size)
} }
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnx { fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned {
VecZnx::from_bytes(self, VEC_ZNX_ROWS, cols, size, bytes) VecZnxOwned::new_from_bytes(self.n(), cols, size, bytes)
}
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx {
VecZnx::from_bytes_borrow(self, VEC_ZNX_ROWS, cols, size, tmp_bytes)
} }
}
impl<B: Backend, DataMut, Data> VecZnxOps<DataMut, Data> for Module<B>
where
Data: AsRef<[u8]>,
DataMut: AsRef<[u8]> + AsMut<[u8]>,
{
fn vec_znx_normalize_tmp_bytes(&self) -> usize { fn vec_znx_normalize_tmp_bytes(&self) -> usize {
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize } unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize }
} }
@@ -130,9 +160,9 @@ impl<B: Backend> VecZnxOps for Module<B> {
fn vec_znx_normalize( fn vec_znx_normalize(
&self, &self,
log_base2k: usize, log_base2k: usize,
res: &mut VecZnx, res: &mut VecZnx<DataMut>,
res_col: usize, res_col: usize,
a: &VecZnx, a: &VecZnx<Data>,
a_col: usize, a_col: usize,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) { ) {
@@ -158,7 +188,7 @@ impl<B: Backend> VecZnxOps for Module<B> {
} }
} }
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx<DataMut>, a_col: usize, tmp_bytes: &mut [u8]) {
unsafe { unsafe {
let a_ptr: *mut VecZnx = a as *mut VecZnx; let a_ptr: *mut VecZnx = a as *mut VecZnx;
Self::vec_znx_normalize( Self::vec_znx_normalize(
@@ -173,7 +203,15 @@ impl<B: Backend> VecZnxOps for Module<B> {
} }
} }
fn vec_znx_add(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize) { fn vec_znx_add(
&self,
res: &mut VecZnx<DataMut>,
res_col: usize,
a: &VecZnx<Data>,
a_col: usize,
b: &VecZnx<Data>,
b_col: usize,
) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.n(), self.n()); assert_eq!(a.n(), self.n());
@@ -197,14 +235,21 @@ impl<B: Backend> VecZnxOps for Module<B> {
} }
} }
fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { fn vec_znx_add_inplace(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
unsafe { unsafe {
let res_ptr: *mut VecZnx = res as *mut VecZnx; Self::vec_znx_add(&self, res, res_col, a, a_col, res, res_col);
Self::vec_znx_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
} }
} }
fn vec_znx_sub(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize) { fn vec_znx_sub(
&self,
res: &mut VecZnx<DataMut>,
res_col: usize,
a: &VecZnx<Data>,
a_col: usize,
b: &VecZnx<Data>,
b_col: usize,
) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.n(), self.n()); assert_eq!(a.n(), self.n());
@@ -228,21 +273,21 @@ impl<B: Backend> VecZnxOps for Module<B> {
} }
} }
fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
unsafe { unsafe {
let res_ptr: *mut VecZnx = res as *mut VecZnx; let res_ptr: *mut VecZnx = res as *mut VecZnx;
Self::vec_znx_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); Self::vec_znx_sub(self, res, res_col, a, a_col, res, res_col);
} }
} }
fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { // fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) {
unsafe { // unsafe {
let res_ptr: *mut VecZnx = res as *mut VecZnx; // let res_ptr: *mut VecZnx = res as *mut VecZnx;
Self::vec_znx_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); // Self::vec_znx_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
} // }
} // }
fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { fn vec_znx_negate(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.n(), self.n()); assert_eq!(a.n(), self.n());
@@ -261,14 +306,13 @@ impl<B: Backend> VecZnxOps for Module<B> {
} }
} }
fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize) { fn vec_znx_negate_inplace(&self, a: &mut VecZnx<DataMut>, a_col: usize) {
unsafe { unsafe {
let a_ptr: *mut VecZnx = a as *mut VecZnx; Self::vec_znx_negate(self, a, a_col, a, a_col);
Self::vec_znx_negate(self, &mut *a_ptr, a_col, &*a_ptr, a_col);
} }
} }
fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.n(), self.n()); assert_eq!(a.n(), self.n());
@@ -288,14 +332,13 @@ impl<B: Backend> VecZnxOps for Module<B> {
} }
} }
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx<DataMut>, a_col: usize) {
unsafe { unsafe {
let a_ptr: *mut VecZnx = a as *mut VecZnx; Self::vec_znx_rotate(self, k, a, a_col, a, a_col);
Self::vec_znx_rotate(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col);
} }
} }
fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.n(), self.n()); assert_eq!(a.n(), self.n());
@@ -315,14 +358,20 @@ impl<B: Backend> VecZnxOps for Module<B> {
} }
} }
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx<DataMut>, a_col: usize) {
unsafe { unsafe {
let a_ptr: *mut VecZnx = a as *mut VecZnx; Self::vec_znx_automorphism(self, k, a, a_col, a, a_col);
Self::vec_znx_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col);
} }
} }
fn vec_znx_split(&self, res: &mut Vec<VecZnx>, res_col: usize, a: &VecZnx, a_col: usize, buf: &mut VecZnx) { fn vec_znx_split(
&self,
res: &mut Vec<VecZnx<DataMut>>,
res_col: usize,
a: &VecZnx<Data>,
a_col: usize,
buf: &mut VecZnx<DataMut>,
) {
let (n_in, n_out) = (a.n(), res[0].n()); let (n_in, n_out) = (a.n(), res[0].n());
debug_assert!( debug_assert!(
@@ -348,7 +397,7 @@ impl<B: Backend> VecZnxOps for Module<B> {
}) })
} }
fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec<VecZnx>, a_col: usize) { fn vec_znx_merge(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &Vec<VecZnx<Data>>, a_col: usize) {
let (n_in, n_out) = (res.n(), a[0].n()); let (n_in, n_out) = (res.n(), a[0].n());
debug_assert!( debug_assert!(

View File

@@ -54,11 +54,9 @@ pub trait GetZnxBase {
fn znx_mut(&mut self) -> &mut ZnxBase; fn znx_mut(&mut self) -> &mut ZnxBase;
} }
pub trait ZnxInfos: GetZnxBase { pub trait ZnxInfos {
/// Returns the ring degree of the polynomials. /// Returns the ring degree of the polynomials.
fn n(&self) -> usize { fn n(&self) -> usize;
self.znx().n
}
/// Returns the base two logarithm of the ring dimension of the polynomials. /// Returns the base two logarithm of the ring dimension of the polynomials.
fn log_n(&self) -> usize { fn log_n(&self) -> usize {
@@ -66,41 +64,27 @@ pub trait ZnxInfos: GetZnxBase {
} }
/// Returns the number of rows. /// Returns the number of rows.
fn rows(&self) -> usize { fn rows(&self) -> usize;
self.znx().rows
}
/// Returns the number of polynomials in each row. /// Returns the number of polynomials in each row.
fn cols(&self) -> usize { fn cols(&self) -> usize;
self.znx().cols
}
/// Returns the number of size per polynomial. /// Returns the number of size per polynomial.
fn size(&self) -> usize { fn size(&self) -> usize;
self.znx().size
}
/// Returns the underlying raw bytes array.
fn data(&self) -> &[u8] {
&self.znx().data
}
/// Returns a pointer to the underlying raw bytes array.
fn ptr(&self) -> *mut u8 {
self.znx().ptr
}
/// Returns the total number of small polynomials. /// Returns the total number of small polynomials.
fn poly_count(&self) -> usize { fn poly_count(&self) -> usize {
self.rows() * self.cols() * self.size() self.rows() * self.cols() * self.size()
} }
}
pub trait ZnxSliceSize {
/// Returns the slice size, which is the offset between /// Returns the slice size, which is the offset between
/// two size of the same column. /// two size of the same column.
fn sl(&self) -> usize; fn sl(&self) -> usize;
} }
// pub trait ZnxSliceSize {}
//(Jay) TODO: Remove ZnxAlloc
pub trait ZnxAlloc<B: Backend> pub trait ZnxAlloc<B: Backend>
where where
Self: Sized + ZnxInfos, Self: Sized + ZnxInfos,
@@ -122,22 +106,21 @@ where
fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize; fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize;
} }
pub trait ZnxLayout: ZnxInfos { pub trait DataView {
type Scalar; type D;
fn data(&self) -> &Self::D;
}
/// Returns true if the receiver is only borrowing the data. pub trait DataViewMut: DataView {
fn borrowing(&self) -> bool { fn data_mut(&self) -> &mut Self::D;
self.znx().data.len() == 0 }
}
pub trait ZnxView: ZnxInfos + DataView<D: AsRef<[u8]>> {
type Scalar;
/// Returns a non-mutable pointer to the underlying coefficients array. /// Returns a non-mutable pointer to the underlying coefficients array.
fn as_ptr(&self) -> *const Self::Scalar { fn as_ptr(&self) -> *const Self::Scalar {
self.znx().ptr as *const Self::Scalar self.data().as_ref().as_ptr() as *const Self::Scalar
}
/// Returns a mutable pointer to the underlying coefficients array.
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.znx_mut().ptr as *mut Self::Scalar
} }
/// Returns a non-mutable reference to the entire underlying coefficient array. /// Returns a non-mutable reference to the entire underlying coefficient array.
@@ -145,11 +128,6 @@ pub trait ZnxLayout: ZnxInfos {
unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
} }
/// Returns a mutable reference to the entire underlying coefficient array.
fn raw_mut(&mut self) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
}
/// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column. /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -161,6 +139,23 @@ pub trait ZnxLayout: ZnxInfos {
unsafe { self.as_ptr().add(offset) } unsafe { self.as_ptr().add(offset) }
} }
/// Returns non-mutable reference to the (i, j)-th small polynomial.
fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
}
}
pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
/// Returns a mutable pointer to the underlying coefficients array.
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar
}
/// Returns a mutable reference to the entire underlying coefficient array.
fn raw_mut(&mut self) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
}
/// Returns a mutable pointer starting at the j-th small polynomial of the i-th column. /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -172,17 +167,15 @@ pub trait ZnxLayout: ZnxInfos {
unsafe { self.as_mut_ptr().add(offset) } unsafe { self.as_mut_ptr().add(offset) }
} }
/// Returns non-mutable reference to the (i, j)-th small polynomial.
fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
}
/// Returns mutable reference to the (i, j)-th small polynomial. /// Returns mutable reference to the (i, j)-th small polynomial.
fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
} }
} }
//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: AsMut<[u8]>> {}
use std::convert::TryFrom; use std::convert::TryFrom;
use std::num::TryFromIntError; use std::num::TryFromIntError;
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
@@ -213,7 +206,7 @@ impl IntegerType for i128 {
const BITS: u32 = 128; const BITS: u32 = 128;
} }
pub trait ZnxZero: ZnxLayout pub trait ZnxZero: ZnxViewMut
where where
Self: Sized, Self: Sized,
{ {
@@ -238,16 +231,16 @@ where
} }
} }
pub trait ZnxRsh: ZnxLayout + ZnxZero pub trait ZnxRsh: ZnxZero {
where
Self: Sized,
Self::Scalar: IntegerType,
{
fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) { fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) {
rsh(k, log_base2k, self, col, carry) rsh(k, log_base2k, self, col, carry)
} }
} }
// Blanket implementations
impl<T> ZnxZero for T where T: ZnxViewMut {}
impl<T> ZnxRsh for T where T: ZnxZero {}
pub fn rsh<V: ZnxRsh + ZnxZero>(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) pub fn rsh<V: ZnxRsh + ZnxZero>(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8])
where where
V::Scalar: IntegerType, V::Scalar: IntegerType,
@@ -310,10 +303,7 @@ pub fn rsh_tmp_bytes<T: IntegerType>(n: usize) -> usize {
n * std::mem::size_of::<T>() n * std::mem::size_of::<T>()
} }
pub fn switch_degree<T: ZnxLayout + ZnxZero>(b: &mut T, col_b: usize, a: &T, col_a: usize) pub fn switch_degree<DMut: ZnxViewMut + ZnxZero, D: ZnxView>(b: &mut DMut, col_b: usize, a: &D, col_a: usize) {
where
<T as ZnxLayout>::Scalar: IntegerType,
{
let (n_in, n_out) = (a.n(), b.n()); let (n_in, n_out) = (a.n(), b.n());
let (gap_in, gap_out): (usize, usize); let (gap_in, gap_out): (usize, usize);
@@ -334,3 +324,64 @@ where
.for_each(|(x_in, x_out)| *x_out = *x_in); .for_each(|(x_in, x_out)| *x_out = *x_in);
}); });
} }
// pub trait ZnxLayout: ZnxInfos {
// type Scalar;
// /// Returns true if the receiver is only borrowing the data.
// fn borrowing(&self) -> bool {
// self.znx().data.len() == 0
// }
// /// Returns a non-mutable pointer to the underlying coefficients array.
// fn as_ptr(&self) -> *const Self::Scalar {
// self.znx().ptr as *const Self::Scalar
// }
// /// Returns a mutable pointer to the underlying coefficients array.
// fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
// self.znx_mut().ptr as *mut Self::Scalar
// }
// /// Returns a non-mutable reference to the entire underlying coefficient array.
// fn raw(&self) -> &[Self::Scalar] {
// unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
// }
// /// Returns a mutable reference to the entire underlying coefficient array.
// fn raw_mut(&mut self) -> &mut [Self::Scalar] {
// unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
// }
// /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
// fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
// #[cfg(debug_assertions)]
// {
// assert!(i < self.cols());
// assert!(j < self.size());
// }
// let offset: usize = self.n() * (j * self.cols() + i);
// unsafe { self.as_ptr().add(offset) }
// }
// /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
// fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
// #[cfg(debug_assertions)]
// {
// assert!(i < self.cols());
// assert!(j < self.size());
// }
// let offset: usize = self.n() * (j * self.cols() + i);
// unsafe { self.as_mut_ptr().add(offset) }
// }
// /// Returns non-mutable reference to the (i, j)-th small polynomial.
// fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
// unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
// }
// /// Returns mutable reference to the (i, j)-th small polynomial.
// fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
// unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
// }
// }