This commit is contained in:
Pro7ech
2025-10-15 10:48:14 +02:00
parent a5df85170d
commit 008b800c01
74 changed files with 890 additions and 871 deletions

View File

@@ -4,3 +4,7 @@ use crate::layouts::Backend;
pub trait ModuleNew<B: Backend> {
fn new(n: u64) -> Self;
}
pub trait ModuleN {
fn n(&self) -> usize;
}

View File

@@ -1,6 +1,6 @@
use crate::{
api::{SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
layouts::{Backend, MatZnx, Module, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
};
/// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes.
@@ -28,11 +28,14 @@ pub trait TakeSlice {
fn take_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self);
}
pub trait ScratchTakeBasic<B: Backend>
pub trait ScratchTakeBasic
where
Self: TakeSlice,
{
fn take_scalar_znx(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
fn take_scalar_znx<M>(&mut self, module: &M, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self)
where
M: ModuleN,
{
let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(module.n(), cols));
(
ScalarZnx::from_data(take_slice, module.n(), cols),
@@ -40,15 +43,18 @@ where
)
}
fn take_svp_ppol(&mut self, module: &Module<B>, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self)
fn take_svp_ppol<M, B: Backend>(&mut self, module: &M, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self)
where
Module<B>: SvpPPolBytesOf,
M: SvpPPolBytesOf + ModuleN,
{
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_svp_ppol(cols));
(SvpPPol::from_data(take_slice, module.n(), cols), rem_slice)
}
fn take_vec_znx(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
fn take_vec_znx<M>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self)
where
M: ModuleN,
{
let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(module.n(), cols, size));
(
VecZnx::from_data(take_slice, module.n(), cols, size),
@@ -56,9 +62,9 @@ where
)
}
fn take_vec_znx_big(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self)
fn take_vec_znx_big<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self)
where
Module<B>: VecZnxBigBytesOf,
M: VecZnxBigBytesOf + ModuleN,
{
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_big(cols, size));
(
@@ -67,9 +73,9 @@ where
)
}
fn take_vec_znx_dft(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self)
fn take_vec_znx_dft<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self)
where
Module<B>: VecZnxDftBytesOf,
M: VecZnxDftBytesOf + ModuleN,
{
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_dft(cols, size));
@@ -79,15 +85,15 @@ where
)
}
fn take_vec_znx_dft_slice(
fn take_vec_znx_dft_slice<M, B: Backend>(
&mut self,
module: &Module<B>,
module: &M,
len: usize,
cols: usize,
size: usize,
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self)
where
Module<B>: VecZnxDftBytesOf,
M: VecZnxDftBytesOf + ModuleN,
{
let mut scratch: &mut Self = self;
let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(len);
@@ -99,13 +105,10 @@ where
(slice, scratch)
}
fn take_vec_znx_slice(
&mut self,
module: &Module<B>,
len: usize,
cols: usize,
size: usize,
) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
fn take_vec_znx_slice<M>(&mut self, module: &M, len: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self)
where
M: ModuleN,
{
let mut scratch: &mut Self = self;
let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len);
for _ in 0..len {
@@ -116,16 +119,16 @@ where
(slice, scratch)
}
fn take_vmp_pmat(
fn take_vmp_pmat<M, B: Backend>(
&mut self,
module: &Module<B>,
module: &M,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (VmpPMat<&mut [u8], B>, &mut Self)
where
Module<B>: VmpPMatBytesOf,
M: VmpPMatBytesOf + ModuleN,
{
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vmp_pmat(rows, cols_in, cols_out, size));
(
@@ -134,14 +137,17 @@ where
)
}
fn take_mat_znx(
fn take_mat_znx<M>(
&mut self,
module: &Module<B>,
module: &M,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (MatZnx<&mut [u8]>, &mut Self) {
) -> (MatZnx<&mut [u8]>, &mut Self)
where
M: ModuleN,
{
let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(module.n(), rows, cols_in, cols_out, size));
(
MatZnx::from_data(take_slice, module.n(), rows, cols_in, cols_out, size),