mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
wip
This commit is contained in:
@@ -18,10 +18,17 @@ pub mod vec_znx_dft_ops;
|
|||||||
pub mod vec_znx_ops;
|
pub mod vec_znx_ops;
|
||||||
pub mod znx_base;
|
pub mod znx_base;
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
any::type_name,
|
||||||
|
ops::{DerefMut, Sub},
|
||||||
|
};
|
||||||
|
|
||||||
pub use encoding::*;
|
pub use encoding::*;
|
||||||
pub use mat_znx_dft::*;
|
pub use mat_znx_dft::*;
|
||||||
pub use mat_znx_dft_ops::*;
|
pub use mat_znx_dft_ops::*;
|
||||||
pub use module::*;
|
pub use module::*;
|
||||||
|
use rand_core::le;
|
||||||
|
use rand_distr::num_traits::sign;
|
||||||
pub use sampling::*;
|
pub use sampling::*;
|
||||||
pub use scalar_znx::*;
|
pub use scalar_znx::*;
|
||||||
pub use scalar_znx_dft::*;
|
pub use scalar_znx_dft::*;
|
||||||
@@ -126,28 +133,177 @@ pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct ScratchSpace {
|
pub struct ScratchOwned(Vec<u8>);
|
||||||
// data: D,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ScratchSpace {
|
impl ScratchOwned {
|
||||||
fn tmp_vec_znx_dft<D, B>(&mut self, n: usize, cols: usize, size: usize) -> VecZnxDft<D, B> {
|
pub fn new(byte_count: usize) -> Self {
|
||||||
todo!()
|
let data: Vec<u8> = alloc_aligned(byte_count);
|
||||||
|
Self(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tmp_vec_znx_big<D, B>(&mut self, n: usize, cols: usize, size: usize) -> VecZnxBig<D, B> {
|
pub fn borrow(&mut self) -> &mut ScratchBorr {
|
||||||
todo!()
|
ScratchBorr::new(&mut self.0)
|
||||||
}
|
|
||||||
|
|
||||||
fn vec_znx_big_normalize_tmp_bytes<B: Backend>(&mut self, module: &Module<B>) -> &mut [u8] {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_apply_dft_tmp_bytes<B: Backend>(&mut self, module: &Module<B>) -> &mut [u8] {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_apply_dft_to_dft_tmp_bytes<B: Backend>(&mut self, module: &Module<B>) -> &mut [u8] {
|
|
||||||
todo!()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct ScratchBorr {
|
||||||
|
data: [u8],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScratchBorr {
|
||||||
|
fn new(data: &mut [u8]) -> &mut Self {
|
||||||
|
unsafe { &mut *(data as *mut [u8] as *mut Self) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
|
||||||
|
let ptr = data.as_mut_ptr();
|
||||||
|
let self_len = data.len();
|
||||||
|
|
||||||
|
let aligned_offset = ptr.align_offset(DEFAULTALIGN);
|
||||||
|
let aligned_len = self_len.saturating_sub(aligned_offset);
|
||||||
|
|
||||||
|
if let Some(rem_len) = aligned_len.checked_sub(take_len) {
|
||||||
|
unsafe {
|
||||||
|
let rem_ptr = ptr.add(aligned_offset).add(take_len);
|
||||||
|
let rem_slice = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
|
||||||
|
|
||||||
|
let take_slice = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
|
||||||
|
|
||||||
|
return (take_slice, rem_slice);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!(
|
||||||
|
"Attempted to take {} from scratch with {} aligned bytes left",
|
||||||
|
take_len,
|
||||||
|
take_len,
|
||||||
|
// type_name::<T>(),
|
||||||
|
// aligned_len
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tmp_scalar_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
|
||||||
|
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::<T>());
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
(
|
||||||
|
&mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)),
|
||||||
|
Self::new(rem_slice),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tmp_vec_znx_dft<B: Backend>(
|
||||||
|
&mut self,
|
||||||
|
module: &Module<B>,
|
||||||
|
cols: usize,
|
||||||
|
size: usize,
|
||||||
|
) -> (VecZnxDft<&mut [u8], B>, &mut Self) {
|
||||||
|
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_dft(module, cols, size));
|
||||||
|
|
||||||
|
(
|
||||||
|
VecZnxDft::from_data(take_slice, module.n(), cols, size),
|
||||||
|
Self::new(rem_slice),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tmp_vec_znx_big<D: for<'a> From<&'a mut [u8]>, B: Backend>(
|
||||||
|
&mut self,
|
||||||
|
module: &Module<B>,
|
||||||
|
cols: usize,
|
||||||
|
size: usize,
|
||||||
|
) -> (VecZnxBig<D, B>, &mut Self) {
|
||||||
|
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_big(module, cols, size));
|
||||||
|
|
||||||
|
(
|
||||||
|
VecZnxBig::from_data(D::from(take_slice), module.n(), cols, size),
|
||||||
|
Self::new(rem_slice),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pub struct ScratchBorrowed<'a> {
|
||||||
|
// data: &'a mut [u8],
|
||||||
|
// }
|
||||||
|
|
||||||
|
// impl<'a> ScratchBorrowed<'a> {
|
||||||
|
// fn take_slice<T>(&mut self, take_len: usize) -> (&mut [T], ScratchBorrowed<'_>) {
|
||||||
|
// let ptr = self.data.as_mut_ptr();
|
||||||
|
// let self_len = self.data.len();
|
||||||
|
|
||||||
|
// //TODO(Jay): print the offset sometimes, just to check
|
||||||
|
// let aligned_offset = ptr.align_offset(DEFAULTALIGN);
|
||||||
|
// let aligned_len = self_len.saturating_sub(aligned_offset);
|
||||||
|
|
||||||
|
// let take_len_bytes = take_len * std::mem::size_of::<T>();
|
||||||
|
|
||||||
|
// if let Some(rem_len) = aligned_len.checked_sub(take_len_bytes) {
|
||||||
|
// unsafe {
|
||||||
|
// let rem_ptr = ptr.add(aligned_offset).add(take_len_bytes);
|
||||||
|
// let rem_slice = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
|
||||||
|
|
||||||
|
// let take_slice = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset) as *mut T, take_len_bytes);
|
||||||
|
|
||||||
|
// return (take_slice, ScratchBorrowed { data: rem_slice });
|
||||||
|
// }
|
||||||
|
// } else {
|
||||||
|
// panic!(
|
||||||
|
// "Attempted to take {} (={} elements of {}) from scratch with {} aligned bytes left",
|
||||||
|
// take_len_bytes,
|
||||||
|
// take_len,
|
||||||
|
// type_name::<T>(),
|
||||||
|
// aligned_len
|
||||||
|
// );
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// fn reborrow(&mut self) -> ScratchBorrowed<'a> {
|
||||||
|
// //(Jay)TODO: `data: &mut *self.data` does not work because liftime of &mut self is different from 'a.
|
||||||
|
// // But it feels that there should be a simpler impl. than the one below
|
||||||
|
// Self {
|
||||||
|
// data: unsafe { &mut *std::ptr::slice_from_raw_parts_mut(self.data.as_mut_ptr(), self.data.len()) },
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// fn tmp_vec_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, Self) {
|
||||||
|
// let (data, re_scratch) = self.take_slice::<u8>(vec_znx_dft::bytes_of_vec_znx_dft(module, cols, size));
|
||||||
|
// (
|
||||||
|
// VecZnxDft::from_data(data, module.n(), cols, size),
|
||||||
|
// re_scratch,
|
||||||
|
// )
|
||||||
|
// }
|
||||||
|
|
||||||
|
// pub(crate) fn len(&self) -> usize {
|
||||||
|
// self.data.len()
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// pub trait Scratch<D> {
|
||||||
|
// fn tmp_vec_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (D, &mut Self);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// impl<'a> Scratch<&'a mut [u8]> for ScratchBorr {
|
||||||
|
// fn tmp_vec_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (&'a mut [u8], &mut Self) {
|
||||||
|
// let (data, rem_scratch) = self.tmp_scalar_slice(vec_znx_dft::bytes_of_vec_znx_dft(module, cols, size));
|
||||||
|
// (
|
||||||
|
// data
|
||||||
|
// rem_scratch,
|
||||||
|
// )
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // fn tmp_vec_znx_big<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, Self) {
|
||||||
|
// // // let (data, re_scratch) = self.take_slice(vec_znx_big::bytes_of_vec_znx_big(module, cols, size));
|
||||||
|
// // // (
|
||||||
|
// // // VecZnxBig::from_data(data, module.n(), cols, size),
|
||||||
|
// // // re_scratch,
|
||||||
|
// // // )
|
||||||
|
// // }
|
||||||
|
|
||||||
|
// // fn scalar_slice<T>(&mut self, len: usize) -> (&mut [T], Self) {
|
||||||
|
// // self.take_slice::<T>(len)
|
||||||
|
// // }
|
||||||
|
|
||||||
|
// // fn reborrow(&mut self) -> Self {
|
||||||
|
// // self.reborrow()
|
||||||
|
// // }
|
||||||
|
// }
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos};
|
use crate::znx_base::ZnxInfos;
|
||||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned};
|
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
@@ -111,26 +111,6 @@ impl<D: From<Vec<u8>>, B: Backend> MatZnxDft<D, B> {
|
|||||||
_marker: PhantomData,
|
_marker: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn from_bytes_borrow(
|
|
||||||
// module: &Module<B>,
|
|
||||||
// rows: usize,
|
|
||||||
// cols_in: usize,
|
|
||||||
// cols_out: usize,
|
|
||||||
// size: usize,
|
|
||||||
// bytes: &mut [u8],
|
|
||||||
// ) -> Self {
|
|
||||||
// debug_assert_eq!(
|
|
||||||
// bytes.len(),
|
|
||||||
// Self::bytes_of(module, rows, cols_in, cols_out, size)
|
|
||||||
// );
|
|
||||||
// Self {
|
|
||||||
// inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes),
|
|
||||||
// cols_in: cols_in,
|
|
||||||
// cols_out: cols_out,
|
|
||||||
// _marker: PhantomData,
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> {
|
impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> {
|
||||||
@@ -170,3 +150,29 @@ impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub type MatZnxDftAllocOwned<B> = MatZnxDft<Vec<u8>, B>;
|
pub type MatZnxDftAllocOwned<B> = MatZnxDft<Vec<u8>, B>;
|
||||||
|
|
||||||
|
impl<B> MatZnxDft<Vec<u8>, B> {
|
||||||
|
pub fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
|
||||||
|
MatZnxDft {
|
||||||
|
data: self.data.as_mut_slice(),
|
||||||
|
n: self.n,
|
||||||
|
size: self.size,
|
||||||
|
rows: self.rows,
|
||||||
|
cols_in: self.cols_in,
|
||||||
|
cols_out: self.cols_out,
|
||||||
|
_marker: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||||
|
MatZnxDft {
|
||||||
|
data: self.data.as_slice(),
|
||||||
|
n: self.n,
|
||||||
|
size: self.size,
|
||||||
|
rows: self.rows,
|
||||||
|
cols_in: self.cols_in,
|
||||||
|
cols_out: self.cols_out,
|
||||||
|
_marker: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
|||||||
use crate::ffi::vmp;
|
use crate::ffi::vmp;
|
||||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||||
use crate::{
|
use crate::{
|
||||||
Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchSpace, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
|
Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchBorr, VecZnx, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
|
||||||
VecZnxDftAlloc, VecZnxDftOps, assert_alignement, is_aligned,
|
VecZnxDftAlloc, VecZnxDftOps,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub trait MatZnxDftAlloc<B> {
|
pub trait MatZnxDftAlloc<B> {
|
||||||
@@ -36,12 +36,55 @@ pub trait MatZnxDftAlloc<B> {
|
|||||||
// ) -> MatZnxDft<FFT64>;
|
// ) -> MatZnxDft<FFT64>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This trait implements methods for vector matrix product,
|
pub trait MatZnxDftScratch {
|
||||||
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
|
|
||||||
pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
|
||||||
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row]
|
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row]
|
||||||
fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
|
fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
|
||||||
|
|
||||||
|
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row]
|
||||||
|
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
|
||||||
|
|
||||||
|
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `c_size`: number of size of the output [VecZnxDft].
|
||||||
|
/// * `a_size`: number of size of the input [VecZnx].
|
||||||
|
/// * `rows`: number of rows of the input [MatZnxDft].
|
||||||
|
/// * `size`: number of size of the input [MatZnxDft].
|
||||||
|
fn vmp_apply_dft_tmp_bytes(
|
||||||
|
&self,
|
||||||
|
c_size: usize,
|
||||||
|
a_size: usize,
|
||||||
|
b_rows: usize,
|
||||||
|
b_cols_in: usize,
|
||||||
|
b_cols_out: usize,
|
||||||
|
b_size: usize,
|
||||||
|
) -> usize;
|
||||||
|
|
||||||
|
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `c_size`: number of size of the output [VecZnxDft].
|
||||||
|
/// * `a_size`: number of size of the input [VecZnxDft].
|
||||||
|
/// * `rows`: number of rows of the input [MatZnxDft].
|
||||||
|
/// * `size`: number of size of the input [MatZnxDft].
|
||||||
|
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||||
|
&self,
|
||||||
|
c_cols: usize,
|
||||||
|
c_size: usize,
|
||||||
|
a_cols: usize,
|
||||||
|
a_size: usize,
|
||||||
|
b_rows: usize,
|
||||||
|
b_cols_in: usize,
|
||||||
|
b_cols_out: usize,
|
||||||
|
b_size: usize,
|
||||||
|
) -> usize;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This trait implements methods for vector matrix product,
|
||||||
|
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
|
||||||
|
pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
||||||
/// Prepares the ith-row of [MatZnxDft] from a [VecZnx].
|
/// Prepares the ith-row of [MatZnxDft] from a [VecZnx].
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -58,12 +101,9 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
|||||||
b_row: usize,
|
b_row: usize,
|
||||||
b_col_in: usize,
|
b_col_in: usize,
|
||||||
a: &VecZnx<Data>,
|
a: &VecZnx<Data>,
|
||||||
scratch: &mut ScratchSpace,
|
scratch: &mut ScratchBorr,
|
||||||
);
|
);
|
||||||
|
|
||||||
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row]
|
|
||||||
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
|
|
||||||
|
|
||||||
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig].
|
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig].
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -78,7 +118,7 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
|||||||
a: &MatZnxDft<Data, B>,
|
a: &MatZnxDft<Data, B>,
|
||||||
b_row: usize,
|
b_row: usize,
|
||||||
b_col_in: usize,
|
b_col_in: usize,
|
||||||
scratch: &mut ScratchSpace,
|
scratch: &mut ScratchBorr,
|
||||||
);
|
);
|
||||||
|
|
||||||
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
|
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
|
||||||
@@ -101,24 +141,6 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
|||||||
/// * `row_i`: the index of the row to extract.
|
/// * `row_i`: the index of the row to extract.
|
||||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<DataMut, B>, a: &MatZnxDft<Data, B>, a_row: usize, a_col_in: usize);
|
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<DataMut, B>, a: &MatZnxDft<Data, B>, a_row: usize, a_col_in: usize);
|
||||||
|
|
||||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `c_size`: number of size of the output [VecZnxDft].
|
|
||||||
/// * `a_size`: number of size of the input [VecZnx].
|
|
||||||
/// * `rows`: number of rows of the input [MatZnxDft].
|
|
||||||
/// * `size`: number of size of the input [MatZnxDft].
|
|
||||||
fn vmp_apply_dft_tmp_bytes(
|
|
||||||
&self,
|
|
||||||
c_size: usize,
|
|
||||||
a_size: usize,
|
|
||||||
b_rows: usize,
|
|
||||||
b_cols_in: usize,
|
|
||||||
b_cols_out: usize,
|
|
||||||
b_size: usize,
|
|
||||||
) -> usize;
|
|
||||||
|
|
||||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||||
///
|
///
|
||||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||||
@@ -143,27 +165,7 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
|||||||
/// * `a`: the left operand [VecZnx] of the vector matrix product.
|
/// * `a`: the left operand [VecZnx] of the vector matrix product.
|
||||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
||||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
|
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
|
||||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft<DataMut, B>, a: &VecZnx<Data>, b: &MatZnxDft<Data, B>, scratch: &mut ScratchSpace);
|
fn vmp_apply_dft(&self, c: &mut VecZnxDft<DataMut, B>, a: &VecZnx<Data>, b: &MatZnxDft<Data, B>, scratch: &mut ScratchBorr);
|
||||||
|
|
||||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `c_size`: number of size of the output [VecZnxDft].
|
|
||||||
/// * `a_size`: number of size of the input [VecZnxDft].
|
|
||||||
/// * `rows`: number of rows of the input [MatZnxDft].
|
|
||||||
/// * `size`: number of size of the input [MatZnxDft].
|
|
||||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
|
||||||
&self,
|
|
||||||
c_cols: usize,
|
|
||||||
c_size: usize,
|
|
||||||
a_cols: usize,
|
|
||||||
a_size: usize,
|
|
||||||
b_rows: usize,
|
|
||||||
b_cols_in: usize,
|
|
||||||
b_cols_out: usize,
|
|
||||||
b_size: usize,
|
|
||||||
) -> usize;
|
|
||||||
|
|
||||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||||
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||||
@@ -195,7 +197,7 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
|||||||
c: &mut VecZnxDft<DataMut, B>,
|
c: &mut VecZnxDft<DataMut, B>,
|
||||||
a: &VecZnxDft<Data, B>,
|
a: &VecZnxDft<Data, B>,
|
||||||
b: &MatZnxDft<Data, B>,
|
b: &MatZnxDft<Data, B>,
|
||||||
scratch: &mut ScratchSpace,
|
scratch: &mut ScratchBorr,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,22 +222,70 @@ impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<DataMut, Data> MatZnxDftOps<DataMut, Data, FFT64> for Module<FFT64>
|
impl<B: Backend> MatZnxDftScratch for Module<B> {
|
||||||
where
|
|
||||||
DataMut: AsMut<[u8]> + AsRef<[u8]>,
|
|
||||||
Data: AsRef<[u8]>,
|
|
||||||
{
|
|
||||||
fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
|
fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
|
||||||
<Self as VecZnxDftAlloc<FFT64>>::bytes_of_vec_znx_dft(self, cols_out, size)
|
<Self as VecZnxDftAlloc<_>>::bytes_of_vec_znx_dft(self, cols_out, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
|
||||||
|
<Self as VecZnxDftAlloc<_>>::bytes_of_vec_znx_dft(self, cols_out, size)
|
||||||
|
+ <Self as VecZnxBigScratch>::vec_znx_big_normalize_tmp_bytes(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vmp_apply_dft_tmp_bytes(
|
||||||
|
&self,
|
||||||
|
c_size: usize,
|
||||||
|
a_size: usize,
|
||||||
|
b_rows: usize,
|
||||||
|
b_cols_in: usize,
|
||||||
|
b_cols_out: usize,
|
||||||
|
b_size: usize,
|
||||||
|
) -> usize {
|
||||||
|
unsafe {
|
||||||
|
vmp::vmp_apply_dft_tmp_bytes(
|
||||||
|
self.ptr,
|
||||||
|
c_size as u64,
|
||||||
|
a_size as u64,
|
||||||
|
(b_rows * b_cols_in) as u64,
|
||||||
|
(b_size * b_cols_out) as u64,
|
||||||
|
) as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||||
|
&self,
|
||||||
|
c_cols: usize,
|
||||||
|
c_size: usize,
|
||||||
|
a_cols: usize,
|
||||||
|
a_size: usize,
|
||||||
|
b_rows: usize,
|
||||||
|
b_cols_in: usize,
|
||||||
|
b_cols_out: usize,
|
||||||
|
b_size: usize,
|
||||||
|
) -> usize {
|
||||||
|
unsafe {
|
||||||
|
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||||
|
self.ptr,
|
||||||
|
(c_size * c_cols) as u64,
|
||||||
|
(a_size * a_cols) as u64,
|
||||||
|
(b_rows * b_cols_in) as u64,
|
||||||
|
(b_size * b_cols_out) as u64,
|
||||||
|
) as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<DataMut, Data> MatZnxDftOps<DataMut, Data, FFT64> for Module<FFT64>
|
||||||
|
where
|
||||||
|
DataMut: AsMut<[u8]> + AsRef<[u8]> + for<'a> From<&'a mut [u8]>,
|
||||||
|
Data: AsRef<[u8]>,
|
||||||
|
{
|
||||||
fn vmp_prepare_row(
|
fn vmp_prepare_row(
|
||||||
&self,
|
&self,
|
||||||
b: &mut MatZnxDft<DataMut, FFT64>,
|
b: &mut MatZnxDft<DataMut, FFT64>,
|
||||||
b_row: usize,
|
b_row: usize,
|
||||||
b_col_in: usize,
|
b_col_in: usize,
|
||||||
a: &VecZnx<Data>,
|
a: &VecZnx<Data>,
|
||||||
scratch: &mut ScratchSpace,
|
scratch: &mut ScratchBorr,
|
||||||
) {
|
) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -278,17 +328,13 @@ where
|
|||||||
let a_size: usize = a.size();
|
let a_size: usize = a.size();
|
||||||
|
|
||||||
// let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size));
|
// let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size));
|
||||||
let mut a_dft = scratch.tmp_vec_znx_dft::<DataMut, _>(self.n(), cols_out, a_size);
|
let (mut a_dft, _) = scratch.tmp_scalar_slice(12);
|
||||||
|
DataMut::from(a_dft);
|
||||||
|
// let (mut a_dft, _) = scratch.tmp_vec_znx_dft::<DataMut, _>(self, cols_out, a_size);
|
||||||
(0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i));
|
(0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i));
|
||||||
|
|
||||||
Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft);
|
Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
|
|
||||||
self.bytes_of_vec_znx_dft(cols_out, size)
|
|
||||||
+ <Self as VecZnxBigOps<DataMut, Data, FFT64>>::vec_znx_big_normalize_tmp_bytes(self)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_extract_row(
|
fn vmp_extract_row(
|
||||||
&self,
|
&self,
|
||||||
log_base2k: usize,
|
log_base2k: usize,
|
||||||
@@ -296,7 +342,7 @@ where
|
|||||||
a: &MatZnxDft<Data, FFT64>,
|
a: &MatZnxDft<Data, FFT64>,
|
||||||
a_row: usize,
|
a_row: usize,
|
||||||
a_col_in: usize,
|
a_col_in: usize,
|
||||||
scratch: &mut ScratchSpace,
|
mut scratch: &mut ScratchBorr,
|
||||||
) {
|
) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -336,9 +382,9 @@ where
|
|||||||
let size: usize = b.size();
|
let size: usize = b.size();
|
||||||
|
|
||||||
// let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size));
|
// let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size));
|
||||||
let mut b_dft = scratch.tmp_vec_znx_dft::<DataMut, _>(self.n(), cols_out, size);
|
let (mut b_dft, scratch) = scratch.tmp_vec_znx_dft(self, cols_out, size);
|
||||||
Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in);
|
Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in);
|
||||||
let mut b_big = scratch.tmp_vec_znx_big(self.n(), cols_out, size);
|
let (mut b_big, scratch) = scratch.tmp_vec_znx_big(self, cols_out, size);
|
||||||
(0..cols_out).for_each(|i| {
|
(0..cols_out).for_each(|i| {
|
||||||
<Self as VecZnxDftOps<DataMut, Data, FFT64>>::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i);
|
<Self as VecZnxDftOps<DataMut, Data, FFT64>>::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i);
|
||||||
self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, scratch);
|
self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, scratch);
|
||||||
@@ -434,32 +480,12 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_apply_dft_tmp_bytes(
|
|
||||||
&self,
|
|
||||||
res_size: usize,
|
|
||||||
a_size: usize,
|
|
||||||
b_rows: usize,
|
|
||||||
b_cols_in: usize,
|
|
||||||
b_cols_out: usize,
|
|
||||||
b_size: usize,
|
|
||||||
) -> usize {
|
|
||||||
unsafe {
|
|
||||||
vmp::vmp_apply_dft_tmp_bytes(
|
|
||||||
self.ptr,
|
|
||||||
res_size as u64,
|
|
||||||
a_size as u64,
|
|
||||||
(b_rows * b_cols_in) as u64,
|
|
||||||
(b_size * b_cols_out) as u64,
|
|
||||||
) as usize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_apply_dft(
|
fn vmp_apply_dft(
|
||||||
&self,
|
&self,
|
||||||
c: &mut VecZnxDft<DataMut, FFT64>,
|
c: &mut VecZnxDft<DataMut, FFT64>,
|
||||||
a: &VecZnx<Data>,
|
a: &VecZnx<Data>,
|
||||||
b: &MatZnxDft<Data, FFT64>,
|
b: &MatZnxDft<Data, FFT64>,
|
||||||
scratch: &mut ScratchSpace,
|
mut scratch: &mut ScratchBorr,
|
||||||
) {
|
) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -493,6 +519,16 @@ where
|
|||||||
// );
|
// );
|
||||||
// assert_alignement(tmp_bytes.as_ptr());
|
// assert_alignement(tmp_bytes.as_ptr());
|
||||||
}
|
}
|
||||||
|
let (tmp_bytes, _) = scratch.tmp_scalar_slice(<Self as MatZnxDftScratch>::vmp_apply_dft_tmp_bytes(
|
||||||
|
self,
|
||||||
|
c.size(),
|
||||||
|
a.size(),
|
||||||
|
b.rows(),
|
||||||
|
b.cols_in(),
|
||||||
|
b.cols_out(),
|
||||||
|
b.size(),
|
||||||
|
));
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_apply_dft(
|
vmp::vmp_apply_dft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
@@ -504,39 +540,17 @@ where
|
|||||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||||
(b.rows() * b.cols_in()) as u64,
|
(b.rows() * b.cols_in()) as u64,
|
||||||
(b.size() * b.cols_out()) as u64,
|
(b.size() * b.cols_out()) as u64,
|
||||||
scratch.vmp_apply_dft_tmp_bytes(self).as_mut_ptr(),
|
tmp_bytes.as_mut_ptr(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
|
||||||
&self,
|
|
||||||
res_cols: usize,
|
|
||||||
res_size: usize,
|
|
||||||
a_size: usize,
|
|
||||||
a_cols: usize,
|
|
||||||
b_rows: usize,
|
|
||||||
b_cols_in: usize,
|
|
||||||
b_cols_out: usize,
|
|
||||||
b_size: usize,
|
|
||||||
) -> usize {
|
|
||||||
unsafe {
|
|
||||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
|
||||||
self.ptr,
|
|
||||||
(res_size * res_cols) as u64,
|
|
||||||
(a_size * a_cols) as u64,
|
|
||||||
(b_rows * b_cols_in) as u64,
|
|
||||||
(b_size * b_cols_out) as u64,
|
|
||||||
) as usize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_apply_dft_to_dft(
|
fn vmp_apply_dft_to_dft(
|
||||||
&self,
|
&self,
|
||||||
c: &mut VecZnxDft<DataMut, FFT64>,
|
c: &mut VecZnxDft<DataMut, FFT64>,
|
||||||
a: &VecZnxDft<Data, FFT64>,
|
a: &VecZnxDft<Data, FFT64>,
|
||||||
b: &MatZnxDft<Data, FFT64>,
|
b: &MatZnxDft<Data, FFT64>,
|
||||||
scratch: &mut ScratchSpace,
|
mut scratch: &mut ScratchBorr,
|
||||||
) {
|
) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -572,6 +586,17 @@ where
|
|||||||
// );
|
// );
|
||||||
// assert_alignement(tmp_bytes.as_ptr());
|
// assert_alignement(tmp_bytes.as_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_dft_to_dft_tmp_bytes(
|
||||||
|
c.cols(),
|
||||||
|
c.size(),
|
||||||
|
a.cols(),
|
||||||
|
a.size(),
|
||||||
|
b.rows(),
|
||||||
|
b.cols_in(),
|
||||||
|
b.cols_out(),
|
||||||
|
b.size(),
|
||||||
|
));
|
||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_apply_dft_to_dft(
|
vmp::vmp_apply_dft_to_dft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
@@ -582,7 +607,7 @@ where
|
|||||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||||
b.rows() as u64,
|
b.rows() as u64,
|
||||||
(b.size() * b.cols()) as u64,
|
(b.size() * b.cols()) as u64,
|
||||||
scratch.vmp_apply_dft_to_dft_tmp_bytes(self).as_mut_ptr(),
|
tmp_bytes.as_mut_ptr(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -590,6 +615,7 @@ where
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use crate::ScratchOwned;
|
||||||
use crate::mat_znx_dft_ops::*;
|
use crate::mat_znx_dft_ops::*;
|
||||||
use crate::vec_znx_big_ops::*;
|
use crate::vec_znx_big_ops::*;
|
||||||
use crate::vec_znx_dft_ops::*;
|
use crate::vec_znx_dft_ops::*;
|
||||||
@@ -617,7 +643,9 @@ mod tests {
|
|||||||
|
|
||||||
// let mut tmp_bytes: Vec<u8> =
|
// let mut tmp_bytes: Vec<u8> =
|
||||||
// alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes());
|
// alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes());
|
||||||
let mut scratch = ScratchSpace {};
|
let mut scratch = ScratchOwned::new(
|
||||||
|
2 * (module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) + module.vec_znx_big_normalize_tmp_bytes()),
|
||||||
|
);
|
||||||
let mut tmp_bytes: Vec<u8> =
|
let mut tmp_bytes: Vec<u8> =
|
||||||
alloc_aligned::<u8>(<Module<FFT64> as VecZnxDftOps<Vec<u8>, Vec<u8>, _>>::vec_znx_idft_tmp_bytes(&module));
|
alloc_aligned::<u8>(<Module<FFT64> as VecZnxDftOps<Vec<u8>, Vec<u8>, _>>::vec_znx_idft_tmp_bytes(&module));
|
||||||
|
|
||||||
@@ -630,7 +658,9 @@ mod tests {
|
|||||||
module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
|
module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
|
||||||
});
|
});
|
||||||
|
|
||||||
module.vmp_prepare_row(&mut vmpmat_0, row_i, col_in, &a, &mut scratch);
|
// let g = vmpmat_0.to_mut();
|
||||||
|
|
||||||
|
module.vmp_prepare_row(&mut vmpmat_0.to_mut(), row_i, col_in, &a, scratch.borrow());
|
||||||
|
|
||||||
// Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft)
|
// Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft)
|
||||||
module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft);
|
module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft);
|
||||||
@@ -641,11 +671,25 @@ mod tests {
|
|||||||
assert_eq!(a_dft.raw(), b_dft.raw());
|
assert_eq!(a_dft.raw(), b_dft.raw());
|
||||||
|
|
||||||
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)
|
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)
|
||||||
module.vmp_extract_row(log_base2k, &mut b, &vmpmat_0, row_i, col_in, &mut scratch);
|
// module.vmp_extract_row(
|
||||||
|
// log_base2k,
|
||||||
|
// &mut b.to_mut(),
|
||||||
|
// &vmpmat_0.to_ref(),
|
||||||
|
// row_i,
|
||||||
|
// col_in,
|
||||||
|
// scratch.borrow(),
|
||||||
|
// );
|
||||||
|
|
||||||
(0..mat_cols_out).for_each(|col_out| {
|
(0..mat_cols_out).for_each(|col_out| {
|
||||||
module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes);
|
module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes);
|
||||||
module.vec_znx_big_normalize(log_base2k, &mut a, col_out, &a_big, col_out, &mut scratch);
|
module.vec_znx_big_normalize(
|
||||||
|
log_base2k,
|
||||||
|
&mut a.to_mut(),
|
||||||
|
col_out,
|
||||||
|
&a_big.to_ref(),
|
||||||
|
col_out,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
assert_eq!(a.raw(), b.raw());
|
assert_eq!(a.raw(), b.raw());
|
||||||
|
|||||||
@@ -97,11 +97,6 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D> {
|
|||||||
pub fn switch_degree<Data: AsRef<[u8]>>(&mut self, col: usize, a: &VecZnx<Data>, col_a: usize) {
|
pub fn switch_degree<Data: AsRef<[u8]>>(&mut self, col: usize, a: &VecZnx<Data>, col_a: usize) {
|
||||||
switch_degree(self, col_a, a, col)
|
switch_degree(self, col_a, a, col)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prints the first `n` coefficients of each limb
|
|
||||||
// pub fn print(&self, n: usize, col: usize) {
|
|
||||||
// (0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n]));
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D: From<Vec<u8>>> VecZnx<D> {
|
impl<D: From<Vec<u8>>> VecZnx<D> {
|
||||||
@@ -131,8 +126,6 @@ impl<D: From<Vec<u8>>> VecZnx<D> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//(Jay)TODO: Impl. truncate pow2 for Owned Vector
|
|
||||||
|
|
||||||
/// Copies the coefficients of `a` on the receiver.
|
/// Copies the coefficients of `a` on the receiver.
|
||||||
/// Copy is done with the minimum size matching both backing arrays.
|
/// Copy is done with the minimum size matching both backing arrays.
|
||||||
/// Panics if the cols do not match.
|
/// Panics if the cols do not match.
|
||||||
@@ -148,12 +141,6 @@ where
|
|||||||
data_b[..size].copy_from_slice(&data_a[..size])
|
data_b[..size].copy_from_slice(&data_a[..size])
|
||||||
}
|
}
|
||||||
|
|
||||||
// if !self.borrowing() {
|
|
||||||
// self.inner
|
|
||||||
// .data
|
|
||||||
// .truncate(self.n() * self.cols() * (self.size() - k / log_base2k));
|
|
||||||
// }
|
|
||||||
|
|
||||||
fn normalize_tmp_bytes(n: usize) -> usize {
|
fn normalize_tmp_bytes(n: usize) -> usize {
|
||||||
n * std::mem::size_of::<i64>()
|
n * std::mem::size_of::<i64>()
|
||||||
}
|
}
|
||||||
@@ -190,26 +177,6 @@ fn normalize<D: AsMut<[u8]> + AsRef<[u8]>>(log_base2k: usize, a: &mut VecZnx<D>,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// impl<B: Backend> ZnxAlloc<B> for VecZnx {
|
|
||||||
// type Scalar = i64;
|
|
||||||
|
|
||||||
// fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx {
|
|
||||||
// debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
|
|
||||||
// VecZnx {
|
|
||||||
// inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes),
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
|
|
||||||
// debug_assert_eq!(
|
|
||||||
// _rows, VEC_ZNX_ROWS,
|
|
||||||
// "rows != {} not supported for VecZnx",
|
|
||||||
// VEC_ZNX_ROWS
|
|
||||||
// );
|
|
||||||
// module.n() * cols * size * size_of::<Self::Scalar>()
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
impl<D: AsRef<[u8]>> fmt::Display for VecZnx<D> {
|
impl<D: AsRef<[u8]>> fmt::Display for VecZnx<D> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
writeln!(
|
writeln!(
|
||||||
@@ -248,3 +215,23 @@ impl<D: AsRef<[u8]>> fmt::Display for VecZnx<D> {
|
|||||||
pub type VecZnxOwned = VecZnx<Vec<u8>>;
|
pub type VecZnxOwned = VecZnx<Vec<u8>>;
|
||||||
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
|
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
|
||||||
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
|
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
|
||||||
|
|
||||||
|
impl VecZnx<Vec<u8>> {
|
||||||
|
pub(crate) fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||||
|
VecZnx {
|
||||||
|
data: self.data.as_mut_slice(),
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: self.size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||||
|
VecZnx {
|
||||||
|
data: self.data.as_slice(),
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: self.size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -53,13 +53,13 @@ impl<D: AsRef<[u8]>> ZnxView for VecZnxBig<D, FFT64> {
|
|||||||
type Scalar = i64;
|
type Scalar = i64;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D: From<Vec<u8>>, B: Backend> VecZnxBig<D, B> {
|
pub(crate) fn bytes_of_vec_znx_big<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {
|
||||||
pub(crate) fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
|
|
||||||
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
|
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<D: From<Vec<u8>>, B: Backend> VecZnxBig<D, B> {
|
||||||
pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
||||||
let data = alloc_aligned::<u8>(Self::bytes_of(module, cols, size));
|
let data = alloc_aligned::<u8>(bytes_of_vec_znx_big(module, cols, size));
|
||||||
Self {
|
Self {
|
||||||
data: data.into(),
|
data: data.into(),
|
||||||
n: module.n(),
|
n: module.n(),
|
||||||
@@ -71,7 +71,7 @@ impl<D: From<Vec<u8>>, B: Backend> VecZnxBig<D, B> {
|
|||||||
|
|
||||||
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||||
let data: Vec<u8> = bytes.into();
|
let data: Vec<u8> = bytes.into();
|
||||||
assert!(data.len() == Self::bytes_of(module, cols, size));
|
assert!(data.len() == bytes_of_vec_znx_big(module, cols, size));
|
||||||
Self {
|
Self {
|
||||||
data: data.into(),
|
data: data.into(),
|
||||||
n: module.n(),
|
n: module.n(),
|
||||||
@@ -82,8 +82,42 @@ impl<D: From<Vec<u8>>, B: Backend> VecZnxBig<D, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<D, B> VecZnxBig<D, B> {
|
||||||
|
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
data,
|
||||||
|
n,
|
||||||
|
cols,
|
||||||
|
size,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
|
pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
|
||||||
|
|
||||||
|
impl<B> VecZnxBig<Vec<u8>, B> {
|
||||||
|
pub(crate) fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
|
||||||
|
VecZnxBig {
|
||||||
|
data: self.data.as_mut_slice(),
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: self.size,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn to_ref(&self) -> VecZnxBig<&[u8], B> {
|
||||||
|
VecZnxBig {
|
||||||
|
data: self.data.as_slice(),
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: self.size,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// impl VecZnxBig<FFT64> {
|
// impl VecZnxBig<FFT64> {
|
||||||
// pub fn print(&self, n: usize, col: usize) {
|
// pub fn print(&self, n: usize, col: usize) {
|
||||||
// (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n]));
|
// (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n]));
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
use crate::ffi::vec_znx;
|
use crate::ffi::vec_znx;
|
||||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||||
use crate::{Backend, DataView, FFT64, Module, ScratchSpace, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxOps, assert_alignement};
|
use crate::{
|
||||||
|
Backend, DataView, FFT64, Module, ScratchBorr, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxScratch, assert_alignement,
|
||||||
|
bytes_of_vec_znx_big,
|
||||||
|
};
|
||||||
|
|
||||||
pub trait VecZnxBigAlloc<B> {
|
pub trait VecZnxBigAlloc<B> {
|
||||||
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
||||||
@@ -113,9 +116,6 @@ pub trait VecZnxBigOps<DataMut, Data, B> {
|
|||||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||||
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
|
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
|
||||||
|
|
||||||
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
|
|
||||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
|
|
||||||
|
|
||||||
/// Normalizes `a` and stores the result on `b`.
|
/// Normalizes `a` and stores the result on `b`.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -129,7 +129,7 @@ pub trait VecZnxBigOps<DataMut, Data, B> {
|
|||||||
res_col: usize,
|
res_col: usize,
|
||||||
a: &VecZnxBig<Data, B>,
|
a: &VecZnxBig<Data, B>,
|
||||||
a_col: usize,
|
a_col: usize,
|
||||||
scratch: &mut ScratchSpace,
|
scratch: &mut ScratchBorr,
|
||||||
);
|
);
|
||||||
|
|
||||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||||
@@ -146,6 +146,11 @@ pub trait VecZnxBigOps<DataMut, Data, B> {
|
|||||||
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<DataMut, B>, a_col: usize);
|
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<DataMut, B>, a_col: usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait VecZnxBigScratch {
|
||||||
|
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
|
||||||
|
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
|
||||||
|
}
|
||||||
|
|
||||||
impl VecZnxBigAlloc<FFT64> for Module<FFT64> {
|
impl VecZnxBigAlloc<FFT64> for Module<FFT64> {
|
||||||
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<FFT64> {
|
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<FFT64> {
|
||||||
VecZnxBig::new(self, cols, size)
|
VecZnxBig::new(self, cols, size)
|
||||||
@@ -160,7 +165,7 @@ impl VecZnxBigAlloc<FFT64> for Module<FFT64> {
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
|
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
|
||||||
VecZnxBigOwned::bytes_of(self, cols, size)
|
bytes_of_vec_znx_big(self, cols, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -491,10 +496,6 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
|
|
||||||
<Self as VecZnxOps<DataMut, Data>>::vec_znx_normalize_tmp_bytes(self)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vec_znx_big_normalize(
|
fn vec_znx_big_normalize(
|
||||||
&self,
|
&self,
|
||||||
log_base2k: usize,
|
log_base2k: usize,
|
||||||
@@ -502,7 +503,7 @@ where
|
|||||||
res_col: usize,
|
res_col: usize,
|
||||||
a: &VecZnxBig<Data, FFT64>,
|
a: &VecZnxBig<Data, FFT64>,
|
||||||
a_col: usize,
|
a_col: usize,
|
||||||
scratch: &mut ScratchSpace,
|
scratch: &mut ScratchBorr,
|
||||||
) {
|
) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -513,6 +514,10 @@ where
|
|||||||
// assert!(tmp_bytes.len() >= <Self as VecZnxOps<DataMut, Data>>::vec_znx_normalize_tmp_bytes(&self));
|
// assert!(tmp_bytes.len() >= <Self as VecZnxOps<DataMut, Data>>::vec_znx_normalize_tmp_bytes(&self));
|
||||||
// assert_alignement(tmp_bytes.as_ptr());
|
// assert_alignement(tmp_bytes.as_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let (tmp_bytes, _) = scratch.tmp_scalar_slice(<Self as VecZnxBigScratch>::vec_znx_big_normalize_tmp_bytes(
|
||||||
|
&self,
|
||||||
|
));
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx::vec_znx_normalize_base2k(
|
vec_znx::vec_znx_normalize_base2k(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
@@ -523,7 +528,7 @@ where
|
|||||||
a.at_ptr(a_col, 0),
|
a.at_ptr(a_col, 0),
|
||||||
a.size() as u64,
|
a.size() as u64,
|
||||||
a.sl() as u64,
|
a.sl() as u64,
|
||||||
scratch.vec_znx_big_normalize_tmp_bytes(self).as_mut_ptr(),
|
tmp_bytes.as_mut_ptr(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -574,3 +579,9 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> VecZnxBigScratch for Module<B> {
|
||||||
|
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
|
||||||
|
<Self as VecZnxScratch>::vec_znx_normalize_tmp_bytes(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -54,13 +54,13 @@ impl<D: AsRef<[u8]>> ZnxView for VecZnxDft<D, FFT64> {
|
|||||||
type Scalar = f64;
|
type Scalar = f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
|
pub(crate) fn bytes_of_vec_znx_dft<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {
|
||||||
pub(crate) fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
|
|
||||||
unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
|
unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
|
||||||
pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
||||||
let data = alloc_aligned::<u8>(Self::bytes_of(module, cols, size));
|
let data = alloc_aligned::<u8>(bytes_of_vec_znx_dft(module, cols, size));
|
||||||
Self {
|
Self {
|
||||||
data: data.into(),
|
data: data.into(),
|
||||||
n: module.n(),
|
n: module.n(),
|
||||||
@@ -72,7 +72,7 @@ impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
|
|||||||
|
|
||||||
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||||
let data: Vec<u8> = bytes.into();
|
let data: Vec<u8> = bytes.into();
|
||||||
assert!(data.len() == Self::bytes_of(module, cols, size));
|
assert!(data.len() == bytes_of_vec_znx_dft(module, cols, size));
|
||||||
Self {
|
Self {
|
||||||
data: data.into(),
|
data: data.into(),
|
||||||
n: module.n(),
|
n: module.n(),
|
||||||
@@ -85,8 +85,8 @@ impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
|
|||||||
|
|
||||||
pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
|
pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
|
||||||
|
|
||||||
impl<'a, D: ?Sized, B> VecZnxDft<&'a mut D, B> {
|
impl<D, B> VecZnxDft<D, B> {
|
||||||
pub(crate) fn from_mut_slice(data: &'a mut D, n: usize, cols: usize, size: usize) -> Self {
|
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
data,
|
data,
|
||||||
n,
|
n,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use crate::VecZnxDftOwned;
|
|
||||||
use crate::ffi::{vec_znx_big, vec_znx_dft};
|
use crate::ffi::{vec_znx_big, vec_znx_dft};
|
||||||
|
use crate::vec_znx_dft::bytes_of_vec_znx_dft;
|
||||||
use crate::znx_base::ZnxInfos;
|
use crate::znx_base::ZnxInfos;
|
||||||
|
use crate::{Backend, VecZnxDftOwned};
|
||||||
use crate::{FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxView, ZnxViewMut, ZnxZero, assert_alignement};
|
use crate::{FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxView, ZnxViewMut, ZnxZero, assert_alignement};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
|
||||||
@@ -66,12 +67,12 @@ pub trait VecZnxDftOps<DataMut, Data, B> {
|
|||||||
fn vec_znx_dft(&self, res: &mut VecZnxDft<DataMut, B>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
|
fn vec_znx_dft(&self, res: &mut VecZnxDft<DataMut, B>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VecZnxDftAlloc<FFT64> for Module<FFT64> {
|
impl<B: Backend> VecZnxDftAlloc<B> for Module<B> {
|
||||||
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<FFT64> {
|
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<B> {
|
||||||
VecZnxDftOwned::new(&self, cols, size)
|
VecZnxDftOwned::new(&self, cols, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<FFT64> {
|
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
|
||||||
VecZnxDftOwned::new_from_bytes(self, cols, size, bytes)
|
VecZnxDftOwned::new_from_bytes(self, cols, size, bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,7 +81,7 @@ impl VecZnxDftAlloc<FFT64> for Module<FFT64> {
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
|
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
|
||||||
VecZnxDftOwned::bytes_of(&self, cols, size)
|
bytes_of_vec_znx_dft(self, cols, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -43,9 +43,6 @@ pub trait VecZnxAlloc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub trait VecZnxOps<DataMut, Data> {
|
pub trait VecZnxOps<DataMut, Data> {
|
||||||
/// Returns the minimum number of bytes necessary for normalization.
|
|
||||||
fn vec_znx_normalize_tmp_bytes(&self) -> usize;
|
|
||||||
|
|
||||||
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
|
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
|
||||||
fn vec_znx_normalize(
|
fn vec_znx_normalize(
|
||||||
&self,
|
&self,
|
||||||
@@ -137,6 +134,11 @@ pub trait VecZnxOps<DataMut, Data> {
|
|||||||
fn vec_znx_merge(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &Vec<VecZnx<Data>>, a_col: usize);
|
fn vec_znx_merge(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &Vec<VecZnx<Data>>, a_col: usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait VecZnxScratch {
|
||||||
|
/// Returns the minimum number of bytes necessary for normalization.
|
||||||
|
fn vec_znx_normalize_tmp_bytes(&self) -> usize;
|
||||||
|
}
|
||||||
|
|
||||||
impl<B: Backend> VecZnxAlloc for Module<B> {
|
impl<B: Backend> VecZnxAlloc for Module<B> {
|
||||||
//(Jay)TODO: One must define the Scalar generic param here.
|
//(Jay)TODO: One must define the Scalar generic param here.
|
||||||
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned {
|
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned {
|
||||||
@@ -157,10 +159,6 @@ where
|
|||||||
Data: AsRef<[u8]>,
|
Data: AsRef<[u8]>,
|
||||||
DataMut: AsRef<[u8]> + AsMut<[u8]>,
|
DataMut: AsRef<[u8]> + AsMut<[u8]>,
|
||||||
{
|
{
|
||||||
fn vec_znx_normalize_tmp_bytes(&self) -> usize {
|
|
||||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vec_znx_normalize(
|
fn vec_znx_normalize(
|
||||||
&self,
|
&self,
|
||||||
log_base2k: usize,
|
log_base2k: usize,
|
||||||
@@ -174,7 +172,7 @@ where
|
|||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
assert_eq!(res.n(), self.n());
|
assert_eq!(res.n(), self.n());
|
||||||
assert!(tmp_bytes.len() >= <Self as VecZnxOps<DataMut, Data>>::vec_znx_normalize_tmp_bytes(&self));
|
assert!(tmp_bytes.len() >= <Self as VecZnxScratch>::vec_znx_normalize_tmp_bytes(&self));
|
||||||
assert_alignement(tmp_bytes.as_ptr());
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
@@ -489,3 +487,9 @@ where
|
|||||||
<Self as VecZnxOps<DataMut, Data>>::vec_znx_rotate_inplace(self, a.len() as i64, res, res_col);
|
<Self as VecZnxOps<DataMut, Data>>::vec_znx_rotate_inplace(self, a.len() as i64, res, res_col);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> VecZnxScratch for Module<B> {
|
||||||
|
fn vec_znx_normalize_tmp_bytes(&self) -> usize {
|
||||||
|
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,59 +1,6 @@
|
|||||||
use crate::{Backend, Module, alloc_aligned, assert_alignement, cast_mut};
|
|
||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
|
||||||
pub struct ZnxBase {
|
|
||||||
/// The ring degree
|
|
||||||
pub n: usize,
|
|
||||||
|
|
||||||
/// The number of rows (in the third dimension)
|
|
||||||
pub rows: usize,
|
|
||||||
|
|
||||||
/// The number of polynomials
|
|
||||||
pub cols: usize,
|
|
||||||
|
|
||||||
/// The number of size per polynomial (a.k.a small polynomials).
|
|
||||||
pub size: usize,
|
|
||||||
|
|
||||||
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
|
|
||||||
pub data: Vec<u8>,
|
|
||||||
|
|
||||||
/// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it).
|
|
||||||
pub ptr: *mut u8,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ZnxBase {
|
|
||||||
pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
|
|
||||||
let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes);
|
|
||||||
res.data = bytes;
|
|
||||||
res
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(n & (n - 1), 0, "n must be a power of two");
|
|
||||||
assert!(n > 0, "n must be greater than 0");
|
|
||||||
assert!(rows > 0, "rows must be greater than 0");
|
|
||||||
assert!(cols > 0, "cols must be greater than 0");
|
|
||||||
assert!(size > 0, "size must be greater than 0");
|
|
||||||
}
|
|
||||||
Self {
|
|
||||||
n: n,
|
|
||||||
rows: rows,
|
|
||||||
cols: cols,
|
|
||||||
size: size,
|
|
||||||
data: Vec::new(),
|
|
||||||
ptr: bytes.as_mut_ptr(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait GetZnxBase {
|
|
||||||
fn znx(&self) -> &ZnxBase;
|
|
||||||
fn znx_mut(&mut self) -> &mut ZnxBase;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait ZnxInfos {
|
pub trait ZnxInfos {
|
||||||
/// Returns the ring degree of the polynomials.
|
/// Returns the ring degree of the polynomials.
|
||||||
fn n(&self) -> usize;
|
fn n(&self) -> usize;
|
||||||
@@ -82,30 +29,6 @@ pub trait ZnxInfos {
|
|||||||
fn sl(&self) -> usize;
|
fn sl(&self) -> usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub trait ZnxSliceSize {}
|
|
||||||
|
|
||||||
//(Jay) TODO: Remove ZnxAlloc
|
|
||||||
// pub trait ZnxAlloc<B: Backend>
|
|
||||||
// where
|
|
||||||
// Self: Sized + ZnxInfos,
|
|
||||||
// {
|
|
||||||
// type Scalar;
|
|
||||||
// fn new(module: &Module<B>, rows: usize, cols: usize, size: usize) -> Self {
|
|
||||||
// let bytes: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(module, rows, cols, size));
|
|
||||||
// Self::from_bytes(module, rows, cols, size, bytes)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// fn from_bytes(module: &Module<B>, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
|
|
||||||
// let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes);
|
|
||||||
// res.znx_mut().data = bytes;
|
|
||||||
// res
|
|
||||||
// }
|
|
||||||
|
|
||||||
// fn from_bytes_borrow(module: &Module<B>, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
|
|
||||||
|
|
||||||
// fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize;
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub trait DataView {
|
pub trait DataView {
|
||||||
type D;
|
type D;
|
||||||
fn data(&self) -> &Self::D;
|
fn data(&self) -> &Self::D;
|
||||||
@@ -176,35 +99,6 @@ pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
|
|||||||
//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known
|
//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known
|
||||||
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: AsMut<[u8]>> {}
|
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: AsMut<[u8]>> {}
|
||||||
|
|
||||||
use std::convert::TryFrom;
|
|
||||||
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
|
|
||||||
pub trait Num:
|
|
||||||
Copy
|
|
||||||
+ Default
|
|
||||||
+ PartialEq
|
|
||||||
+ PartialOrd
|
|
||||||
+ Add<Output = Self>
|
|
||||||
+ Sub<Output = Self>
|
|
||||||
+ Mul<Output = Self>
|
|
||||||
+ Div<Output = Self>
|
|
||||||
+ Neg<Output = Self>
|
|
||||||
+ AddAssign
|
|
||||||
{
|
|
||||||
const BITS: u32;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Num for i64 {
|
|
||||||
const BITS: u32 = 64;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Num for i128 {
|
|
||||||
const BITS: u32 = 128;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Num for f64 {
|
|
||||||
const BITS: u32 = 64;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait ZnxZero: ZnxViewMut
|
pub trait ZnxZero: ZnxViewMut
|
||||||
where
|
where
|
||||||
Self: Sized,
|
Self: Sized,
|
||||||
@@ -261,128 +155,96 @@ pub fn switch_degree<S: Copy, DMut: ZnxViewMut<Scalar = S> + ZnxZero, D: ZnxView
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
|
||||||
|
|
||||||
|
use crate::{ScratchBorr, cast_mut};
|
||||||
|
pub trait Integer:
|
||||||
|
Copy
|
||||||
|
+ Default
|
||||||
|
+ PartialEq
|
||||||
|
+ PartialOrd
|
||||||
|
+ Add<Output = Self>
|
||||||
|
+ Sub<Output = Self>
|
||||||
|
+ Mul<Output = Self>
|
||||||
|
+ Div<Output = Self>
|
||||||
|
+ Neg<Output = Self>
|
||||||
|
+ Shl<Output = Self>
|
||||||
|
+ Shr<Output = Self>
|
||||||
|
+ AddAssign
|
||||||
|
{
|
||||||
|
const BITS: u32;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Integer for i64 {
|
||||||
|
const BITS: u32 = 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Integer for i128 {
|
||||||
|
const BITS: u32 = 128;
|
||||||
|
}
|
||||||
|
|
||||||
// (Jay)TODO: implement rsh for VecZnx, VecZnxBig
|
// (Jay)TODO: implement rsh for VecZnx, VecZnxBig
|
||||||
// pub trait ZnxRsh: ZnxZero {
|
// pub trait ZnxRsh: ZnxZero {
|
||||||
// fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) {
|
// fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) {
|
||||||
// rsh(k, log_base2k, self, col, carry)
|
// rsh(k, log_base2k, self, col, carry)
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
// pub fn rsh<V: ZnxRsh + ZnxZero>(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) {
|
pub fn rsh<V: ZnxZero>(k: usize, log_base2k: usize, a: &mut V, a_col: usize, scratch: &mut ScratchBorr)
|
||||||
// let n: usize = a.n();
|
where
|
||||||
// let size: usize = a.size();
|
V::Scalar: From<usize> + Integer,
|
||||||
// let cols: usize = a.cols();
|
{
|
||||||
|
let n: usize = a.n();
|
||||||
|
let size: usize = a.size();
|
||||||
|
let cols: usize = a.cols();
|
||||||
|
|
||||||
// #[cfg(debug_assertions)]
|
// #[cfg(debug_assertions)]
|
||||||
// {
|
// {
|
||||||
// assert!(
|
// assert!(
|
||||||
// tmp_bytes.len() >= rsh_tmp_bytes::<V::Scalar>(n),
|
// tmp_bytes.len() >= rsh_tmp_bytes::<V::Scalar>(n),
|
||||||
// "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})",
|
// "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})",
|
||||||
// tmp_bytes.len() / size_of::<V::Scalar>(),
|
// tmp_bytes.len() / size_of::<V::Scalar>(),
|
||||||
// n,
|
// n,
|
||||||
// size,
|
// size,
|
||||||
// );
|
// );
|
||||||
// assert_alignement(tmp_bytes.as_ptr());
|
// assert_alignement(tmp_bytes.as_ptr());
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// let size: usize = a.size();
|
let size: usize = a.size();
|
||||||
// let steps: usize = k / log_base2k;
|
let steps: usize = k / log_base2k;
|
||||||
|
|
||||||
// a.raw_mut().rotate_right(n * steps * cols);
|
a.raw_mut().rotate_right(n * steps * cols);
|
||||||
// (0..cols).for_each(|i| {
|
(0..cols).for_each(|i| {
|
||||||
// (0..steps).for_each(|j| {
|
(0..steps).for_each(|j| {
|
||||||
// a.zero_at(i, j);
|
a.zero_at(i, j);
|
||||||
// })
|
})
|
||||||
// });
|
});
|
||||||
|
|
||||||
// let k_rem: usize = k % log_base2k;
|
let k_rem: usize = k % log_base2k;
|
||||||
|
|
||||||
// if k_rem != 0 {
|
if k_rem != 0 {
|
||||||
// let carry: &mut [V::Scalar] = cast_mut(tmp_bytes);
|
let (carry, _) = scratch.tmp_scalar_slice::<V::Scalar>(rsh_tmp_bytes::<V::Scalar>(n));
|
||||||
|
|
||||||
// unsafe {
|
unsafe {
|
||||||
// std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::<V::Scalar>());
|
std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::<V::Scalar>());
|
||||||
// }
|
}
|
||||||
|
|
||||||
// let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap();
|
let log_base2k_t = V::Scalar::from(log_base2k);
|
||||||
// let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap();
|
let shift = V::Scalar::from(V::Scalar::BITS as usize - k_rem);
|
||||||
// let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap();
|
let k_rem_t = V::Scalar::from(k_rem);
|
||||||
|
|
||||||
// (steps..size).for_each(|i| {
|
(0..cols).for_each(|i| {
|
||||||
// izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| {
|
(steps..size).for_each(|j| {
|
||||||
// *xi += *ci << log_base2k_t;
|
izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| {
|
||||||
// *ci = get_base_k_carry(*xi, shift);
|
*xi += *ci << log_base2k_t;
|
||||||
// *xi = (*xi - *ci) >> k_rem_t;
|
*ci = (*xi << shift) >> shift;
|
||||||
// });
|
*xi = (*xi - *ci) >> k_rem_t;
|
||||||
// })
|
});
|
||||||
// }
|
});
|
||||||
// }
|
//TODO: ZERO CARRYcarry
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// #[inline(always)]
|
pub fn rsh_tmp_bytes<T>(n: usize) -> usize {
|
||||||
// fn get_base_k_carry<T: Num>(x: T, shift: T) -> T {
|
n * std::mem::size_of::<T>()
|
||||||
// (x << shift) >> shift
|
}
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn rsh_tmp_bytes<T: Num>(n: usize) -> usize {
|
|
||||||
// n * std::mem::size_of::<T>()
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub trait ZnxLayout: ZnxInfos {
|
|
||||||
// type Scalar;
|
|
||||||
|
|
||||||
// /// Returns true if the receiver is only borrowing the data.
|
|
||||||
// fn borrowing(&self) -> bool {
|
|
||||||
// self.znx().data.len() == 0
|
|
||||||
// }
|
|
||||||
|
|
||||||
// /// Returns a non-mutable pointer to the underlying coefficients array.
|
|
||||||
// fn as_ptr(&self) -> *const Self::Scalar {
|
|
||||||
// self.znx().ptr as *const Self::Scalar
|
|
||||||
// }
|
|
||||||
|
|
||||||
// /// Returns a mutable pointer to the underlying coefficients array.
|
|
||||||
// fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
|
|
||||||
// self.znx_mut().ptr as *mut Self::Scalar
|
|
||||||
// }
|
|
||||||
|
|
||||||
// /// Returns a non-mutable reference to the entire underlying coefficient array.
|
|
||||||
// fn raw(&self) -> &[Self::Scalar] {
|
|
||||||
// unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// /// Returns a mutable reference to the entire underlying coefficient array.
|
|
||||||
// fn raw_mut(&mut self) -> &mut [Self::Scalar] {
|
|
||||||
// unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
|
|
||||||
// fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
|
|
||||||
// #[cfg(debug_assertions)]
|
|
||||||
// {
|
|
||||||
// assert!(i < self.cols());
|
|
||||||
// assert!(j < self.size());
|
|
||||||
// }
|
|
||||||
// let offset: usize = self.n() * (j * self.cols() + i);
|
|
||||||
// unsafe { self.as_ptr().add(offset) }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
|
|
||||||
// fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
|
|
||||||
// #[cfg(debug_assertions)]
|
|
||||||
// {
|
|
||||||
// assert!(i < self.cols());
|
|
||||||
// assert!(j < self.size());
|
|
||||||
// }
|
|
||||||
// let offset: usize = self.n() * (j * self.cols() + i);
|
|
||||||
// unsafe { self.as_mut_ptr().add(offset) }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// /// Returns non-mutable reference to the (i, j)-th small polynomial.
|
|
||||||
// fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
|
|
||||||
// unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// /// Returns mutable reference to the (i, j)-th small polynomial.
|
|
||||||
// fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
|
|
||||||
// unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|||||||
Reference in New Issue
Block a user