mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
renamed vmp API closer to spqlios
This commit is contained in:
@@ -24,8 +24,8 @@ pub trait VmpPrepare<B: Backend> {
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait VmpApplyTmpBytes {
|
||||
fn vmp_apply_tmp_bytes(
|
||||
pub trait VmpApplyDftToDftTmpBytes {
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
@@ -36,7 +36,7 @@ pub trait VmpApplyTmpBytes {
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpApply<B: Backend> {
|
||||
pub trait VmpApplyDftToDft<B: Backend> {
|
||||
/// Applies the vector matrix product [crate::layouts::VecZnxDft] x [crate::layouts::VmpPMat].
|
||||
///
|
||||
/// A vector matrix product numerically equivalent to a sum of [crate::api::SvpApply],
|
||||
@@ -61,7 +61,7 @@ pub trait VmpApply<B: Backend> {
|
||||
/// * `a`: the left operand [crate::layouts::VecZnxDft] of the vector matrix product.
|
||||
/// * `b`: the right operand [crate::layouts::VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpApplyTmpBytes::vmp_apply_tmp_bytes].
|
||||
fn vmp_apply<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
fn vmp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
@@ -69,8 +69,8 @@ pub trait VmpApply<B: Backend> {
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait VmpApplyAddTmpBytes {
|
||||
fn vmp_apply_add_tmp_bytes(
|
||||
pub trait VmpApplyDftToDftAddTmpBytes {
|
||||
fn vmp_apply_dft_to_dft_add_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
@@ -81,8 +81,8 @@ pub trait VmpApplyAddTmpBytes {
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpApplyAdd<B: Backend> {
|
||||
fn vmp_apply_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
|
||||
pub trait VmpApplyDftToDftAdd<B: Backend> {
|
||||
fn vmp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use crate::{
|
||||
api::{
|
||||
VmpApply, VmpApplyAdd, VmpApplyAddTmpBytes, VmpApplyTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes,
|
||||
VmpPrepare, VmpPrepareTmpBytes,
|
||||
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc,
|
||||
VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes,
|
||||
},
|
||||
layouts::{Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef},
|
||||
oep::{
|
||||
VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl,
|
||||
VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
|
||||
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -59,11 +59,11 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApplyTmpBytes for Module<B>
|
||||
impl<B> VmpApplyDftToDftTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyTmpBytesImpl<B>,
|
||||
B: Backend + VmpApplyDftToDftTmpBytesImpl<B>,
|
||||
{
|
||||
fn vmp_apply_tmp_bytes(
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
@@ -72,31 +72,31 @@ where
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_tmp_bytes_impl(
|
||||
B::vmp_apply_dft_to_dft_tmp_bytes_impl(
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApply<B> for Module<B>
|
||||
impl<B> VmpApplyDftToDft<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyImpl<B>,
|
||||
B: Backend + VmpApplyDftToDftImpl<B>,
|
||||
{
|
||||
fn vmp_apply<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
fn vmp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>,
|
||||
{
|
||||
B::vmp_apply_impl(self, res, a, b, scratch);
|
||||
B::vmp_apply_dft_to_dft_impl(self, res, a, b, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApplyAddTmpBytes for Module<B>
|
||||
impl<B> VmpApplyDftToDftAddTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyAddTmpBytesImpl<B>,
|
||||
B: Backend + VmpApplyDftToDftAddTmpBytesImpl<B>,
|
||||
{
|
||||
fn vmp_apply_add_tmp_bytes(
|
||||
fn vmp_apply_dft_to_dft_add_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
@@ -105,22 +105,22 @@ where
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_add_tmp_bytes_impl(
|
||||
B::vmp_apply_dft_to_dft_add_tmp_bytes_impl(
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApplyAdd<B> for Module<B>
|
||||
impl<B> VmpApplyDftToDftAdd<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyAddImpl<B>,
|
||||
B: Backend + VmpApplyDftToDftAddImpl<B>,
|
||||
{
|
||||
fn vmp_apply_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
|
||||
fn vmp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>,
|
||||
{
|
||||
B::vmp_apply_add_impl(self, res, a, b, scale, scratch);
|
||||
B::vmp_apply_dft_to_dft_add_impl(self, res, a, b, scale, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,8 +57,8 @@ pub unsafe trait VmpPMatPrepareImpl<B: Backend> {
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VmpApplyTmpBytesImpl<B: Backend> {
|
||||
fn vmp_apply_tmp_bytes_impl(
|
||||
pub unsafe trait VmpApplyDftToDftTmpBytesImpl<B: Backend> {
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes_impl(
|
||||
module: &Module<B>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
@@ -73,8 +73,8 @@ pub unsafe trait VmpApplyTmpBytesImpl<B: Backend> {
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VmpApplyImpl<B: Backend> {
|
||||
fn vmp_apply_impl<R, A, C>(module: &Module<B>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
pub unsafe trait VmpApplyDftToDftImpl<B: Backend> {
|
||||
fn vmp_apply_dft_to_dft_impl<R, A, C>(module: &Module<B>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
@@ -86,8 +86,8 @@ pub unsafe trait VmpApplyImpl<B: Backend> {
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VmpApplyAddTmpBytesImpl<B: Backend> {
|
||||
fn vmp_apply_add_tmp_bytes_impl(
|
||||
pub unsafe trait VmpApplyDftToDftAddTmpBytesImpl<B: Backend> {
|
||||
fn vmp_apply_dft_to_dft_add_tmp_bytes_impl(
|
||||
module: &Module<B>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
@@ -102,10 +102,16 @@ pub unsafe trait VmpApplyAddTmpBytesImpl<B: Backend> {
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VmpApplyAddImpl<B: Backend> {
|
||||
pub unsafe trait VmpApplyDftToDftAddImpl<B: Backend> {
|
||||
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
|
||||
fn vmp_apply_add_impl<R, A, C>(module: &Module<B>, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
fn vmp_apply_dft_to_dft_add_impl<R, A, C>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
a: &A,
|
||||
b: &C,
|
||||
scale: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>;
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
use crate::{
|
||||
api::{
|
||||
DFT, IDFTTmpA, ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VmpApply, VmpApplyTmpBytes, VmpPMatAlloc, VmpPrepare,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VmpApplyDftToDft, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare,
|
||||
},
|
||||
layouts::{MatZnx, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos, ZnxViewMut},
|
||||
oep::{
|
||||
DFTImpl, IDFTTmpAImpl, ModuleNewImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, VecZnxBigAllocImpl,
|
||||
VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxDftAllocImpl, VmpApplyImpl, VmpApplyTmpBytesImpl,
|
||||
VmpPMatAllocImpl, VmpPMatPrepareImpl,
|
||||
VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxDftAllocImpl, VmpApplyDftToDftImpl,
|
||||
VmpApplyDftToDftTmpBytesImpl, VmpPMatAllocImpl, VmpPMatPrepareImpl,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -17,14 +17,14 @@ pub fn test_vmp_apply<B>()
|
||||
where
|
||||
B: Backend
|
||||
+ ModuleNewImpl<B>
|
||||
+ VmpApplyTmpBytesImpl<B>
|
||||
+ VmpApplyDftToDftTmpBytesImpl<B>
|
||||
+ VecZnxBigNormalizeTmpBytesImpl<B>
|
||||
+ VmpPMatAllocImpl<B>
|
||||
+ VecZnxDftAllocImpl<B>
|
||||
+ VecZnxBigAllocImpl<B>
|
||||
+ VmpPMatPrepareImpl<B>
|
||||
+ DFTImpl<B>
|
||||
+ VmpApplyImpl<B>
|
||||
+ VmpApplyDftToDftImpl<B>
|
||||
+ IDFTTmpAImpl<B>
|
||||
+ ScratchOwnedAllocImpl<B>
|
||||
+ ScratchOwnedBorrowImpl<B>
|
||||
@@ -49,7 +49,7 @@ where
|
||||
let mat_cols_out: usize = res_cols;
|
||||
|
||||
let mut scratch = ScratchOwned::alloc(
|
||||
module.vmp_apply_tmp_bytes(
|
||||
module.vmp_apply_dft_to_dft_tmp_bytes(
|
||||
res_size,
|
||||
a_size,
|
||||
mat_rows,
|
||||
@@ -89,7 +89,7 @@ where
|
||||
module.dft(1, 0, &mut a_dft, i, &a, i);
|
||||
});
|
||||
|
||||
module.vmp_apply(&mut c_dft, &a_dft, &vmp, scratch.borrow());
|
||||
module.vmp_apply_dft_to_dft(&mut c_dft, &a_dft, &vmp, scratch.borrow());
|
||||
|
||||
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
|
||||
|
||||
Reference in New Issue
Block a user