This commit is contained in:
Jean-Philippe Bossuat
2025-05-27 17:49:43 +02:00
parent dec3481a6f
commit a295085724
32 changed files with 897 additions and 1375 deletions

View File

@@ -152,81 +152,49 @@ impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> {
}
pub type MatZnxDftOwned<B> = MatZnxDft<Vec<u8>, B>;
pub type MatZnxDftMut<'a, B> = MatZnxDft<&'a mut [u8], B>;
pub type MatZnxDftRef<'a, B> = MatZnxDft<&'a [u8], B>;
pub trait MatZnxDftToRef<B: Backend> {
pub trait MatZnxToRef<B: Backend> {
fn to_ref(&self) -> MatZnxDft<&[u8], B>;
}
pub trait MatZnxDftToMut<B: Backend>: MatZnxDftToRef<B> {
impl<D, B: Backend> MatZnxToRef<B> for MatZnxDft<D, B>
where
D: AsRef<[u8]>,
B: Backend,
{
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft {
data: self.data.as_ref(),
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: std::marker::PhantomData,
}
}
}
pub trait MatZnxToMut<B: Backend> {
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>;
}
impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<Vec<u8>, B> {
impl<D, B: Backend> MatZnxToMut<B> for MatZnxDft<D, B>
where
D: AsRef<[u8]> + AsMut<[u8]>,
B: Backend,
{
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
MatZnxDft {
data: self.data.as_mut_slice(),
data: self.data.as_mut(),
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft {
data: self.data.as_slice(),
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
MatZnxDft {
data: self.data,
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&mut [u8], B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft {
data: self.data,
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&[u8], B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft {
data: self.data,
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
_phantom: std::marker::PhantomData,
}
}
}

View File

@@ -2,7 +2,7 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut,
Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxToMut, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut,
VecZnxDftToRef,
};
@@ -47,27 +47,27 @@ pub trait MatZnxDftOps<BACKEND: Backend> {
///
/// # Arguments
///
/// * `b`: [MatZnxDft] on which the values are encoded.
/// * `res`: [MatZnxDft] on which the values are encoded.
/// * `a`: the [VecZnxDft] to encode on the [MatZnxDft].
/// * `row_i`: the index of the row to prepare.
///
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
where
R: MatZnxDftToMut<BACKEND>,
A: VecZnxDftToRef<BACKEND>;
R: MatZnxToMut<FFT64>,
A: VecZnxDftToRef<FFT64>;
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
/// * `res`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
/// * `a`: [MatZnxDft] on which the values are encoded.
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: MatZnxDftToRef<BACKEND>;
R: VecZnxDftToMut<FFT64>,
A: MatZnxToRef<FFT64>;
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
@@ -96,16 +96,16 @@ pub trait MatZnxDftOps<BACKEND: Backend> {
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
where
R: VecZnxDftToMut<BACKEND>,
A: VecZnxDftToRef<BACKEND>,
B: MatZnxDftToRef<BACKEND>;
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
B: MatZnxToRef<FFT64>;
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
where
R: VecZnxDftToMut<BACKEND>,
A: VecZnxDftToRef<BACKEND>,
B: MatZnxDftToRef<BACKEND>;
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
B: MatZnxToRef<FFT64>;
}
impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
@@ -154,10 +154,10 @@ impl<BACKEND: Backend> MatZnxDftScratch for Module<BACKEND> {
impl MatZnxDftOps<FFT64> for Module<FFT64> {
fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
where
R: MatZnxDftToMut<FFT64>,
R: MatZnxToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res: MatZnxDft<&mut [u8], _> = res.to_mut();
let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: VecZnxDft<&[u8], _> = a.to_ref();
#[cfg(debug_assertions)]
@@ -207,9 +207,9 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
where
R: VecZnxDftToMut<FFT64>,
A: MatZnxDftToRef<FFT64>,
A: MatZnxToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: MatZnxDft<&[u8], _> = a.to_ref();
#[cfg(debug_assertions)]
@@ -259,7 +259,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
B: MatZnxDftToRef<FFT64>,
B: MatZnxToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
let a: VecZnxDft<&[u8], _> = a.to_ref();
@@ -313,7 +313,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
B: MatZnxDftToRef<FFT64>,
B: MatZnxToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
let a: VecZnxDft<&[u8], _> = a.to_ref();

View File

@@ -42,7 +42,7 @@ pub trait AddNormal {
fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64);
}
impl<T> FillUniform for VecZnx<T>
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillUniform for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
@@ -59,7 +59,7 @@ where
}
}
impl<T> FillDistF64 for VecZnx<T>
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillDistF64 for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
@@ -102,7 +102,7 @@ where
}
}
impl<T> AddDistF64 for VecZnx<T>
impl<T: AsMut<[u8]> + AsRef<[u8]>> AddDistF64 for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
@@ -145,7 +145,7 @@ where
}
}
impl<T> FillNormal for VecZnx<T>
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillNormal for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
@@ -161,7 +161,7 @@ where
}
}
impl<T> AddNormal for VecZnx<T>
impl<T: AsMut<[u8]> + AsRef<[u8]>> AddNormal for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
@@ -177,7 +177,7 @@ where
}
}
impl<T> FillDistF64 for VecZnxBig<T, FFT64>
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillDistF64 for VecZnxBig<T, FFT64>
where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{
@@ -220,7 +220,7 @@ where
}
}
impl<T> AddDistF64 for VecZnxBig<T, FFT64>
impl<T: AsMut<[u8]> + AsRef<[u8]>> AddDistF64 for VecZnxBig<T, FFT64>
where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{
@@ -263,7 +263,7 @@ where
}
}
impl<T> FillNormal for VecZnxBig<T, FFT64>
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillNormal for VecZnxBig<T, FFT64>
where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{
@@ -279,7 +279,7 @@ where
}
}
impl<T> AddNormal for VecZnxBig<T, FFT64>
impl<T: AsMut<[u8]> + AsRef<[u8]>> AddNormal for VecZnxBig<T, FFT64>
where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{

View File

@@ -196,108 +196,57 @@ pub trait ScalarZnxToRef {
fn to_ref(&self) -> ScalarZnx<&[u8]>;
}
impl<D> ScalarZnxToRef for ScalarZnx<D>
where
D: AsRef<[u8]>,
{
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
}
}
}
pub trait ScalarZnxToMut {
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]>;
}
impl ScalarZnxToMut for ScalarZnx<Vec<u8>> {
impl<D> ScalarZnxToMut for ScalarZnx<D>
where
D: AsRef<[u8]> + AsMut<[u8]>,
{
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
ScalarZnx {
data: self.data.as_mut_slice(),
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToMut for ScalarZnx<Vec<u8>> {
impl<D> VecZnxToRef for ScalarZnx<D>
where
D: AsRef<[u8]>,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl<D> VecZnxToMut for ScalarZnx<D>
where
D: AsRef<[u8]> + AsMut<[u8]>,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<Vec<u8>> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToRef for ScalarZnx<Vec<u8>> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToMut for ScalarZnx<&mut [u8]> {
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
ScalarZnx {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToMut for ScalarZnx<&mut [u8]> {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<&mut [u8]> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToRef for ScalarZnx<&mut [u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<&[u8]> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToRef for ScalarZnx<&[u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
size: 1,

View File

@@ -113,14 +113,33 @@ pub trait ScalarZnxDftToRef<B: Backend> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B>;
}
impl<D, B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<D, B>
where
D: AsRef<[u8]>,
B: Backend,
{
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
ScalarZnxDft {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
pub trait ScalarZnxDftToMut<B: Backend> {
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B>;
}
impl<B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
impl<D, B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<D, B>
where
D: AsMut<[u8]> + AsRef<[u8]>,
B: Backend,
{
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
ScalarZnxDft {
data: self.data.as_mut_slice(),
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
_phantom: PhantomData,
@@ -128,106 +147,34 @@ impl<B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
}
}
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
ScalarZnxDft {
data: self.data.as_slice(),
impl<D, B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<D, B>
where
D: AsRef<[u8]>,
B: Backend,
{
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
_phantom: PhantomData,
size: 1,
_phantom: std::marker::PhantomData,
}
}
}
impl<B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
ScalarZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&mut [u8], B> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
ScalarZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
ScalarZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
impl<D, B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<D, B>
where
D: AsRef<[u8]> + AsMut<[u8]>,
B: Backend,
{
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data.as_mut_slice(),
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&mut [u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
_phantom: std::marker::PhantomData,
}
}
}

View File

@@ -1,103 +1,105 @@
use crate::ffi::svp;
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft,
VecZnxDftToMut, VecZnxDftToRef,
};
pub trait ScalarZnxDftAlloc<B: Backend> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
}
pub trait ScalarZnxDftOps<BACKEND: Backend> {
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxDftToMut<BACKEND>,
A: ScalarZnxToRef;
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>,
B: VecZnxDftToRef<FFT64>;
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>;
}
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B> {
ScalarZnxDftOwned::new(self, cols)
}
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize {
ScalarZnxDftOwned::bytes_of(self, cols)
}
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> {
ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
}
}
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxDftToMut<FFT64>,
A: ScalarZnxToRef,
{
unsafe {
svp::svp_prepare(
self.ptr,
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
a.to_ref().at_ptr(a_col, 0),
)
}
}
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: ScalarZnxDftToRef<FFT64>,
B: VecZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
let b: VecZnxDft<&[u8], FFT64> = b.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
b.size() as u64,
b.cols() as u64,
)
}
}
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: ScalarZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
)
}
}
}
use crate::ffi::svp;
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft,
VecZnxDftToMut, VecZnxDftToRef,
};
pub trait ScalarZnxDftAlloc<B: Backend> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
}
pub trait ScalarZnxDftOps<BACKEND: Backend> {
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxDftToMut<BACKEND>,
A: ScalarZnxToRef;
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>,
B: VecZnxDftToRef<BACKEND>;
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>;
}
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B> {
ScalarZnxDftOwned::new(self, cols)
}
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize {
ScalarZnxDftOwned::bytes_of(self, cols)
}
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> {
ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
}
}
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxDftToMut<FFT64>,
A: ScalarZnxToRef,
{
unsafe {
svp::svp_prepare(
self.ptr,
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
a.to_ref().at_ptr(a_col, 0),
)
}
}
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: ScalarZnxDftToRef<FFT64>,
B: VecZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
let b: VecZnxDft<&[u8], FFT64> = b.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
b.size() as u64,
b.cols() as u64,
)
}
}
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: ScalarZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
)
}
}
}

