renamed vmp API closer to spqlios

This commit is contained in:
Pro7ech
2025-08-25 11:58:51 +02:00
parent 1551f7a6f0
commit a1b865709d
38 changed files with 431 additions and 393 deletions

View File

@@ -1,12 +1,12 @@
use crate::{
api::{
VmpApply, VmpApplyAdd, VmpApplyAddTmpBytes, VmpApplyTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes,
VmpPrepare, VmpPrepareTmpBytes,
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc,
VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes,
},
layouts::{Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef},
oep::{
VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl,
VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
},
};
@@ -59,11 +59,11 @@ where
}
}
impl<B> VmpApplyTmpBytes for Module<B>
impl<B> VmpApplyDftToDftTmpBytes for Module<B>
where
B: Backend + VmpApplyTmpBytesImpl<B>,
B: Backend + VmpApplyDftToDftTmpBytesImpl<B>,
{
fn vmp_apply_tmp_bytes(
fn vmp_apply_dft_to_dft_tmp_bytes(
&self,
res_size: usize,
a_size: usize,
@@ -72,31 +72,31 @@ where
b_cols_out: usize,
b_size: usize,
) -> usize {
B::vmp_apply_tmp_bytes_impl(
B::vmp_apply_dft_to_dft_tmp_bytes_impl(
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
)
}
}
impl<B> VmpApply<B> for Module<B>
impl<B> VmpApplyDftToDft<B> for Module<B>
where
B: Backend + VmpApplyImpl<B>,
B: Backend + VmpApplyDftToDftImpl<B>,
{
fn vmp_apply<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
fn vmp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
C: VmpPMatToRef<B>,
{
B::vmp_apply_impl(self, res, a, b, scratch);
B::vmp_apply_dft_to_dft_impl(self, res, a, b, scratch);
}
}
impl<B> VmpApplyAddTmpBytes for Module<B>
impl<B> VmpApplyDftToDftAddTmpBytes for Module<B>
where
B: Backend + VmpApplyAddTmpBytesImpl<B>,
B: Backend + VmpApplyDftToDftAddTmpBytesImpl<B>,
{
fn vmp_apply_add_tmp_bytes(
fn vmp_apply_dft_to_dft_add_tmp_bytes(
&self,
res_size: usize,
a_size: usize,
@@ -105,22 +105,22 @@ where
b_cols_out: usize,
b_size: usize,
) -> usize {
B::vmp_apply_add_tmp_bytes_impl(
B::vmp_apply_dft_to_dft_add_tmp_bytes_impl(
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
)
}
}
impl<B> VmpApplyAdd<B> for Module<B>
impl<B> VmpApplyDftToDftAdd<B> for Module<B>
where
B: Backend + VmpApplyAddImpl<B>,
B: Backend + VmpApplyDftToDftAddImpl<B>,
{
fn vmp_apply_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
fn vmp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
C: VmpPMatToRef<B>,
{
B::vmp_apply_add_impl(self, res, a, b, scale, scratch);
B::vmp_apply_dft_to_dft_add_impl(self, res, a, b, scale, scratch);
}
}