Backend refactor (#120)

* remove spqlios, split cpu_ref and cpu_avx into different crates

* remove spqlios submodule

* update crate naming & add avx tests
This commit is contained in:
Jean-Philippe Bossuat
2025-11-19 15:34:31 +01:00
committed by GitHub
parent 84598e42fe
commit 9e007c988f
182 changed files with 1053 additions and 4483 deletions

View File

@@ -20,6 +20,7 @@ byteorder = {workspace = true}
once_cell = {workspace = true}
rand_chacha = "0.9.0"
bytemuck = {workspace = true}
paste = "1.0.15"
[build-dependencies]

View File

@@ -1,8 +1,8 @@
use crate::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply,
VecZnxIdftApplyConsume, VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare,
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc,
VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, VmpPrepareTmpBytes,
},
layouts::{DataViewMut, DigestU64, FillUniform, MatZnx, Module, ScratchOwned, VecZnx, VecZnxBig},
source::Source,
@@ -129,7 +129,9 @@ where
+ VmpPrepare<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxDftApply<BR>,
+ VecZnxDftApply<BR>
+ VmpPrepareTmpBytes
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
Module<BT>: ModuleNew<BT>
+ VmpApplyDftToDftTmpBytes
@@ -139,7 +141,9 @@ where
+ VmpPrepare<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxDftApply<BT>,
+ VecZnxDftApply<BT>
+ VmpPrepareTmpBytes
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
@@ -151,10 +155,16 @@ where
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(
module_ref.vmp_apply_dft_to_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
module_ref
.vmp_apply_dft_to_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size)
.max(module_ref.vmp_prepare_tmp_bytes(max_size, max_cols, max_cols, max_size))
.max(module_ref.vec_znx_big_normalize_tmp_bytes()),
);
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(
module_test.vmp_apply_dft_to_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
module_test
.vmp_apply_dft_to_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size)
.max(module_test.vmp_prepare_tmp_bytes(max_size, max_cols, max_cols, max_size))
.max(module_test.vec_znx_big_normalize_tmp_bytes()),
);
for cols_in in 1..max_cols + 1 {
@@ -258,7 +268,9 @@ where
+ VmpPrepare<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxDftApply<BR>,
+ VecZnxDftApply<BR>
+ VmpPrepareTmpBytes
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
Module<BT>: ModuleNew<BT>
+ VmpApplyDftToDftAddTmpBytes
@@ -268,7 +280,9 @@ where
+ VmpPrepare<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxDftApply<BT>,
+ VecZnxDftApply<BT>
+ VmpPrepareTmpBytes
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
@@ -280,10 +294,16 @@ where
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(
module_ref.vmp_apply_dft_to_dft_add_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
module_ref
.vmp_apply_dft_to_dft_add_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size)
.max(module_ref.vmp_prepare_tmp_bytes(max_size, max_cols, max_cols, max_size))
.max(module_ref.vec_znx_big_normalize_tmp_bytes()),
);
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(
module_test.vmp_apply_dft_to_dft_add_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
module_test
.vmp_apply_dft_to_dft_add_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size)
.max(module_test.vmp_prepare_tmp_bytes(max_size, max_cols, max_cols, max_size))
.max(module_test.vec_znx_big_normalize_tmp_bytes()),
);
for cols_in in 1..max_cols + 1 {