View File

@@ -237,14 +237,15 @@ fn normalize<D: AsMut<[u8]> + AsRef<[u8]>>(basek: usize, a: &mut VecZnx<D>, a_co
}
}
impl<D> VecZnx<D>
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D>
where
VecZnx<D>: VecZnxToMut + ZnxInfos,
{
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<R>(&mut self, self_col: usize, a: &R, a_col: usize)
pub fn extract_column<R>(&mut self, self_col: usize, a: &VecZnx<R>, a_col: usize)
where
R: VecZnxToRef + ZnxInfos,
R: AsRef<[u8]>,
VecZnx<R>: VecZnxToRef + ZnxInfos,
{
#[cfg(debug_assertions)]
{
@@ -313,72 +314,41 @@ pub trait VecZnxToRef {
fn to_ref(&self) -> VecZnx<&[u8]>;
}
pub trait VecZnxToMut: VecZnxToRef {
impl<D> VecZnxToRef for VecZnx<D>
where
D: AsRef<[u8]>,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
pub trait VecZnxToMut {
fn to_mut(&mut self) -> VecZnx<&mut [u8]>;
}
impl VecZnxToMut for VecZnx<Vec<u8>> {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl VecZnxToRef for VecZnx<Vec<u8>> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl VecZnxToMut for VecZnx<&mut [u8]> {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl VecZnxToRef for VecZnx<&mut [u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl VecZnxToRef for VecZnx<&[u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl<DataSelf> VecZnx<DataSelf>
impl<D> VecZnxToMut for VecZnx<D>
where
VecZnx<DataSelf>: VecZnxToRef,
D: AsRef<[u8]> + AsMut<[u8]>,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl<DataSelf: AsRef<[u8]>> VecZnx<DataSelf> {
pub fn clone(&self) -> VecZnx<Vec<u8>> {
let self_ref: VecZnx<&[u8]> = self.to_ref();
VecZnx {
data: self_ref.data.to_vec(),
n: self_ref.n,

View File

@@ -94,7 +94,7 @@ impl<D, B: Backend> VecZnxBig<D, B> {
}
}
impl<D> VecZnxBig<D, FFT64>
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxBig<D, FFT64>
where
VecZnxBig<D, FFT64>: VecZnxBigToMut<FFT64> + ZnxInfos,
{
@@ -110,9 +110,9 @@ where
}
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxBig<C, FFT64>, a_col: usize)
pub fn extract_column<C>(&mut self, self_col: usize, a: &C, a_col: usize)
where
VecZnxBig<C, FFT64>: VecZnxBigToRef<FFT64> + ZnxInfos,
C: VecZnxBigToRef<FFT64> + ZnxInfos,
{
#[cfg(debug_assertions)]
{
@@ -144,66 +144,38 @@ pub trait VecZnxBigToRef<B: Backend> {
fn to_ref(&self) -> VecZnxBig<&[u8], B>;
}
impl<D, B: Backend> VecZnxBigToRef<B> for VecZnxBig<D, B>
where
D: AsRef<[u8]>,
B: Backend,
{
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: std::marker::PhantomData,
}
}
}
pub trait VecZnxBigToMut<B: Backend> {
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>;
}
impl<B: Backend> VecZnxBigToMut<B> for VecZnxBig<Vec<u8>, B> {
impl<D, B: Backend> VecZnxBigToMut<B> for VecZnxBig<D, B>
where
D: AsRef<[u8]> + AsMut<[u8]>,
B: Backend,
{
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
VecZnxBig {
data: self.data.as_mut_slice(),
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<Vec<u8>, B> {
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToMut<B> for VecZnxBig<&mut [u8], B> {
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
VecZnxBig {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<&mut [u8], B> {
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<&[u8], B> {
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
_phantom: std::marker::PhantomData,
}
}
}

View File

@@ -128,7 +128,7 @@ pub trait VecZnxBigOps<BACKEND: Backend> {
fn vec_znx_big_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where
R: VecZnxToMut,
A: VecZnxBigToRef<BACKEND>;
A: VecZnxBigToRef<FFT64>;
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
@@ -501,7 +501,7 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
}
}
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, res_col: usize)
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<FFT64>,
{
@@ -513,10 +513,10 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
unsafe {
vec_znx::vec_znx_negate(
self.ptr,
a.at_mut_ptr(res_col, 0),
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(res_col, 0),
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)

View File

@@ -91,14 +91,14 @@ impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
}
}
impl<D> VecZnxDft<D, FFT64>
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxDft<D, FFT64>
where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64> + ZnxInfos,
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
{
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxDft<C, FFT64>, a_col: usize)
pub fn extract_column<C: AsRef<[u8]>>(&mut self, self_col: usize, a: &VecZnxDft<C, FFT64>, a_col: usize)
where
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
@@ -142,66 +142,38 @@ pub trait VecZnxDftToRef<B: Backend> {
fn to_ref(&self) -> VecZnxDft<&[u8], B>;
}
pub trait VecZnxDftToMut<B: Backend>: VecZnxDftToRef<B> {
impl<D, B: Backend> VecZnxDftToRef<B> for VecZnxDft<D, B>
where
D: AsRef<[u8]>,
B: Backend,
{
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: std::marker::PhantomData,
}
}
}
pub trait VecZnxDftToMut<B: Backend> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>;
}
impl<B: Backend> VecZnxDftToMut<B> for VecZnxDft<Vec<u8>, B> {
impl<D, B: Backend> VecZnxDftToMut<B> for VecZnxDft<D, B>
where
D: AsRef<[u8]> + AsMut<[u8]>,
B: Backend,
{
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data.as_mut_slice(),
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for VecZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<&mut [u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<&[u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
_phantom: std::marker::PhantomData,
}
}
}

View File

@@ -59,7 +59,7 @@ pub trait VecZnxOps {
A: VecZnxToRef;
/// Adds the selected column of `a` on the selected column and limb of `res`.
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize)
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef;