Add Zn type

This commit is contained in:
Pro7ech
2025-08-21 12:16:53 +02:00
parent ccd94e36cc
commit bf513dc555
129 changed files with 1400 additions and 686 deletions

View File

@@ -41,18 +41,10 @@ unsafe impl VmpPMatAllocImpl<FFT64> for FFT64 {
}
unsafe impl VmpPrepareTmpBytesImpl<FFT64> for FFT64 {
fn vmp_prepare_tmp_bytes_impl(
module: &Module<FFT64>,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> usize {
fn vmp_prepare_tmp_bytes_impl(module: &Module<FFT64>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
unsafe {
vmp::vmp_prepare_tmp_bytes(
module.ptr(),
n as u64,
(rows * cols_in) as u64,
(cols_out * size) as u64,
) as usize
@@ -102,8 +94,7 @@ unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
);
}
let (tmp_bytes, _) =
scratch.take_slice(module.vmp_prepare_tmp_bytes(res.n(), a.rows(), a.cols_in(), a.cols_out(), a.size()));
let (tmp_bytes, _) = scratch.take_slice(module.vmp_prepare_tmp_bytes(a.rows(), a.cols_in(), a.cols_out(), a.size()));
unsafe {
vmp::vmp_prepare_contiguous(
@@ -121,7 +112,6 @@ unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
unsafe impl VmpApplyTmpBytesImpl<FFT64> for FFT64 {
fn vmp_apply_tmp_bytes_impl(
module: &Module<FFT64>,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
@@ -132,7 +122,6 @@ unsafe impl VmpApplyTmpBytesImpl<FFT64> for FFT64 {
unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes(
module.ptr(),
n as u64,
(res_size * b_cols_out) as u64,
(a_size * b_cols_in) as u64,
(b_rows * b_cols_in) as u64,
@@ -174,7 +163,6 @@ unsafe impl VmpApplyImpl<FFT64> for FFT64 {
}
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
res.n(),
res.size(),
a.size(),
b.rows(),
@@ -201,7 +189,6 @@ unsafe impl VmpApplyImpl<FFT64> for FFT64 {
unsafe impl VmpApplyAddTmpBytesImpl<FFT64> for FFT64 {
fn vmp_apply_add_tmp_bytes_impl(
module: &Module<FFT64>,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
@@ -212,7 +199,6 @@ unsafe impl VmpApplyAddTmpBytesImpl<FFT64> for FFT64 {
unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes(
module.ptr(),
n as u64,
(res_size * b_cols_out) as u64,
(a_size * b_cols_in) as u64,
(b_rows * b_cols_in) as u64,
@@ -254,7 +240,6 @@ unsafe impl VmpApplyAddImpl<FFT64> for FFT64 {
}
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
res.n(),
res.size(),
a.size(),
b.rows(),