mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
wip
This commit is contained in:
@@ -4,3 +4,7 @@ use crate::layouts::Backend;
|
||||
pub trait ModuleNew<B: Backend> {
|
||||
fn new(n: u64) -> Self;
|
||||
}
|
||||
|
||||
pub trait ModuleN {
|
||||
fn n(&self) -> usize;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
api::{SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
|
||||
layouts::{Backend, MatZnx, Module, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
|
||||
layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
};
|
||||
|
||||
/// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes.
|
||||
@@ -28,11 +28,14 @@ pub trait TakeSlice {
|
||||
fn take_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self);
|
||||
}
|
||||
|
||||
pub trait ScratchTakeBasic<B: Backend>
|
||||
pub trait ScratchTakeBasic
|
||||
where
|
||||
Self: TakeSlice,
|
||||
{
|
||||
fn take_scalar_znx(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
|
||||
fn take_scalar_znx<M>(&mut self, module: &M, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self)
|
||||
where
|
||||
M: ModuleN,
|
||||
{
|
||||
let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(module.n(), cols));
|
||||
(
|
||||
ScalarZnx::from_data(take_slice, module.n(), cols),
|
||||
@@ -40,15 +43,18 @@ where
|
||||
)
|
||||
}
|
||||
|
||||
fn take_svp_ppol(&mut self, module: &Module<B>, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self)
|
||||
fn take_svp_ppol<M, B: Backend>(&mut self, module: &M, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self)
|
||||
where
|
||||
Module<B>: SvpPPolBytesOf,
|
||||
M: SvpPPolBytesOf + ModuleN,
|
||||
{
|
||||
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_svp_ppol(cols));
|
||||
(SvpPPol::from_data(take_slice, module.n(), cols), rem_slice)
|
||||
}
|
||||
|
||||
fn take_vec_znx(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
|
||||
fn take_vec_znx<M>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self)
|
||||
where
|
||||
M: ModuleN,
|
||||
{
|
||||
let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(module.n(), cols, size));
|
||||
(
|
||||
VecZnx::from_data(take_slice, module.n(), cols, size),
|
||||
@@ -56,9 +62,9 @@ where
|
||||
)
|
||||
}
|
||||
|
||||
fn take_vec_znx_big(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self)
|
||||
fn take_vec_znx_big<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self)
|
||||
where
|
||||
Module<B>: VecZnxBigBytesOf,
|
||||
M: VecZnxBigBytesOf + ModuleN,
|
||||
{
|
||||
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_big(cols, size));
|
||||
(
|
||||
@@ -67,9 +73,9 @@ where
|
||||
)
|
||||
}
|
||||
|
||||
fn take_vec_znx_dft(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self)
|
||||
fn take_vec_znx_dft<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self)
|
||||
where
|
||||
Module<B>: VecZnxDftBytesOf,
|
||||
M: VecZnxDftBytesOf + ModuleN,
|
||||
{
|
||||
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_dft(cols, size));
|
||||
|
||||
@@ -79,15 +85,15 @@ where
|
||||
)
|
||||
}
|
||||
|
||||
fn take_vec_znx_dft_slice(
|
||||
fn take_vec_znx_dft_slice<M, B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
module: &M,
|
||||
len: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self)
|
||||
where
|
||||
Module<B>: VecZnxDftBytesOf,
|
||||
M: VecZnxDftBytesOf + ModuleN,
|
||||
{
|
||||
let mut scratch: &mut Self = self;
|
||||
let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(len);
|
||||
@@ -99,13 +105,10 @@ where
|
||||
(slice, scratch)
|
||||
}
|
||||
|
||||
fn take_vec_znx_slice(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
len: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
|
||||
fn take_vec_znx_slice<M>(&mut self, module: &M, len: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self)
|
||||
where
|
||||
M: ModuleN,
|
||||
{
|
||||
let mut scratch: &mut Self = self;
|
||||
let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
@@ -116,16 +119,16 @@ where
|
||||
(slice, scratch)
|
||||
}
|
||||
|
||||
fn take_vmp_pmat(
|
||||
fn take_vmp_pmat<M, B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
module: &M,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (VmpPMat<&mut [u8], B>, &mut Self)
|
||||
where
|
||||
Module<B>: VmpPMatBytesOf,
|
||||
M: VmpPMatBytesOf + ModuleN,
|
||||
{
|
||||
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vmp_pmat(rows, cols_in, cols_out, size));
|
||||
(
|
||||
@@ -134,14 +137,17 @@ where
|
||||
)
|
||||
}
|
||||
|
||||
fn take_mat_znx(
|
||||
fn take_mat_znx<M>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
module: &M,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnx<&mut [u8]>, &mut Self) {
|
||||
) -> (MatZnx<&mut [u8]>, &mut Self)
|
||||
where
|
||||
M: ModuleN,
|
||||
{
|
||||
let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(module.n(), rows, cols_in, cols_out, size));
|
||||
(
|
||||
MatZnx::from_data(take_slice, module.n(), rows, cols_in, cols_out, size),
|
||||
|
||||
Reference in New Issue
Block a user