fixed scratch API

This commit is contained in:
Pro7ech
2025-10-21 10:47:46 +02:00
parent 681ec7e349
commit fef2a2fc27
28 changed files with 112 additions and 153 deletions

View File

@@ -34,13 +34,11 @@ pub trait ScratchTakeBasic
where
Self: TakeSlice,
{
fn take_scalar_znx<M>(&mut self, module: &M, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self)
where
M: ModuleN,
fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self)
{
let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(module.n(), cols));
let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(n, cols));
(
ScalarZnx::from_data(take_slice, module.n(), cols),
ScalarZnx::from_data(take_slice, n, cols),
rem_slice,
)
}
@@ -53,13 +51,10 @@ where
(SvpPPol::from_data(take_slice, module.n(), cols), rem_slice)
}
fn take_vec_znx<M>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self)
where
M: ModuleN,
{
let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(module.n(), cols, size));
fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self){
let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size));
(
VecZnx::from_data(take_slice, module.n(), cols, size),
VecZnx::from_data(take_slice, n, cols, size),
rem_slice,
)
}
@@ -107,14 +102,11 @@ where
(slice, scratch)
}
fn take_vec_znx_slice<M>(&mut self, module: &M, len: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self)
where
M: ModuleN,
{
fn take_vec_znx_slice(&mut self, n: usize, len: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self){
let mut scratch: &mut Self = self;
let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len);
for _ in 0..len {
let (znx, new_scratch) = scratch.take_vec_znx(module, cols, size);
let (znx, new_scratch) = scratch.take_vec_znx(n, cols, size);
scratch = new_scratch;
slice.push(znx);
}
@@ -139,20 +131,18 @@ where
)
}
fn take_mat_znx<M>(
fn take_mat_znx(
&mut self,
module: &M,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (MatZnx<&mut [u8]>, &mut Self)
where
M: ModuleN,
{
let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(module.n(), rows, cols_in, cols_out, size));
let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(n, rows, cols_in, cols_out, size));
(
MatZnx::from_data(take_slice, module.n(), rows, cols_in, cols_out, size),
MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),
rem_slice,
)
}