From 06e4e58b2dec9c0db3e953d9d026f8777126ef98 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sun, 26 Jan 2025 12:26:44 +0100 Subject: [PATCH] spqlios basic wrapper --- Cargo.lock | 113 +- Cargo.toml | 2 +- spqlios/Cargo.toml | 11 + spqlios/build.rs | 52 + spqlios/examples/fft.rs | 57 + spqlios/lib/.clang-format | 14 + spqlios/lib/.gitignore | 4 + spqlios/lib/CMakeLists.txt | 69 + spqlios/lib/CONTRIBUTING.md | 77 + spqlios/lib/Changelog.md | 18 + spqlios/lib/LICENSE | 201 ++ spqlios/lib/README.md | 65 + spqlios/lib/docs/api-full.svg | 416 +++++ spqlios/lib/docs/logo-inpher1.png | Bin 0 -> 24390 bytes spqlios/lib/docs/logo-inpher2.png | Bin 0 -> 24239 bytes spqlios/lib/docs/logo-sandboxaq-black.svg | 139 ++ spqlios/lib/docs/logo-sandboxaq-white.svg | 133 ++ spqlios/lib/manifest.yaml | 2 + spqlios/lib/scripts/auto-release.sh | 27 + spqlios/lib/scripts/ci-pkg | 102 ++ spqlios/lib/scripts/prepare-release | 181 ++ spqlios/lib/spqlios/CMakeLists.txt | 223 +++ spqlios/lib/spqlios/arithmetic/module_api.c | 164 ++ .../arithmetic/scalar_vector_product.c | 63 + spqlios/lib/spqlios/arithmetic/vec_rnx_api.c | 318 ++++ .../arithmetic/vec_rnx_approxdecomp_avx.c | 59 + .../arithmetic/vec_rnx_approxdecomp_ref.c | 75 + .../spqlios/arithmetic/vec_rnx_arithmetic.c | 223 +++ .../spqlios/arithmetic/vec_rnx_arithmetic.h | 340 ++++ .../arithmetic/vec_rnx_arithmetic_avx.c | 189 ++ .../arithmetic/vec_rnx_arithmetic_plugin.h | 88 + .../arithmetic/vec_rnx_arithmetic_private.h | 284 +++ .../arithmetic/vec_rnx_conversions_ref.c | 91 + .../lib/spqlios/arithmetic/vec_rnx_svp_ref.c | 47 + .../lib/spqlios/arithmetic/vec_rnx_vmp_avx.c | 196 ++ .../lib/spqlios/arithmetic/vec_rnx_vmp_ref.c | 251 +++ spqlios/lib/spqlios/arithmetic/vec_znx.c | 333 ++++ .../spqlios/arithmetic/vec_znx_arithmetic.h | 357 ++++ .../arithmetic/vec_znx_arithmetic_private.h | 481 +++++ spqlios/lib/spqlios/arithmetic/vec_znx_avx.c | 103 ++ spqlios/lib/spqlios/arithmetic/vec_znx_big.c | 270 +++ spqlios/lib/spqlios/arithmetic/vec_znx_dft.c | 162 ++ .../lib/spqlios/arithmetic/vec_znx_dft_avx2.c | 1 + .../arithmetic/vector_matrix_product.c | 240 +++ .../arithmetic/vector_matrix_product_avx.c | 137 ++ spqlios/lib/spqlios/arithmetic/zn_api.c | 169 ++ .../spqlios/arithmetic/zn_approxdecomp_ref.c | 81 + .../lib/spqlios/arithmetic/zn_arithmetic.h | 135 ++ .../spqlios/arithmetic/zn_arithmetic_plugin.h | 39 + .../arithmetic/zn_arithmetic_private.h | 150 ++ .../spqlios/arithmetic/zn_conversions_ref.c | 108 ++ .../lib/spqlios/arithmetic/zn_vmp_int16_avx.c | 4 + .../lib/spqlios/arithmetic/zn_vmp_int16_ref.c | 4 + .../lib/spqlios/arithmetic/zn_vmp_int32_avx.c | 223 +++ .../lib/spqlios/arithmetic/zn_vmp_int32_ref.c | 88 + .../lib/spqlios/arithmetic/zn_vmp_int8_avx.c | 4 + .../lib/spqlios/arithmetic/zn_vmp_int8_ref.c | 4 + spqlios/lib/spqlios/arithmetic/zn_vmp_ref.c | 138 ++ spqlios/lib/spqlios/arithmetic/znx_small.c | 38 + .../lib/spqlios/coeffs/coeffs_arithmetic.c | 496 +++++ .../lib/spqlios/coeffs/coeffs_arithmetic.h | 78 + .../spqlios/coeffs/coeffs_arithmetic_avx.c | 124 ++ spqlios/lib/spqlios/commons.c | 165 ++ spqlios/lib/spqlios/commons.h | 77 + spqlios/lib/spqlios/commons_private.c | 55 + spqlios/lib/spqlios/commons_private.h | 72 + spqlios/lib/spqlios/cplx/README.md | 22 + spqlios/lib/spqlios/cplx/cplx_common.c | 80 + spqlios/lib/spqlios/cplx/cplx_conversions.c | 158 ++ .../spqlios/cplx/cplx_conversions_avx2_fma.c | 104 ++ spqlios/lib/spqlios/cplx/cplx_execute.c | 18 + .../lib/spqlios/cplx/cplx_fallbacks_aarch64.c | 41 + spqlios/lib/spqlios/cplx/cplx_fft.h | 221 +++ spqlios/lib/spqlios/cplx/cplx_fft16_avx_fma.s | 156 ++ .../spqlios/cplx/cplx_fft16_avx_fma_win32.s | 190 ++ spqlios/lib/spqlios/cplx/cplx_fft_asserts.c | 8 + spqlios/lib/spqlios/cplx/cplx_fft_avx2_fma.c | 266 +++ spqlios/lib/spqlios/cplx/cplx_fft_avx512.c | 453 +++++ spqlios/lib/spqlios/cplx/cplx_fft_internal.h | 123 ++ spqlios/lib/spqlios/cplx/cplx_fft_private.h | 109 ++ spqlios/lib/spqlios/cplx/cplx_fft_ref.c | 367 ++++ spqlios/lib/spqlios/cplx/cplx_fft_sse.c | 310 ++++ .../lib/spqlios/cplx/cplx_fftvec_avx2_fma.c | 389 ++++ spqlios/lib/spqlios/cplx/cplx_fftvec_ref.c | 85 + .../lib/spqlios/cplx/cplx_ifft16_avx_fma.s | 157 ++ .../spqlios/cplx/cplx_ifft16_avx_fma_win32.s | 192 ++ spqlios/lib/spqlios/cplx/cplx_ifft_avx2_fma.c | 267 +++ spqlios/lib/spqlios/cplx/cplx_ifft_ref.c | 315 ++++ spqlios/lib/spqlios/cplx/spqlios_cplx_fft.c | 0 spqlios/lib/spqlios/ext/neon_accel/macrof.h | 138 ++ spqlios/lib/spqlios/ext/neon_accel/macrofx4.h | 428 +++++ spqlios/lib/spqlios/q120/q120_arithmetic.h | 115 ++ .../lib/spqlios/q120/q120_arithmetic_avx2.c | 567 ++++++ .../spqlios/q120/q120_arithmetic_private.h | 37 + .../lib/spqlios/q120/q120_arithmetic_ref.c | 506 +++++ .../lib/spqlios/q120/q120_arithmetic_simple.c | 111 ++ spqlios/lib/spqlios/q120/q120_common.h | 94 + .../lib/spqlios/q120/q120_fallbacks_aarch64.c | 5 + spqlios/lib/spqlios/q120/q120_ntt.c | 340 ++++ spqlios/lib/spqlios/q120/q120_ntt.h | 25 + spqlios/lib/spqlios/q120/q120_ntt_avx2.c | 479 +++++ spqlios/lib/spqlios/q120/q120_ntt_private.h | 39 + spqlios/lib/spqlios/reim/reim_conversions.c | 212 +++ .../lib/spqlios/reim/reim_conversions_avx.c | 106 ++ spqlios/lib/spqlios/reim/reim_execute.c | 22 + .../lib/spqlios/reim/reim_fallbacks_aarch64.c | 15 + spqlios/lib/spqlios/reim/reim_fft.h | 207 +++ spqlios/lib/spqlios/reim/reim_fft16_avx_fma.s | 167 ++ .../spqlios/reim/reim_fft16_avx_fma_win32.s | 203 ++ spqlios/lib/spqlios/reim/reim_fft4_avx_fma.c | 66 + spqlios/lib/spqlios/reim/reim_fft8_avx_fma.c | 89 + spqlios/lib/spqlios/reim/reim_fft_avx2.c | 162 ++ .../lib/spqlios/reim/reim_fft_core_template.h | 162 ++ spqlios/lib/spqlios/reim/reim_fft_ifft.c | 37 + spqlios/lib/spqlios/reim/reim_fft_internal.h | 143 ++ spqlios/lib/spqlios/reim/reim_fft_neon.c | 1627 +++++++++++++++++ spqlios/lib/spqlios/reim/reim_fft_private.h | 101 + spqlios/lib/spqlios/reim/reim_fft_ref.c | 437 +++++ .../lib/spqlios/reim/reim_fftvec_addmul_fma.c | 75 + .../lib/spqlios/reim/reim_fftvec_addmul_ref.c | 54 + .../lib/spqlios/reim/reim_ifft16_avx_fma.s | 192 ++ .../spqlios/reim/reim_ifft16_avx_fma_win32.s | 228 +++ spqlios/lib/spqlios/reim/reim_ifft4_avx_fma.c | 62 + spqlios/lib/spqlios/reim/reim_ifft8_avx_fma.c | 86 + spqlios/lib/spqlios/reim/reim_ifft_avx2.c | 167 ++ spqlios/lib/spqlios/reim/reim_ifft_ref.c | 409 +++++ spqlios/lib/spqlios/reim/reim_to_tnx_avx.c | 32 + spqlios/lib/spqlios/reim/reim_to_tnx_ref.c | 72 + spqlios/lib/spqlios/reim4/reim4_arithmetic.h | 149 ++ .../lib/spqlios/reim4/reim4_arithmetic_avx2.c | 130 ++ .../lib/spqlios/reim4/reim4_arithmetic_ref.c | 214 +++ spqlios/lib/spqlios/reim4/reim4_execute.c | 19 + .../spqlios/reim4/reim4_fallbacks_aarch64.c | 11 + .../spqlios/reim4/reim4_fftvec_addmul_fma.c | 54 + .../spqlios/reim4/reim4_fftvec_addmul_ref.c | 97 + .../lib/spqlios/reim4/reim4_fftvec_conv_fma.c | 37 + .../lib/spqlios/reim4/reim4_fftvec_conv_ref.c | 116 ++ .../lib/spqlios/reim4/reim4_fftvec_internal.h | 20 + .../lib/spqlios/reim4/reim4_fftvec_private.h | 33 + .../lib/spqlios/reim4/reim4_fftvec_public.h | 59 + spqlios/lib/test/CMakeLists.txt | 142 ++ .../test/spqlios_coeffs_arithmetic_test.cpp | 488 +++++ .../test/spqlios_cplx_conversions_test.cpp | 86 + spqlios/lib/test/spqlios_cplx_fft_bench.cpp | 112 ++ spqlios/lib/test/spqlios_cplx_test.cpp | 496 +++++ .../test/spqlios_q120_arithmetic_bench.cpp | 136 ++ .../lib/test/spqlios_q120_arithmetic_test.cpp | 437 +++++ spqlios/lib/test/spqlios_q120_ntt_bench.cpp | 44 + spqlios/lib/test/spqlios_q120_ntt_test.cpp | 174 ++ .../test/spqlios_reim4_arithmetic_bench.cpp | 52 + .../test/spqlios_reim4_arithmetic_test.cpp | 253 +++ .../test/spqlios_reim_conversions_test.cpp | 115 ++ spqlios/lib/test/spqlios_reim_test.cpp | 477 +++++ spqlios/lib/test/spqlios_svp_product_test.cpp | 28 + spqlios/lib/test/spqlios_svp_test.cpp | 47 + spqlios/lib/test/spqlios_test.cpp | 493 +++++ ...qlios_vec_rnx_approxdecomp_tnxdbl_test.cpp | 42 + .../test/spqlios_vec_rnx_conversions_test.cpp | 134 ++ .../lib/test/spqlios_vec_rnx_ppol_test.cpp | 73 + spqlios/lib/test/spqlios_vec_rnx_test.cpp | 417 +++++ spqlios/lib/test/spqlios_vec_rnx_vmp_test.cpp | 291 +++ spqlios/lib/test/spqlios_vec_znx_big_test.cpp | 265 +++ spqlios/lib/test/spqlios_vec_znx_dft_test.cpp | 193 ++ spqlios/lib/test/spqlios_vec_znx_test.cpp | 546 ++++++ spqlios/lib/test/spqlios_vmp_product_test.cpp | 121 ++ .../lib/test/spqlios_zn_approxdecomp_test.cpp | 46 + .../lib/test/spqlios_zn_conversions_test.cpp | 104 ++ spqlios/lib/test/spqlios_zn_vmp_test.cpp | 67 + spqlios/lib/test/spqlios_znx_small_test.cpp | 26 + spqlios/lib/test/testlib/fft64_dft.cpp | 168 ++ spqlios/lib/test/testlib/fft64_dft.h | 43 + spqlios/lib/test/testlib/fft64_layouts.cpp | 238 +++ spqlios/lib/test/testlib/fft64_layouts.h | 109 ++ spqlios/lib/test/testlib/mod_q120.cpp | 229 +++ spqlios/lib/test/testlib/mod_q120.h | 49 + .../test/testlib/negacyclic_polynomial.cpp | 18 + .../lib/test/testlib/negacyclic_polynomial.h | 69 + .../test/testlib/negacyclic_polynomial_impl.h | 247 +++ spqlios/lib/test/testlib/ntt120_dft.cpp | 122 ++ spqlios/lib/test/testlib/ntt120_dft.h | 31 + spqlios/lib/test/testlib/ntt120_layouts.cpp | 66 + spqlios/lib/test/testlib/ntt120_layouts.h | 103 ++ .../lib/test/testlib/polynomial_vector.cpp | 69 + spqlios/lib/test/testlib/polynomial_vector.h | 42 + spqlios/lib/test/testlib/random.cpp | 55 + spqlios/lib/test/testlib/reim4_elem.cpp | 145 ++ spqlios/lib/test/testlib/reim4_elem.h | 95 + spqlios/lib/test/testlib/sha3.c | 168 ++ spqlios/lib/test/testlib/sha3.h | 56 + spqlios/lib/test/testlib/test_commons.cpp | 10 + spqlios/lib/test/testlib/test_commons.h | 74 + spqlios/lib/test/testlib/test_hash.cpp | 24 + spqlios/lib/test/testlib/vec_rnx_layout.cpp | 182 ++ spqlios/lib/test/testlib/vec_rnx_layout.h | 85 + spqlios/lib/test/testlib/zn_layouts.cpp | 55 + spqlios/lib/test/testlib/zn_layouts.h | 29 + spqlios/src/lib.rs | 15 + spqlios/src/mod.rs | 1 + spqlios/src/module.rs | 91 + spqlios/src/poly.rs | 190 ++ spqlios/tests/module.rs | 9 + 201 files changed, 30406 insertions(+), 3 deletions(-) create mode 100644 spqlios/Cargo.toml create mode 100644 spqlios/build.rs create mode 100644 spqlios/examples/fft.rs create mode 100644 spqlios/lib/.clang-format create mode 100644 spqlios/lib/.gitignore create mode 100644 spqlios/lib/CMakeLists.txt create mode 100644 spqlios/lib/CONTRIBUTING.md create mode 100644 spqlios/lib/Changelog.md create mode 100644 spqlios/lib/LICENSE create mode 100644 spqlios/lib/README.md create mode 100644 spqlios/lib/docs/api-full.svg create mode 100644 spqlios/lib/docs/logo-inpher1.png create mode 100644 spqlios/lib/docs/logo-inpher2.png create mode 100644 spqlios/lib/docs/logo-sandboxaq-black.svg create mode 100644 spqlios/lib/docs/logo-sandboxaq-white.svg create mode 100644 spqlios/lib/manifest.yaml create mode 100644 spqlios/lib/scripts/auto-release.sh create mode 100644 spqlios/lib/scripts/ci-pkg create mode 100644 spqlios/lib/scripts/prepare-release create mode 100644 spqlios/lib/spqlios/CMakeLists.txt create mode 100644 spqlios/lib/spqlios/arithmetic/module_api.c create mode 100644 spqlios/lib/spqlios/arithmetic/scalar_vector_product.c create mode 100644 spqlios/lib/spqlios/arithmetic/vec_rnx_api.c create mode 100644 spqlios/lib/spqlios/arithmetic/vec_rnx_approxdecomp_avx.c create mode 100644 spqlios/lib/spqlios/arithmetic/vec_rnx_approxdecomp_ref.c create mode 100644 spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic.c create mode 100644 spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic.h create mode 100644 spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_avx.c create mode 100644 spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h create mode 100644 spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_private.h create mode 100644 spqlios/lib/spqlios/arithmetic/vec_rnx_conversions_ref.c create mode 100644 spqlios/lib/spqlios/arithmetic/vec_rnx_svp_ref.c create mode 100644 spqlios/lib/spqlios/arithmetic/vec_rnx_vmp_avx.c create mode 100644 spqlios/lib/spqlios/arithmetic/vec_rnx_vmp_ref.c create mode 100644 spqlios/lib/spqlios/arithmetic/vec_znx.c create mode 100644 spqlios/lib/spqlios/arithmetic/vec_znx_arithmetic.h create mode 100644 spqlios/lib/spqlios/arithmetic/vec_znx_arithmetic_private.h create mode 100644 spqlios/lib/spqlios/arithmetic/vec_znx_avx.c create mode 100644 spqlios/lib/spqlios/arithmetic/vec_znx_big.c create mode 100644 spqlios/lib/spqlios/arithmetic/vec_znx_dft.c create mode 100644 spqlios/lib/spqlios/arithmetic/vec_znx_dft_avx2.c create mode 100644 spqlios/lib/spqlios/arithmetic/vector_matrix_product.c create mode 100644 spqlios/lib/spqlios/arithmetic/vector_matrix_product_avx.c create mode 100644 spqlios/lib/spqlios/arithmetic/zn_api.c create mode 100644 spqlios/lib/spqlios/arithmetic/zn_approxdecomp_ref.c create mode 100644 spqlios/lib/spqlios/arithmetic/zn_arithmetic.h create mode 100644 spqlios/lib/spqlios/arithmetic/zn_arithmetic_plugin.h create mode 100644 spqlios/lib/spqlios/arithmetic/zn_arithmetic_private.h create mode 100644 spqlios/lib/spqlios/arithmetic/zn_conversions_ref.c create mode 100644 spqlios/lib/spqlios/arithmetic/zn_vmp_int16_avx.c create mode 100644 spqlios/lib/spqlios/arithmetic/zn_vmp_int16_ref.c create mode 100644 spqlios/lib/spqlios/arithmetic/zn_vmp_int32_avx.c create mode 100644 spqlios/lib/spqlios/arithmetic/zn_vmp_int32_ref.c create mode 100644 spqlios/lib/spqlios/arithmetic/zn_vmp_int8_avx.c create mode 100644 spqlios/lib/spqlios/arithmetic/zn_vmp_int8_ref.c create mode 100644 spqlios/lib/spqlios/arithmetic/zn_vmp_ref.c create mode 100644 spqlios/lib/spqlios/arithmetic/znx_small.c create mode 100644 spqlios/lib/spqlios/coeffs/coeffs_arithmetic.c create mode 100644 spqlios/lib/spqlios/coeffs/coeffs_arithmetic.h create mode 100644 spqlios/lib/spqlios/coeffs/coeffs_arithmetic_avx.c create mode 100644 spqlios/lib/spqlios/commons.c create mode 100644 spqlios/lib/spqlios/commons.h create mode 100644 spqlios/lib/spqlios/commons_private.c create mode 100644 spqlios/lib/spqlios/commons_private.h create mode 100644 spqlios/lib/spqlios/cplx/README.md create mode 100644 spqlios/lib/spqlios/cplx/cplx_common.c create mode 100644 spqlios/lib/spqlios/cplx/cplx_conversions.c create mode 100644 spqlios/lib/spqlios/cplx/cplx_conversions_avx2_fma.c create mode 100644 spqlios/lib/spqlios/cplx/cplx_execute.c create mode 100644 spqlios/lib/spqlios/cplx/cplx_fallbacks_aarch64.c create mode 100644 spqlios/lib/spqlios/cplx/cplx_fft.h create mode 100644 spqlios/lib/spqlios/cplx/cplx_fft16_avx_fma.s create mode 100644 spqlios/lib/spqlios/cplx/cplx_fft16_avx_fma_win32.s create mode 100644 spqlios/lib/spqlios/cplx/cplx_fft_asserts.c create mode 100644 spqlios/lib/spqlios/cplx/cplx_fft_avx2_fma.c create mode 100644 spqlios/lib/spqlios/cplx/cplx_fft_avx512.c create mode 100644 spqlios/lib/spqlios/cplx/cplx_fft_internal.h create mode 100644 spqlios/lib/spqlios/cplx/cplx_fft_private.h create mode 100644 spqlios/lib/spqlios/cplx/cplx_fft_ref.c create mode 100644 spqlios/lib/spqlios/cplx/cplx_fft_sse.c create mode 100644 spqlios/lib/spqlios/cplx/cplx_fftvec_avx2_fma.c create mode 100644 spqlios/lib/spqlios/cplx/cplx_fftvec_ref.c create mode 100644 spqlios/lib/spqlios/cplx/cplx_ifft16_avx_fma.s create mode 100644 spqlios/lib/spqlios/cplx/cplx_ifft16_avx_fma_win32.s create mode 100644 spqlios/lib/spqlios/cplx/cplx_ifft_avx2_fma.c create mode 100644 spqlios/lib/spqlios/cplx/cplx_ifft_ref.c create mode 100644 spqlios/lib/spqlios/cplx/spqlios_cplx_fft.c create mode 100644 spqlios/lib/spqlios/ext/neon_accel/macrof.h create mode 100644 spqlios/lib/spqlios/ext/neon_accel/macrofx4.h create mode 100644 spqlios/lib/spqlios/q120/q120_arithmetic.h create mode 100644 spqlios/lib/spqlios/q120/q120_arithmetic_avx2.c create mode 100644 spqlios/lib/spqlios/q120/q120_arithmetic_private.h create mode 100644 spqlios/lib/spqlios/q120/q120_arithmetic_ref.c create mode 100644 spqlios/lib/spqlios/q120/q120_arithmetic_simple.c create mode 100644 spqlios/lib/spqlios/q120/q120_common.h create mode 100644 spqlios/lib/spqlios/q120/q120_fallbacks_aarch64.c create mode 100644 spqlios/lib/spqlios/q120/q120_ntt.c create mode 100644 spqlios/lib/spqlios/q120/q120_ntt.h create mode 100644 spqlios/lib/spqlios/q120/q120_ntt_avx2.c create mode 100644 spqlios/lib/spqlios/q120/q120_ntt_private.h create mode 100644 spqlios/lib/spqlios/reim/reim_conversions.c create mode 100644 spqlios/lib/spqlios/reim/reim_conversions_avx.c create mode 100644 spqlios/lib/spqlios/reim/reim_execute.c create mode 100644 spqlios/lib/spqlios/reim/reim_fallbacks_aarch64.c create mode 100644 spqlios/lib/spqlios/reim/reim_fft.h create mode 100644 spqlios/lib/spqlios/reim/reim_fft16_avx_fma.s create mode 100644 spqlios/lib/spqlios/reim/reim_fft16_avx_fma_win32.s create mode 100644 spqlios/lib/spqlios/reim/reim_fft4_avx_fma.c create mode 100644 spqlios/lib/spqlios/reim/reim_fft8_avx_fma.c create mode 100644 spqlios/lib/spqlios/reim/reim_fft_avx2.c create mode 100644 spqlios/lib/spqlios/reim/reim_fft_core_template.h create mode 100644 spqlios/lib/spqlios/reim/reim_fft_ifft.c create mode 100644 spqlios/lib/spqlios/reim/reim_fft_internal.h create mode 100644 spqlios/lib/spqlios/reim/reim_fft_neon.c create mode 100644 spqlios/lib/spqlios/reim/reim_fft_private.h create mode 100644 spqlios/lib/spqlios/reim/reim_fft_ref.c create mode 100644 spqlios/lib/spqlios/reim/reim_fftvec_addmul_fma.c create mode 100644 spqlios/lib/spqlios/reim/reim_fftvec_addmul_ref.c create mode 100644 spqlios/lib/spqlios/reim/reim_ifft16_avx_fma.s create mode 100644 spqlios/lib/spqlios/reim/reim_ifft16_avx_fma_win32.s create mode 100644 spqlios/lib/spqlios/reim/reim_ifft4_avx_fma.c create mode 100644 spqlios/lib/spqlios/reim/reim_ifft8_avx_fma.c create mode 100644 spqlios/lib/spqlios/reim/reim_ifft_avx2.c create mode 100644 spqlios/lib/spqlios/reim/reim_ifft_ref.c create mode 100644 spqlios/lib/spqlios/reim/reim_to_tnx_avx.c create mode 100644 spqlios/lib/spqlios/reim/reim_to_tnx_ref.c create mode 100644 spqlios/lib/spqlios/reim4/reim4_arithmetic.h create mode 100644 spqlios/lib/spqlios/reim4/reim4_arithmetic_avx2.c create mode 100644 spqlios/lib/spqlios/reim4/reim4_arithmetic_ref.c create mode 100644 spqlios/lib/spqlios/reim4/reim4_execute.c create mode 100644 spqlios/lib/spqlios/reim4/reim4_fallbacks_aarch64.c create mode 100644 spqlios/lib/spqlios/reim4/reim4_fftvec_addmul_fma.c create mode 100644 spqlios/lib/spqlios/reim4/reim4_fftvec_addmul_ref.c create mode 100644 spqlios/lib/spqlios/reim4/reim4_fftvec_conv_fma.c create mode 100644 spqlios/lib/spqlios/reim4/reim4_fftvec_conv_ref.c create mode 100644 spqlios/lib/spqlios/reim4/reim4_fftvec_internal.h create mode 100644 spqlios/lib/spqlios/reim4/reim4_fftvec_private.h create mode 100644 spqlios/lib/spqlios/reim4/reim4_fftvec_public.h create mode 100644 spqlios/lib/test/CMakeLists.txt create mode 100644 spqlios/lib/test/spqlios_coeffs_arithmetic_test.cpp create mode 100644 spqlios/lib/test/spqlios_cplx_conversions_test.cpp create mode 100644 spqlios/lib/test/spqlios_cplx_fft_bench.cpp create mode 100644 spqlios/lib/test/spqlios_cplx_test.cpp create mode 100644 spqlios/lib/test/spqlios_q120_arithmetic_bench.cpp create mode 100644 spqlios/lib/test/spqlios_q120_arithmetic_test.cpp create mode 100644 spqlios/lib/test/spqlios_q120_ntt_bench.cpp create mode 100644 spqlios/lib/test/spqlios_q120_ntt_test.cpp create mode 100644 spqlios/lib/test/spqlios_reim4_arithmetic_bench.cpp create mode 100644 spqlios/lib/test/spqlios_reim4_arithmetic_test.cpp create mode 100644 spqlios/lib/test/spqlios_reim_conversions_test.cpp create mode 100644 spqlios/lib/test/spqlios_reim_test.cpp create mode 100644 spqlios/lib/test/spqlios_svp_product_test.cpp create mode 100644 spqlios/lib/test/spqlios_svp_test.cpp create mode 100644 spqlios/lib/test/spqlios_test.cpp create mode 100644 spqlios/lib/test/spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp create mode 100644 spqlios/lib/test/spqlios_vec_rnx_conversions_test.cpp create mode 100644 spqlios/lib/test/spqlios_vec_rnx_ppol_test.cpp create mode 100644 spqlios/lib/test/spqlios_vec_rnx_test.cpp create mode 100644 spqlios/lib/test/spqlios_vec_rnx_vmp_test.cpp create mode 100644 spqlios/lib/test/spqlios_vec_znx_big_test.cpp create mode 100644 spqlios/lib/test/spqlios_vec_znx_dft_test.cpp create mode 100644 spqlios/lib/test/spqlios_vec_znx_test.cpp create mode 100644 spqlios/lib/test/spqlios_vmp_product_test.cpp create mode 100644 spqlios/lib/test/spqlios_zn_approxdecomp_test.cpp create mode 100644 spqlios/lib/test/spqlios_zn_conversions_test.cpp create mode 100644 spqlios/lib/test/spqlios_zn_vmp_test.cpp create mode 100644 spqlios/lib/test/spqlios_znx_small_test.cpp create mode 100644 spqlios/lib/test/testlib/fft64_dft.cpp create mode 100644 spqlios/lib/test/testlib/fft64_dft.h create mode 100644 spqlios/lib/test/testlib/fft64_layouts.cpp create mode 100644 spqlios/lib/test/testlib/fft64_layouts.h create mode 100644 spqlios/lib/test/testlib/mod_q120.cpp create mode 100644 spqlios/lib/test/testlib/mod_q120.h create mode 100644 spqlios/lib/test/testlib/negacyclic_polynomial.cpp create mode 100644 spqlios/lib/test/testlib/negacyclic_polynomial.h create mode 100644 spqlios/lib/test/testlib/negacyclic_polynomial_impl.h create mode 100644 spqlios/lib/test/testlib/ntt120_dft.cpp create mode 100644 spqlios/lib/test/testlib/ntt120_dft.h create mode 100644 spqlios/lib/test/testlib/ntt120_layouts.cpp create mode 100644 spqlios/lib/test/testlib/ntt120_layouts.h create mode 100644 spqlios/lib/test/testlib/polynomial_vector.cpp create mode 100644 spqlios/lib/test/testlib/polynomial_vector.h create mode 100644 spqlios/lib/test/testlib/random.cpp create mode 100644 spqlios/lib/test/testlib/reim4_elem.cpp create mode 100644 spqlios/lib/test/testlib/reim4_elem.h create mode 100644 spqlios/lib/test/testlib/sha3.c create mode 100644 spqlios/lib/test/testlib/sha3.h create mode 100644 spqlios/lib/test/testlib/test_commons.cpp create mode 100644 spqlios/lib/test/testlib/test_commons.h create mode 100644 spqlios/lib/test/testlib/test_hash.cpp create mode 100644 spqlios/lib/test/testlib/vec_rnx_layout.cpp create mode 100644 spqlios/lib/test/testlib/vec_rnx_layout.h create mode 100644 spqlios/lib/test/testlib/zn_layouts.cpp create mode 100644 spqlios/lib/test/testlib/zn_layouts.h create mode 100644 spqlios/src/lib.rs create mode 100644 spqlios/src/mod.rs create mode 100644 spqlios/src/module.rs create mode 100644 spqlios/src/poly.rs create mode 100644 spqlios/tests/module.rs diff --git a/Cargo.lock b/Cargo.lock index 1f71ab3..659ed82 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -49,6 +49,32 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "bindgen" +version = "0.71.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "itertools 0.10.5", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", +] + +[[package]] +name = "bitflags" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" + [[package]] name = "bumpalo" version = "3.16.0" @@ -67,6 +93,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -100,6 +135,17 @@ dependencies = [ "half", ] +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "4.5.23" @@ -215,6 +261,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + [[package]] name = "half" version = "2.4.1" @@ -288,6 +340,16 @@ version = "0.2.167" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc" +[[package]] +name = "libloading" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +dependencies = [ + "cfg-if", + "windows-targets", +] + [[package]] name = "libm" version = "0.2.11" @@ -334,6 +396,12 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "ndarray" version = "0.16.1" @@ -349,6 +417,16 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "num" version = "0.4.3" @@ -507,6 +585,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "primality-test" version = "0.3.0" @@ -637,6 +725,12 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +[[package]] +name = "rustc-hash" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" + [[package]] name = "ryu" version = "1.0.18" @@ -692,12 +786,27 @@ dependencies = [ "serde", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "spqlios" +version = "0.1.0" +dependencies = [ + "bindgen", + "itertools 0.14.0", + "sampling", +] + [[package]] name = "sprs" version = "0.11.2" @@ -715,9 +824,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.90" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index c13f7e4..ca969ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,2 +1,2 @@ [workspace] -members = ["math", "sampling", "utils"] +members = ["math", "sampling", "spqlios", "utils"] diff --git a/spqlios/Cargo.toml b/spqlios/Cargo.toml new file mode 100644 index 0000000..16048d1 --- /dev/null +++ b/spqlios/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "spqlios" +version = "0.1.0" +edition = "2021" + +[dependencies] +itertools = "0.14.0" +sampling = { path = "../sampling" } + +[build-dependencies] +bindgen = "0.71.1" diff --git a/spqlios/build.rs b/spqlios/build.rs new file mode 100644 index 0000000..9b57c59 --- /dev/null +++ b/spqlios/build.rs @@ -0,0 +1,52 @@ +use bindgen; +use std::env; +use std::fs; +use std::path::absolute; +use std::path::PathBuf; +use std::time::SystemTime; + +fn main() { + // Path to the C header file + let header_paths = [ + "lib/spqlios/coeffs/coeffs_arithmetic.h", + "lib/spqlios/arithmetic/vec_znx_arithmetic.h", + ]; + + let out_path: PathBuf = PathBuf::from(env::var("OUT_DIR").unwrap()); + let bindings_file = out_path.join("bindings.rs"); + + let regenerate: bool = header_paths.iter().any(|header| { + let header_metadata: SystemTime = fs::metadata(header) + .and_then(|m| m.modified()) + .unwrap_or(SystemTime::UNIX_EPOCH); + let bindings_metadata: SystemTime = fs::metadata(&bindings_file) + .and_then(|m| m.modified()) + .unwrap_or(SystemTime::UNIX_EPOCH); + header_metadata > bindings_metadata + }); + + if regenerate { + // Generate the Rust bindings + let mut builder: bindgen::Builder = bindgen::Builder::default(); + for header in header_paths { + builder = builder.header(header); + } + + let bindings = builder + .generate_comments(false) // Optional: includes comments in bindings + .generate_inline_functions(true) // Optional: includes inline functions + .generate() + .expect("Unable to generate bindings"); + + // Write the bindings to the OUT_DIR + bindings + .write_to_file(&bindings_file) + .expect("Couldn't write bindings!"); + } + + println!( + "cargo:rustc-link-search=native={}", + absolute("./lib/build/spqlios").unwrap().to_str().unwrap() + ); + println!("cargo:rustc-link-lib=static=spqlios"); //"cargo:rustc-link-lib=dylib=spqlios" +} diff --git a/spqlios/examples/fft.rs b/spqlios/examples/fft.rs new file mode 100644 index 0000000..50f5e5d --- /dev/null +++ b/spqlios/examples/fft.rs @@ -0,0 +1,57 @@ +use std::ffi::c_void; +use std::time::Instant; + +use spqlios::bindings::*; + +fn main() { + let log_bound: usize = 19; + + let n: usize = 2048; + let m: usize = n >> 1; + + let mut a: Vec = vec![i64::default(); n]; + let mut b: Vec = vec![i64::default(); n]; + let mut c: Vec = vec![i64::default(); n]; + + a.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + b[1] = 1; + + println!("{:?}", b); + + unsafe { + let reim_fft_precomp = new_reim_fft_precomp(m as u32, 2); + let reim_ifft_precomp = new_reim_ifft_precomp(m as u32, 1); + + let buf_a = reim_fft_precomp_get_buffer(reim_fft_precomp, 0); + let buf_b = reim_fft_precomp_get_buffer(reim_fft_precomp, 1); + let buf_c = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0); + + let now = Instant::now(); + (0..1024).for_each(|_| { + reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr()); + reim_fft(reim_fft_precomp, buf_a); + + reim_from_znx64_simple(m as u32, log_bound as u32, buf_b as *mut c_void, b.as_ptr()); + reim_fft(reim_fft_precomp, buf_b); + + reim_fftvec_mul_simple( + m as u32, + buf_c as *mut c_void, + buf_a as *mut c_void, + buf_b as *mut c_void, + ); + reim_ifft(reim_ifft_precomp, buf_c); + + reim_to_znx64_simple( + m as u32, + m as f64, + log_bound as u32, + c.as_mut_ptr(), + buf_c as *mut c_void, + ) + }); + + println!("time: {}us", now.elapsed().as_micros()); + println!("{:?}", &c[..16]); + } +} diff --git a/spqlios/lib/.clang-format b/spqlios/lib/.clang-format new file mode 100644 index 0000000..120c0ac --- /dev/null +++ b/spqlios/lib/.clang-format @@ -0,0 +1,14 @@ +# Use the Google style in this project. +BasedOnStyle: Google + +# Some folks prefer to write "int& foo" while others prefer "int &foo". The +# Google Style Guide only asks for consistency within a project, we chose +# "int& foo" for this project: +DerivePointerAlignment: false +PointerAlignment: Left + +# The Google Style Guide only asks for consistency w.r.t. "east const" vs. +# "const west" alignment of cv-qualifiers. In this project we use "east const". +QualifierAlignment: Left + +ColumnLimit: 120 diff --git a/spqlios/lib/.gitignore b/spqlios/lib/.gitignore new file mode 100644 index 0000000..7b81a31 --- /dev/null +++ b/spqlios/lib/.gitignore @@ -0,0 +1,4 @@ +cmake-build-* +.idea + +build/ diff --git a/spqlios/lib/CMakeLists.txt b/spqlios/lib/CMakeLists.txt new file mode 100644 index 0000000..5711d8e --- /dev/null +++ b/spqlios/lib/CMakeLists.txt @@ -0,0 +1,69 @@ +cmake_minimum_required(VERSION 3.8) +project(spqlios) + +# read the current version from the manifest file +file(READ "manifest.yaml" manifest) +string(REGEX MATCH "version: +(([0-9]+)\\.([0-9]+)\\.([0-9]+))" SPQLIOS_VERSION_BLAH ${manifest}) +#message(STATUS "Version: ${SPQLIOS_VERSION_BLAH}") +set(SPQLIOS_VERSION ${CMAKE_MATCH_1}) +set(SPQLIOS_VERSION_MAJOR ${CMAKE_MATCH_2}) +set(SPQLIOS_VERSION_MINOR ${CMAKE_MATCH_3}) +set(SPQLIOS_VERSION_PATCH ${CMAKE_MATCH_4}) +message(STATUS "Compiling spqlios-fft version: ${SPQLIOS_VERSION_MAJOR}.${SPQLIOS_VERSION_MINOR}.${SPQLIOS_VERSION_PATCH}") + +#set(ENABLE_SPQLIOS_F128 ON CACHE BOOL "Enable float128 via libquadmath") +set(WARNING_PARANOID ON CACHE BOOL "Treat all warnings as errors") +set(ENABLE_TESTING ON CACHE BOOL "Compiles unittests and integration tests") +set(DEVMODE_INSTALL OFF CACHE BOOL "Install private headers and testlib (mainly for CI)") + +if (NOT CMAKE_BUILD_TYPE OR CMAKE_BUILD_TYPE STREQUAL "") + set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type: Release or Debug" FORCE) +endif() +message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") + +if (WARNING_PARANOID) + add_compile_options(-Wall -Werror -Wno-unused-command-line-argument) +endif() + +message(STATUS "CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") +message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") + +if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)") + set(X86 ON) + set(AARCH64 OFF) +else () + set(X86 OFF) + # set(ENABLE_SPQLIOS_F128 OFF) # float128 are only supported for x86 targets +endif () +if (CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)") + set(AARCH64 ON) +endif () + +if (CMAKE_SYSTEM_NAME MATCHES "(Windows)|(MSYS)") + set(WIN32 ON) +endif () +if (WIN32) + #overrides for win32 + set(X86 OFF) + set(AARCH64 OFF) + set(X86_WIN32 ON) +else() + set(X86_WIN32 OFF) + set(WIN32 OFF) +endif (WIN32) + +message(STATUS "--> WIN32: ${WIN32}") +message(STATUS "--> X86_WIN32: ${X86_WIN32}") +message(STATUS "--> X86_LINUX: ${X86}") +message(STATUS "--> AARCH64: ${AARCH64}") + +# compiles the main library in spqlios +add_subdirectory(spqlios) + +# compiles and activates unittests and itests +if (${ENABLE_TESTING}) + enable_testing() + add_subdirectory(test) +endif() + diff --git a/spqlios/lib/CONTRIBUTING.md b/spqlios/lib/CONTRIBUTING.md new file mode 100644 index 0000000..a30c304 --- /dev/null +++ b/spqlios/lib/CONTRIBUTING.md @@ -0,0 +1,77 @@ +# Contributing to SPQlios-fft + +The spqlios-fft team encourages contributions. +We encourage users to fix bugs, improve the documentation, write tests and to enhance the code, or ask for new features. +We encourage researchers to contribute with implementations of their FFT or NTT algorithms. +In the following we are trying to give some guidance on how to contribute effectively. + +## Communication ## + +Communication in the spqlios-fft project happens mainly on [GitHub](https://github.com/tfhe/spqlios-fft/issues). + +All communications are public, so please make sure to maintain professional behaviour in +all published comments. See [Code of Conduct](https://www.contributor-covenant.org/version/2/1/code_of_conduct/) for +guidelines. + +## Reporting Bugs or Requesting features ## + +Bug should be filed at [https://github.com/tfhe/spqlios-fft/issues](https://github.com/tfhe/spqlios-fft/issues). + +Features can also be requested there, in this case, please ensure that the features you request are self-contained, +easy to define, and generic enough to be used in different use-cases. Please provide an example of use-cases if +possible. + +## Setting up topic branches and generating pull requests + +This section applies to people that already have write access to the repository. Specific instructions for pull-requests +from public forks will be given later. + +To implement some changes, please follow these steps: + +- Create a "topic branch". Usually, the branch name should be `username/small-title` + or better `username/issuenumber-small-title` where `issuenumber` is the number of + the github issue number that is tackled. +- Push any needed commits to your branch. Make sure it compiles in `CMAKE_BUILD_TYPE=Debug` and `=Release`, with `-DWARNING_PARANOID=ON`. +- When the branch is nearly ready for review, please open a pull request, and add the label `check-on-arm` +- Do as many commits as necessary until all CI checks pass and all PR comments have been resolved. + + > _During the process, you may optionnally use `git rebase -i` to clean up your commit history. If you elect to do so, + please at the very least make sure that nobody else is working or has forked from your branch: the conflicts it would generate + and the human hours to fix them are not worth it. `Git merge` remains the preferred option._ + +- Finally, when all reviews are positive and all CI checks pass, you may merge your branch via the github webpage. + +### Keep your pull requests limited to a single issue + +Pull requests should be as small/atomic as possible. + +### Coding Conventions + +* Please make sure that your code is formatted according to the `.clang-format` file and + that all files end with a newline character. +* Please make sure that all the functions declared in the public api have relevant doxygen comments. + Preferably, functions in the private apis should also contain a brief doxygen description. + +### Versions and History + +* **Stable API** The project uses semantic versioning on the functions that are listed as `stable` in the documentation. A version has + the form `x.y.z` + * a patch release that increments `z` does not modify the stable API. + * a minor release that increments `y` adds a new feature to the stable API. + * In the unlikely case where we need to change or remove a feature, we will trigger a major release that + increments `x`. + + > _If any, we will mark those features as deprecated at least six months before the major release._ + +* **Experimental API** Features that are not part of the stable section in the documentation are experimental features: you may test them at + your own risk, + but keep in mind that semantic versioning does not apply to them. + +> _If you have a use-case that uses an experimental feature, we encourage +> you to tell us about it, so that this feature reaches to the stable section faster!_ + +* **Version history** The current version is reported in `manifest.yaml`, any change of version comes up with a tag on the main branch, and the history between releases is summarized in `Changelog.md`. It is the main source of truth for anyone who wishes to + get insight about + the history of the repository (not the commit graph). + +> Note: _The commit graph of git is for git's internal use only. Its main purpose is to reduce potential merge conflicts to a minimum, even in scenario where multiple features are developped in parallel: it may therefore be non-linear. If, as humans, we like to see a linear history, please read `Changelog.md` instead!_ diff --git a/spqlios/lib/Changelog.md b/spqlios/lib/Changelog.md new file mode 100644 index 0000000..5c0d2ea --- /dev/null +++ b/spqlios/lib/Changelog.md @@ -0,0 +1,18 @@ +# Changelog + +All notable changes to this project will be documented in this file. +this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [2.0.0] - 2024-08-21 + +- Initial release of the `vec_znx` (except convolution products), `vec_rnx` and `zn` apis. +- Hardware acceleration available: AVX2 (most parts) +- APIs are documented in the wiki and are in "beta mode": during the 2.x -> 3.x transition, functions whose API is satisfactory in test projects will pass in "stable mode". + +## [1.0.0] - 2023-07-18 + +- Initial release of the double precision fft on the reim and cplx backends +- Coeffs-space conversions cplx <-> znx32 and tnx32 +- FFT-space conversions cplx <-> reim4 layouts +- FFT-space multiplications on the cplx, reim and reim4 layouts. +- In this first release, the only platform supported is linux x86_64 (generic C code, and avx2/fma). It compiles on arm64, but without any acceleration. diff --git a/spqlios/lib/LICENSE b/spqlios/lib/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/spqlios/lib/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/spqlios/lib/README.md b/spqlios/lib/README.md new file mode 100644 index 0000000..9edc19d --- /dev/null +++ b/spqlios/lib/README.md @@ -0,0 +1,65 @@ +# SPQlios library + + + +The SPQlios library provides fast arithmetic for Fully Homomorphic Encryption, and other lattice constructions that arise in post quantum cryptography. + + + +Namely, it is divided into 4 sections: + +* The low-level DFT section support FFT over 64-bit floats, as well as NTT modulo one fixed 120-bit modulus. It is an upgrade of the original spqlios-fft module embedded in the TFHE library since 2016. The DFT section exposes the traditional DFT, inverse-DFT, and coefficient-wise multiplications in DFT space. +* The VEC_ZNX section exposes fast algebra over vectors of small integer polynomial modulo $X^N+1$. It proposed in particular efficient (prepared) vector-matrix products, scalar-vector products, convolution products, and element-wise products, operations that naturally occurs on gadget-decomposed Ring-LWE coordinates. +* The RNX section is a simpler variant of VEC_ZNX, to represent single polynomials modulo $X^N+1$ (over the reals or over the torus) when the coefficient precision fits on 64-bit doubles. The small vector-matrix API of the RNX section is particularly adapted to reproducing the fastest CGGI-based bootstrappings. +* The ZN section focuses over vector and matrix algebra over scalars (used by scalar LWE, or scalar key-switches, but also on non-ring schemes like Frodo, FrodoPIR, and SimplePIR). + +### A high value target for hardware accelerations + +SPQlios is more than a library, it is also a good target for hardware developers. +On one hand, the arithmetic operations that are defined in the library have a clear standalone mathematical definition. And at the same time, the amount of work in each operations is sufficiently large so that meaningful functions only require a few of these. + +This makes the SPQlios API a high value target for hardware acceleration, that targets FHE. + +### SPQLios is not an FHE library, but a huge enabler + +SPQlios itself is not an FHE library: there is no ciphertext, plaintext or key. It is a mathematical library that exposes efficient algebra over polynomials. Using the functions exposed, it is possible to quickly build efficient FHE libraries, with support for the main schemes based on Ring-LWE: BFV, BGV, CGGI, DM, CKKS. + + +## Dependencies + +The SPQLIOS-FFT library is a C library that can be compiled with a standard C compiler, and depends only on libc and libm. The API +interface can be used in a regular C code, and any other language via classical foreign APIs. + +The unittests and integration tests are in an optional part of the code, and are written in C++. These tests rely on +[```benchmark```](https://github.com/google/benchmark), and [```gtest```](https://github.com/google/googletest) libraries, and therefore require a C++17 compiler. + +Currently, the project has been tested with the gcc,g++ >= 11.3.0 compiler under Linux (x86_64). In the future, we plan to +extend the compatibility to other compilers, platforms and operating systems. + + +## Installation + +The library uses a classical ```cmake``` build mechanism: use ```cmake``` to create a ```build``` folder in the top level directory and run ```make``` from inside it. This assumes that the standard tool ```cmake``` is already installed on the system, and an up-to-date c++ compiler (i.e. g++ >=11.3.0) as well. + +It will compile the shared library in optimized mode, and ```make install``` install it to the desired prefix folder (by default ```/usr/local/lib```). + +If you want to choose additional compile options (i.e. other installation folder, debug mode, tests), you need to run cmake manually and pass the desired options: +``` +mkdir build +cd build +cmake ../src -CMAKE_INSTALL_PREFIX=/usr/ +make +``` +The available options are the following: + +| Variable Name | values | +| -------------------- | ------------------------------------------------------------ | +| CMAKE_INSTALL_PREFIX | */usr/local* installation folder (libs go in lib/ and headers in include/) | +| WARNING_PARANOID | All warnings are shown and treated as errors. Off by default | +| ENABLE_TESTING | Compiles unit tests and integration tests | + +------ + + + + diff --git a/spqlios/lib/docs/api-full.svg b/spqlios/lib/docs/api-full.svg new file mode 100644 index 0000000..8cc9743 --- /dev/null +++ b/spqlios/lib/docs/api-full.svg @@ -0,0 +1,416 @@ + + + +VEC_ZNX arithmetic APIFFT/NTT APIRNX APIZN APIFHE librariesBGV, BFV, CKKSFHE librariesTFHE/CGGI, DMOther PQC:NTRU,FalconOnion-PIRsOnion, spiralSPQlios APIChimeraSchemeSwitchLWE/SISFrodo, FrodoPIRPotential use-casesthat would benefit from the API diff --git a/spqlios/lib/docs/logo-inpher1.png b/spqlios/lib/docs/logo-inpher1.png new file mode 100644 index 0000000000000000000000000000000000000000..ce8de011e6373e4e600d573d5748156cbada9f21 GIT binary patch literal 24390 zcmeEu1yh?{)GY;CC>EqR1S{_D1TR|L-Q6iJ#i2lPcZ$2aYjJmn;_edUhWCE+{fIj^ zGfbGwlfBQ```B80orKAM|Avf!hX4fyg)Au{st5%I3xtCDpbP&A@(Jm{>I3BMv%Q3d z6BHB!+56uIw|tNb6cjO(q^OXxTl&e`q3_o@_luX~i%YQcdjHboSGn!AKh?-ZOX7x? zj=Yx=KR?vEePsf(p!4F84K1FCzk6{%+c$Z7 zh1=uHQt~bV&(2#)s+6BqY%Bl^RMuikk`61J=IY@g?fL(&|Fa zcQng~z9d`hWvE@tE^RaHz2`sIubDHIjCV^o8438Z0Y}&hK`^-u%fup`^D%l&f8+j( zPJ37vjN;{5V+U(dJ>G0ChZT2GXX2E;qGjFet4o|!os(pY3(WhI{1II+^TtQzr@bwS z{gaiu*+QWB?2I=l;VPl{65m5OivKmIc@@wo$Etxu@y8(TzH$D$ehwN>s4+*;vK|iy zyimn|)^2#90TRw5yp(DBy)N{Y>Q-xNtr;b6`9G+BZhpkw4nVNXr~7XV3(ApR{QSu~ zOnGBf-ay3E{fW6d*j`nb>}R%NgXD}OqFJkw?T2^Q2Rg7oV%Nk*Oi|;y#IJs@r zN*DP>^w{AN8b#6Er!t@$8nc#uaZo6|%U@3u?QbO0Eq2&0))&opvx4WO7n3`tN#Y2a zeV+)FN$}|43kYUeu$Ad;vqc(zG!zRq2XWOq1F?EAx$yYPHB$r;j=%SdoKjK%8U<&K1@luISr zZ7Eu1L#+Rv$yP&iJ0RiR;I~k&%Ty-y=3!n%UGGqY1S$YI$L70)p@t~qufMwVC8VTa z%QUr2dm$v~ciSG=ejz~nh1Ly%*<3TxJdi*qNymU;{VrCH0E!Bl#+ND2(G8_<0v#tl z_qTQ#r_W~t_g00Md0^K6G6flOBG|3RYjLvF{Sl!+72QnXD^h_-$8b&6fKgg<&koOr znQ&Skt85{Yt3af8fByw5|IFw9%gmsGV#qfmnntsorxAKlNtL=$i>cNM?&n9T(kNB! zOwnof0_^?=e2jqFZxdw%qwn8_a>o45?V950<6xPErGp+L6eo+SJ!@6?BK*Mf*;Lfr zwYBdf)n0v(P7?P1{=yZ!_j!43-DfXz8(|zcccJJmqShBdl z!~nu~%Rk3OHgm#Pufy%!HBcx-x)4eGU~!JRgCyjI{c}PjHoIGmrKnv}EYg!O@0Ay> z|IWeyB7X`qLFKFo_6aes9IcdUmWt62vz2Rm>0eMlqrKal6ocn;_Ve7}EhpA=v6>3- zgX%@_g9{rnHP?jg*(NRs$l0Kel+A3$(0RYW0N3E+?NHtI;SAGy187{#R*Qt5} zfh`WQOE~{E@s8wTa_h|V{KlZNPB={$h3AXzBy$RKDNQ32KV|{tuj}G_RdZONVE~+; zBjVdJOzF_?BAZG}g@Uif6GkFRuq5wRvSA=2MgP9tH8@3`IrEFLy4HMv>uEZ+J)OqV=0l(g98kuzx9_1#j^#W~9~vM@{snFu297pftBQ7S|9@GT2rRZHi)ZvcVxC&dz4x($9ZyMU%Es zJ%_}^K8-VDr`V{11o6KC%mL*Wf^yZx*ccW5SB9;Zg!^xm@VtmhUFIV5g`uS^nHKx* zN8q!5kEIl?sx2QbL3he5FDd#z^?A+HFA7c*1fIJs6p-B=^3rrgjZ!glyLg$3MQ5K` z#O0*;(pE;gVmS8eZ$TtHPawN6^!9VH4Mx65eG{pR)h9_`>a9PW$`nkJZb@S0*r4mFGP@iaOqi_? z=A;xic9v^VmrIPw!?VR%#}ZbjmW6GeGmlZZTE;BxN3};F{kMSrh4%|VzPSLK-9sjd zSl$NXpU~rDRQkryHC4dgru%5=DumKjhW;M4*0g2zzRa0$5lsv@22)&ArCP(52it<4 z*84{`nw%#;5jR+tkM0U*OYR~>QmT!ns&h*ih&y~g7dVSWK9BUf`2zPVz;y9fC{o2` zNNRN%?}+rSMkW@-oBj`YzPhD9MCn$;GqGHJAo?)*Huh7sAKXH}XXVLE zZg9Qoij`0(<}eL<0Fos53;3m0)A%WWhtW7^DM%58=V7$_UP2cr0@Ys0CC`=;$2%$+ zgL!t+Mt%DA0Rc*#9vWbCwb*ezwiNJ@zD3eqtkyBygURrzMe0^ z-F%ixI4B)EYHXzZG!$oljsN|Bvh&&R0qbqgU7KUI0F?DQNzn$XDJLoCxtGUvBJv!0;cM?m z`Sl-rNRwy2&yO&zZ(I9-3TF=Z3?r6a{-yD2r0kU?nUFctpp=PjbU6VieRg<0W9te> zh0Y?)$(+};czpe$7fMuf`Iw)5HB`7Q7txfhdq`GXRj7IGAKRtiJIm%p0zhSaY!3m< z^oIVZMx-9x%+aU_!mt5fw!Gjzpa0XRd_%HuEKoBJT7ddlVY)%NfR}yB zi1KC7{Ucyk-2Brcxia<$0@RaR0>iR}9xhW)NAYcL)PdLeH3@jW*netsp03_>nh_^%=e?iSDt7a>nCg+%fSYxS znxqijFc@5bB`p#sQh?M8OB&09>s6wAhDHa7@n{w|-uNa~>aFK_8$&lyVstvs!+`;H zWw^p~Zq}I*_yMZx7bNWfC8aa_HSJ^@1XKZH8xUjAM}*Z}~cexXup> z?rsmiuOuYhyFb}j@`tD$?Vx@ZrgB|5^5Cnpv^@wOt}yE0q^M_!Xufz21PeUylWcp) z!Df+EYWJn9Y~0VO5s|Nbfts|LhQ~wof`h6HR1B9XT3ZjPKE##Bm)QCuA1F<|-B8?o z-SM!fUbt4uJF@a*&C%2wn)6s~$T+e&@fr8Jof~YhR4}Uaa*vosgc9T=Hiv>T$7ca5-t~p-kYh?- ztKQOuqDNT7^r zo{!f@qK=Wu>o%_4i;Fpi0oz%cV0S!QhQV;Z?-%x{ZW2>*ovAq3yZ4;(97WO240>WC zwDGCD`{7Y8GLr@ce0*F=Kn`!3J9eT*OjaJHepoDJ$ON~RqL1bKG0Gk;N9a_4_{Kz% z8=!%3OK5Ab0S&H{oix~bP}@&-dv^ha7Vq2z)*uhc0d=UIUja6y{>e-;^nESE7MdFDW{~4q!P@4fwh~ z8@{8rCbj|hZo8Qh7C3!*Q(1XX>rH5>>?U+d*SZ6P=CtwH(~A#uPS)Ycu`{*E`}dJR z`Hyy-CH+Qqsyvvao>2|(dZ`p!m93bvm4I|;4PuZ^>6+<^p*@X#27^Y^D370Xx#@AX zI8q7{xMRTxWJab+BIA^={t$dHhNpo+5nXFPN-Kx%_3kHH8WuRel@e6}#8U9n^x4%` z&sohi9N@K`?KtW+6F7J6>~Y!B1t4b6s+Up-f2=-vgOr&tyL zBUQFV+h)%o@YtpH2b^LcMIO2TM^E%=wu4f+>RfV-=*`pD0nQ>2ia|m#+lb z6EZaNTi0LqvN`dTBu{RF_AY}Ha?x@Gg{_;#Zd@|tEz$0J9Cqfgf3Gl`!idhzC25)8 z|3>pNuXKS0a{Sw1g!q`Wf#)lW!FEPfu@69`*-LjF^>3lc6(49=ZB!5n!A;wh*Ms`j zD&MDPt81H+;fNSFmV+IaDk&v1AV_J-|DelfJD5)2UgVY*pSFe820};T{Fo-c?2r~v z$tQb3VziroWjNsHXX#$7eHhC74(^<}paE7Ft9BLm*b2H0f+zZ;e}e?=LaL`;Jy$yP zo55>w;4_bVOal#3xzGbyHKJF4YK6r#HlA9g5!%$pA%`T@liBcQ zd$(dBmjeAq1?!6~Tz|k)+QbE?@nbC{fvh~bq-9iiKcDH{ zF&15*D<2-7rkz`W{ddpw(pl&wSla@BF&c)}w*9nyv z8`_tK@73TY=~3L9>Mh^)dl5ALnZZSASMuN3(w?||bLgTfR zZMqC*O#1naAi$&#mN0hwmD1)Ics;1YMk^&AFTt)XHP!YZ|N2{74Pw3h%|s^oH3?)V zuQ|MCis_cist4gi910>r65DhAcE|5@H_K)!a}v5~2cMwj=1sWV-g*DLtwqjCq4d$E zC}-C(Jh+5^<`|9u9b2#42k!GO77c2kdedupEp`WZ-gTf8yvObB0IH2U;XccC*Et5G zj;g~)tY}M?@ApoDJb!PBh!0EeHpdqZjwU1xYrwKXy#LG;D;5jaeQ3lV-L9IYs)kSa z+O&&xYjtufCEoG!xYUq8y8#ZLK_K-IT2Oo<97cpHD_<2JO97%_^gm*qhqPX(<~SmL z{>WSRQpqBeAw7 zF>1IwIx#_U50Y5&KF`pd%*7TheU*sLv7ilVqHVV^-58|!)v8wi0N9gnGDsb0xS&_D z*Xq=XpU{5U`xs6EpHrXO`EuaNTSD0*J*Zdp90l^#mq;0G1A0k}6Nd$QTfPtlM?Eb$ z<#Tp=NTr(XZB;2TV1>i<3*dV(`4U;beXm)t)vMchCpGkuHM2CLjX;m5%ANBbSq;=T znpPT@HIG`y!-rX&PSUNwX7D=ipP0k!`r6I=x8p+Ny1F&1JT8HYICl2lwwDZy!n(hC zi2?p=$aRmc)u)1^4NWUcHY)U%-@}(5|NJ>WE20baeRza1tvm9^n%h|Py`O-<-Bs`r zX0A-z<6N8>@8%hVM$mT1P8JWyCxNbo`l&j(lS6Jfke*V(`8%Thgi;8$1;6eT>F(t` z!(x0j)gk3&`0chdjF$SI(ZnA)Z643LXli_82$|v!GS1F`?R;ma{-EDk=SifrG=|-Y zy4AAR(d*;GwQTmljh6&YLWx+hDKZs5>$lgF$F1+pa*!xMMSp?@&@R5LoJXb3o+Kn) z(4acYlg9Kas9|L!abwVf3Of=pPdsRYev5_I({=L@;2A=AUZS^PvZ zjQ?TK77V`-%wkDQO#zM!bA=UE?o%h%IuqwT2(E2N+GZTS!G=H8Fy9yUL)(9K^JV-M zE6e*ySNmIH0Ifkwz;w*Xr~GYO%IXaw)`n?VLM)e)cI&O__9VLp|ChqStYgSY`rjZ) zqU%N*MZve{laCZS?>#>F3=2$6e!C}<33E)Hn3JZC9j3QTk}db)y*u~Qv%U96p)t+% z|J6L(ck{bf7iZaF5wV!VR{3`E+jI}Fwj&y{RG=O9i|*;hBrNfrPR;$VX$SyWopea= z-VET?oQ0rD#@4FQYmMb9-gkFU)T9^D`V*ZS##pz!%01{EL%U0M-e(OBY?_L>E+LfE zqE0M5E6^#Zx)%NUA^ce$F2Rn_Fk>V1U%9JLsntU7Btsv4l>DpDbKBMG$#=dEbrSgg zrf-G+M3n1Xw3_P!3XXbMLfpZ3Cwg~Aj0_8mjV#q<6)E%;snZ1Y(Qy>6R67AWK&mxn zkNExQF~PA8kR?J|g`CA1f}#tRHPHHu(p)@csgPE9K?!bN!Qoqrf6+& z7k6@4US*eS&SuY-*zG#RMcTkY?@)>uaim3--q_OCYR=xk1aLNZ3-S=~VJT^%ViBB5yc)=24|W_!dmtVQUJ`j2`ys`NtUg*d%~e;Z9J`x$@v}K zuP$95Hck*=DXi3~_WGE1eE?w)Gw^0hx`?jc!9>dU_XPxiM)?H4?9)~Xa8!tmeLb!0 zGvj->cg=gQ$EcwR{{Zz52^RqM(}i@0L^KR9a>NB*)I8oPhoj$+`8YekBvld%x|iuV zj)Rn!vfE0z|Hk>CFFSzQy$|C$?i-%~$B%y%p!%f`;g@klM=tS1 zpY9A;AGgzk&TT{w`vXCyth!!9L5jM6l%yEmBkW8F4fvh)I!I{x#Yiu^Qsg8$RI?zU zeQlc)?)J}U1jKbU#f<+m8RdJ( zMG)43gA#+Xb8G0?3@PGRZP+YgIE3ZJ0@(ThtK zr+@A3V%c`}2$WKU>)$US#_gnRd>$;-1L|<>e!bMd8ymDqUFhi=xsKO<@8JO+EU?V? zae^0|P}Kbt3t3j|SqY(q;OpI{3p=wc8hVYu<7w#m{^TEIN9d0r)I*A{!xo7ggITlh zsq?i0K*f&xsGC{`QE>E=CwZSS&(~Z3i$+dhLrN`H!LsJS#WiEpxTne2YJa~_`MqG= zTb`M5ZwNex`gu!yp-&w{*j*WiFFikkr`!OE<|%Kdo~W$&9oyFNrcX&2B8q#iB%|Ni zk7EUZvK{}?>*ur&1~8yYT$r(8`qk{9(1^@>rnNI50*2J#Zj&x%cL!#N$E30I=I&!R z$2skL_sTx+u!;RM>&&7*cuJOHWw;hJw>CX>wBg$Q@aDK1C?{(>BBZg|h;G zjdV!4^>!LG{WM@cp81wPNLjPm_}#G|A3JK9IpL$(v$4^QslJxEr?GP3XhJZ>;PQ|A z-=06zI?d}D4FXMX1Qt&V3=XCjf?2&^ZfhW9l$>R^I7ZiJPF$Ua59K=TOcWImPN?0g zVF`@KFk>;|O^@;g@y>w*d_g&Gv`$njS6XlVhT~~PZmJ8jAmv`XI%i_4p|N$VS6>;H zVdEkDQ0Z5GS0VtWL{=$GseA;#ErT1hb!qNpI5stZ40FkXJ)GYNPTD?Ad+{ruSoJ(V zPWqnf6PVvE-<-FroSA((NBn%^v z9R30$jEp0@MCdP=O6nQdKeH?bwY$Ke8xI zNw+#S61P7`N{_Gh0FKfMbn>D^LD0JQMNgr#4|fq83h9M((s zG1r6RPvLHh-4=9M{H|zq+t(%PUd)$v`D!hD$^XX?wBqP1GieJ8Z0qVB|A@w~j>tlw zPCe88T*-4abpB$Y3T9Pn{%fpzvY?4@`%(l4>>IMbxR-sVgTD@L-i1&oyLUv>CUz${ zbEK@AC^UVrkgC)JlE0vRZo^_$8vc_mZvcSeI0qCRgVy-geOlmtP~h6x5DAlq1?jqJ zcm1+yE~cmb8hcbNzf!dwGMp!cw7#sL^gi8_JsgoS!;1cN1s_npg5_XO>-uWo-+iB;c?Ps>Kydy8uf23R% z*pICX7v(b$0H-VkmaodfUpd*(V466k@Q*l60C)A$c6OFdX#IopM|$0biIIWxe7U)x ztd!d&QcG_h7F+M$8rDGZyFxfedT`gH+kHiP=)>n7h3?3!*Z51Bpv9)OxzIRq%Cq;> z088}SbN5n+-qIjHTcT*SM?2H!-fD%#ll~5pfGL6-hO$eBvBacd9iy-?=6V>ssxdY*7Kx|1P7O%_saga+F0Cwb@U0IIp+p_B zn>;hjey0NEmlxKW_{~oeVmeiLz0F&&vGFY6fG5z-p222RoXff1^r%10VOgd5WWXnzu4w;j_A(l-s}L#$d&|mOjhjlx-2z}Z`77d$UBv7wIy;KT{k9e)bf0dDOY^w6E%m+*Ds+Lce1is;QqlPp@?dy(UDEdq!6nzUchZAWMQ4HwrSM#hEaFRum?UiS$%9Mvru=zL-6 zS3&vQmQ#!cljPr}F0-zGs=viZqc0ZF?0a2Kjq(iL`L{N_{^c_hxwZBeItMrMlh@Ig-0|stH-J zstXCBlj5og>aejgQkQq68Q6}MXO%z{vAh=GHw@n@ziLCd{qL4rZW}JwV{7!W1%`^& z+huvQr;R|4`0IinT>k8xy^CMTo0`}#eSFacqpMwe z@A0Vah*uH)yqbcoL4#+^DXVU0Ga*qEhF_PkRff7a!26NB21sm?+E%D%A=0EO> zi^?}|-}2EO#TZA##{B>yKlzjdjnFGn-;a%>WNn~VR(Glg>pd%pT5sC_d|u6OjuT$Cb}|G+q>}o)%bal94fO6 zg*0h*{Nl+VJ^aWi_etGK%F9%g8<*91e^1*M1|X$;iX^Y?8zRvh)L8!%mDR*k8v#e| zb0_94a$s0@9)K>^&`Yycd+WmZkR7kwyM%Jt9sKuh3lfri!b1b2%X};cu-04c&Sa7v z=|}Wp`*F*Cov<220Ah5Jf>t5YgISyB9&2)TZz`IXeKa$STBbC8+jzNy7*=E@j8N#P zKZ@MzQH=uQsYH+nV&qHcjkmnz*$E?`sK^j|lk2e!$F-2!nxKf`P!NJG_EcP9r;l_2 zQD3F+ZiERMf8(%YI5fMH#E{FlCEVn7PQ$}ruo0?Ry4a_(4j^qS)uo-!m=yo?wSi8{ zSGW2`m^%{AtF+7P*=z(8G%C!6^X|@gos=NnX;!utR)h9Kn~ZTw->c2_^2R3|%^Ng%e$dzLA^apSZelrr^ESPsO==s+UTY~2E(}teL zn7N*KzmAN!xJ59&wug&pHKV)p#wNYfx?@gm_h%(Xhvnl&)2D)xHt*|NW^@*& z+;XA69#A#R>mqRs(G@;YYnumqf?s5nW@hZt*!E>P;}4U6_Nv33e5?3ZKWkVT%1dC_ z?~uAATgw#Q*VUsS=8v3Hd!A6+z}NLP{L321?K(a*Q|_z~Rm*m3gf0K-wlckgAn8*9 zv|o%w;=m?J#5`<`Fssnllu;*3c4&mPWOjXUf;>lj#ZbD5$e6QG%5RNt$Pk_{C0$v( zCTjPxn)!url&7X4{^lD7O|gwr8ZymW=@{r`g111_>fFQTE%9j=f-gPS5Ry($-$cXD zl>AJ}*RqoHD;i~~YY1x$q!^y8>QB$eHq^(DyS6)CmvQ(`>sQ3HAccu$q^CNI!t9A8~B+VUvudmL79+;+-TN;-p@+#{Mawebt{ZVlNHrKjAWFD9Ib-!D8 zzozVV!?RTQ9PuM~6?~3(`Y^ibO`=oWA(!pl!g6uZ4|*H80>7EbwaW!G!q=7dEarDk zbI7Bv>T2aRZ;(wlCA_q}wxd4?oQi(JTc8aa(S>({LApKv!66R)CDJ>Y=(sb=R|3~v z|Ds^nJ$ymdE#^?uBkVeSyO#JQ>Eo2Hja}jxg=~TJkgv@{X7 za!y9#ssEJ-xI0TyV^IpO$fde^tiaRMfq3oPyvWFE(7?83fJ~{0Zv*%mrcwyM4zI4= zzDCn6_I48Q?ZvRtyHmn@3)cgH(20d`SHIb*W{ehpRPfPEbqU31=%6Cf@*nnf2Pzkb zRsvA#$Y@PsI;qxbvcg#5yH1>PkDA8fe%(Z>VN>+Os%x zWMT7|A-i^$l~vY3-GsUB{xL9RqD*CWQprEB(QbtOl(k za;El7%^77OqL?fA#XCH~D=*Bxrl%T{lC8uNaT?a&?FZ=RYmoCM>~6q8-aez?g*NvN z>k9*l04UKxMlPEwj#J`tcb~#m58h0@6H1i;)#J7b{!1!2` za(tnH0hV?&`_K^|c@F4JQ&5rMMeZeH4>^wCreM)3DQdQP^(tv~l*oY;xHjn<_`IOl z*fYYN9J^j7cy9^CZ(qyACDCe&woolZ*WHq#*{Y?alOls3UjUL$gmdpxe1d0Eo4s~( zGtODbW$*OHjQ$yGByMx83W6_lDmKIgR+h`BKEELVruAg8l?^++7q$Qz>y)&PEXv3l z)er)5(;Lq*EI$2daP9VIJ+)BZtKUfI<^1kC4u*}z6H9%i3o7LqhB3@_9y0%VNT zztOmI6|@Q}-#UcGoaR)nHKX6!IFNA|nc^&Gz{?ca`)WRL&iQ)j8$E!ZV*%sZjUy0p zs?N^G!Q!LrH_hKZ$Mq|ZghK>OHV_p^NJ}}l5(M(ml4=S6 z$Ug4xtX$sVR@J3>vw(a!6$uyNl2biY%qiFVJF@i~VgET0G+QhLm!v_Jlr(Haoh5ou zTl}0sH+EmceU0PdIS};{u@ZX)?rHeqt6HhvkVUe$XqcESVt2gfkcw%7eNdk*yjr8X zq#p>2eK~$FdBBt3(77WJgWciSBKt7{{}EU2k|e}g`DELmEDeBQ&2uX83&e6-8@OgV|0%cS6)98a zD{|0L*TSEysnm7c^9p#bz@>#7=U<>2qYc`Vl$W~kqlu+1S-Gk#S2F5PWuH|!l}|lI z<-CHk*>&*U9(o4f9IDZ{;a|&-mH_B2x_J3Gky`T%a{<`nuEpj|#L;2dq#$m46dsOV z5(*8_4WeO%+STcpYs}@O4@M*<8Oe+g=|WUys>OW3I_6r|fDgMr>Fq*{nr*cLfY_ra z9hv7wv^c4q$bz2gs-=kyE`p4sH^)q>`NUh2m1*v`?PJDYzTaBbi3$slG!+_BT$HI$ zc49RY+&CqAI;19Icu52y-~3x^L)LV?^@LTeOk)M+-BIkDGSB*d>=qQgJ_#{z2-IRA z$vKxr5Dvu{wGV(Rshi%bH&Te@9U)xCPET zWnl{7SA9P&^8#*QYg6d+3tz8!9o=<_^aBV)rN8#1Afm4UXEsQ@!Z#Pt5%iltX1g*W zdDDmq7wzNs@}|`e6#gcdKH{>D%|Yn60=M-mtV1m755tEp&*~5v)a)0-NSb!HyWtqs zd69J#(n2Suak}yaabQ44`xobn375k9&66epfJ>e^9ynp?;U7v5xaa9XNV0~&v8)sZ zFl@zyi_F%v1aKImNV|0nrELriCdJ!GN8IJn7s3SdLTjs`G z-qK78REQA}Z_)0Xe~a^j4nI(etSPo|0!gnvNs6P|Wi9hkU$5)cQEzTSe31S+VNBKB zS{jAxYl2qazDS=jE$k6hdRcJa$acIcYz;}`#44Y3$U-FhsVO^KA1nzT=iizbWk|*Q z*=G)#1ad%A1cDB#T->=l$z#9cYGuCK+&*vjR8!_>#U*u#Onl0>W+IQ=srv3!DRF$_ zsrxKDQ?69q6 z6{fU<9>}FeR#bfg0M-xkGz*^|P*Y(Z>=Mn~L0Jkz3hh3cA0GKk zgd@?Mf1VrhUq~pw!_}?)^s`qfQ7gM2MeVZ70Qd+(_luO22pji9SI)O6D(d9=Y!_j) zNOFU({Jz)JS?%sL;od?=qe64$J0UF-CfZvO_Z};WTN*$4&x9g=oha?#wf@lhg*@|9 zH)dYQ7)UYrk6FxxIjF0uNH-DRo;J=Oz>g&sWq=(_l@yIMPks__t~*a~YGaG1N*5SA**8udQ!1 zg$k%?`nzc)7!lGkQ!Uk9tb*CxWvS+%w_kiWyGif@Pzn~Xz}LXRdde4)Al`%!$iT?D zY>50gGa0_5;C?KU*eE!nd1`D-(c_Ub+cV^LfvsCV66uI|o#fk2)Alk*366R=W>u`@ zd>Wvv(C~u9KaCpRyWSD87>*yJ?>M=e+V^oOdxwBI6j;!hyTDRI0 zB#zh5ROXzX%v^W6t&}BTiB%Ro_O=oj3AZTd9dpG(fv$QXZvf4LVg>zLyN$U z!N_2zE~dhrY&P?pa1aadXaRE4G2?d7PfBw~BWzrvYODgQ%BTF+?URDkMQ_9aC_lW1 zfX_Q@NoAZ2$ec+LiGoFwun}E9ZujyoL1eG@oUfuj`l2uQdnEM$sOg^1ehK;0yF~b; z2`AwMnRDNFUhh2gvbk#rQ46s>_;5U*9mAW7u@ui=ol7CojRKt z0AoWuaV)#2%=rkJ^0)y~iI= zho*P7Y0C07Nt^9uy&(S!E#x0)bb5Xlx@br#VlW-QRE@~jBaB3$GF~&8eId8x{(Cfp zCIHR#Xr!6U?{DE}H8pI%^(y7;8ba;`wDZNy36QLt*lF_=i6Xoj3J9>06t4=ajvK41 zQ3M7K)2pgEZY4RuvU1{-LJDF(WteVZ!aOK0h^6kD^d5Y{A42qrdsWVQ-(`u9{=n4gdWVyAtjhMj7pTQGp& zxlK=hi8PDqi<5t&EWGaI&U@WVP|Le)5NONV(a#a7>~fiZgeX>?P-}k8fPI2{-b@Fn z3-iC1kFo0iX4gvJEzq9-t7}+U8_6}ff+(|6K)1Dx(G~F|$V+ z2@@0=_k^*Rapf2i?qUsr7xL6Zckrnj-)mWu91wf@SZdd~CT$Kpi`;J~8Dl7BprZGw z@oy8&*IS&z>}Tg?*m2_hGg!7(%6L_8`_n`<<)>e=QqOszGRD<)w@`CR}ZyBI_>wf6YD%#TOGUc>p-0* z@D7U+48jORRm>CLZ}elk^Nawy@m@H4RTz4SZdTB514!6Wpu2U?IZQG0W3lhsy1lS_d&Mjqffb5cXy?6&u()EquJ*J_=wQZEeY zQ=xQ|3DT6}NAg&Wi*I4;Es8@ydcu&2Z?nJiUUY;yo^tYoi@Hjc8DQ9#&CXE>CxVBT zF>RJ!5pZ8izpAB)=UpSkF%A7J{-tfRx3V&KLqsZmVS&uT+SDuSBU(FauMXbtvu}=a zpA&3cP~HvZl~BU`Hz~N$>b1coXy)=c5g|{v^;CTqkT#PZ_rD|;PXGe(R?Cof>D*QH zAFWt`E@o|Z&Gu_)`#026fBaOR!)VNK|uO!vFd38K4kJ@u>lH4wo1`Zf{tdubN zK3W%Ky!$%9Yfrh9hbW~QL%-KoB`IfLK7^PP z_e9K(@*EpH#4SvCMAC_*px7RYSwrQz7~CA{>2b&HzTVaxZ}7Cr#V`D0&57^Ud0q%W zorAE^<_xVNVgO2O*8xKgsu=$EXM+E-$^Hnm)opMXyUbZOL+AAK&}6ZYJs{D5&1LMdH0?9t{Yvc9p{WcMSrq$zvHk4knnoGB=GkJfDV z5%H`=5?FoCs^FXQ1>0H^tnQ7CvVT3kTufTBF5MKGd>S9QxjyVn>)-} zLjMKEx0%cRxI{?YC);HTL`+|Le|}F4iGYiFN;$S~4(emLKJ|s$c<$XvxmUn>oJDTf z2vCOE#gb(hWJ@RTzEZ;gXFkli6<~0;#Go?^r3fnmLQ)r@1WlJvAPjVPO2w8;t@-jo zH*Moz$2>|%50(vc+ms9RupOzf{op{yIquCtX)p4e*Ly@RCLPJ3EhWz67ko!bow9;-eNuz(ypfqzg>!WnYS zD7#hp#9^1hALVar)b8$UT3HC-wbPfSJ4CXwIa?#l^4ix%?dNNc+|piv;e_Au6z6M| z)xs2~{DY2YZ2NLfsl>{%K}&5+O~M@hwA&zU1M%s{aq(U`P97MY_w_ zT&tLJx3_LUKBi*SviW{zLC!y$ZyQNVMlO5DoRPl{I+m?yYbM5b1t$QOA3i_{7BYt) zea=i!$D~k$5^Da;r(P>Hgfdk!scucPZ`VqoELDf1t>>`=aMJoM;Kp{%8I1wrXTRU> zboZv0p0by?sR?XkuVCYceRB178(O>@b2S699e(BPbQ+GI_h4eELLH@=;>cFoP;K5V z0-M#g(H0cvp};owfd7F)Zq1rX4t*fUE@zI8qi zd6$vtxszkj-Z@;rAlt$WE~Q0zFwGRTD|)b0I2d0f6xyH#>pkUA*FxMuP0l%DyG1XP z+L7|i_2??vGX9jP5VAOTEVfC3soxB+u0{r=$u&yYavj1^wygv5hl3FFy|Gb(JKI;* zUdO|NO#iT2S;oBp4XukM|ej@*>T(gzo zz3~RQ0woo7##>p+N^EJw`oov2YEo(_2%R4sKt?oF*SB33L|tq5C-Mr&Nk*pWXi%x= z^V;-zfZ;>ez)>@j5EK-LFr?O|_k3;BFlAgd*gL!gSPvTN-~G1Y!Q^;l=A&Mp0vHNc zRqPmnf&y;7|1Q8(pnibknbs(0N(TTDqT4#_AtOqkX#iqbG}pG38&kT8LL=7iv$QnX z#cLf5M}8AXe8iz5_b70OktAHucT}4P|965ojoy!d!|-AF*&51!MlqGlMQVF$wVVYXPZFxRYeLK3PaGYUPa+Kj@dZj@`x zU9psx^)k7Mr91n?Pj}**cWOi<_joc^OX?Axtccdg_SH)S(=&>+7;muNK zvP*ypK^0|rQ*AAj9zh?>qoQ_S9A2lh0BXLEzAkBW{z1!FD#w8Ac=PkXIu~e^mIf|0 z(iWHe6lJEJkBrMF5Jje<``sU_hQ*EZTfn5D5dx34J@NSKAUXJ3M{xts{BmabLO$FE zQEBLWtLV+PyWVoWK^)c@9ROR#Nd~+)cyS(xuQ64B^lRYmaEJ_;HQ}K*q0{E0oHkhb z+6n`(KbXHgE{Pl{g!&n&NI+FkQv@LqSx?Det%;!liQb(Dh$hXczTlDFe8)Y?mY=rK z;I06-?vR-N``3-NTB#d1`l4)MM0L;kRYdW?23lO5#W(vO<}=(W_XzO35FC&!cC^S! zq*MQ`t|TfJ=`07s+P*?LlIfRgTCjL@|C(tkr)BG+?`G0UDFsxHn825EY(}nsVGRVz zWAP#2I!+bQ+d;OlI_!AYgx`pQZa1sQ-}qA8ohGlbDp^0o4sFgacbJ3qIMa9a%Ii+N z_7-U&^XX^gU@n575(GX}B~cUj3Xk%l(4l;h-85){Y=#zGj**Ip$I+cQDHjFeU)j>E zqCz7qvqF5H9ImT$0)545t{ZFEKL&>{D~e`%Z7uaZ8z&^VHvgFdcIPPOGQtiG+| zK6TuWFS>4c7ABo}pWny;LsG7)+(kAD(OC(-wqnBtsYS%xk;Q5YMF;-Cm5+CP1kOJN zB4SLEtbBmt5yA6=gC07NurS$%Fl@~k&ycVuo+Q6u{+4aOk5z;SSX#N{PnxbwvUTA3 zU{g#3R61jftCoz7s@K)Kq+fW*OX_mELnhXA`A?l1(_*7Y6e3&rUH_Tdzv)Z~ik_yC z;X^0k1ETJ;e6+5r`A@+dwQPpK1#gZ`v|2`oR$SZlB3e?9gD0}g;b-Ct|Nj5e%vpa$ z742<&DCzDlC8ay1K?MOpln_w5JBRKN5v4(5L<9wCh8_^<8jx;8V33A^0S1_R@O}S< zyVji_X4cF)YwtPx#6J7^KF`PH(q)e9G>od){O*&c*$ZUNbDwuappTp>o!xO6JlJmU zhlYy))0g6`OU1pLhYB27fJcn%BAGgq`I_l>skx4$nFcqyOQ5|2)0F`!$dLy;;r9z} zn%Q%DKk!6<>PSd2FP*u3_LG$Cnb^vBeTp5=EkQuyF*i@H1<}FF{3fjw|45HgjW?mF z1c&+Bvgce9#h&^h)DLNg{jJ~Qfqc^|`FBaWY<+7bv2F5B&PW!*$Ru2@&W`FY`#?)^ zo$CF?;a(6YdD)i>^9-L%1)&5Vk5UK1;9SI2*C-rjrV!C5n5)!bF8>b!R_QU=6kd^d ztx7Bj2E{7sYIGWs#I~UTod7{^i2ghN=h-!)6QL_1?Sq)_E_4|({gQlUDJA?m*PAD0 zNO#yjJMcK$MNS1TVl+*Q=ptUR5Mk{ksWfa;-VB&I`_E|%3HLd29-TR?2DZ5iNRMbp9lTU-f@)01+o5+ ze)AaTGabU#3Agw1WD?qGFIWMs&s8SxzlTje(Cds(@+4AKiMhzMITq#Dhs*v(;=u!R zC7+XBVguzB!ZLzMhs}xz+6i#g~h?TEq+L@*IxJH4e zt>;kGijvHHc(52SXN_c)B~9?aubp$ z<8&Z)MH2uOLWp(*Kvo26JnNW@B^NkL|NKA}GPulyVS~)X3!Vx{s@nNVGBZ5pmYpt) zugPs3=BlK&RoSNqe9_fuRp3fjEls;jDywsMl|le|D3TMMHSrDOr>~929`flLCH(em z#eoi%x62jSF$E}wRLLgldgwdj(WnVrmU>5_QS4^_Ui11|=<$tt@F{G6AAP)HDL$2m@^T%zw#RZ6 zXWzWde5x{;&xb-BX&8QHn3n~w_gEWgfQJU2 zhlM{cS3h{YWNOYaw{WLKX?C~6WGfE}tU_FR2RwCvu$h5- z8WO!XW5Z18+jMKeMGU_Zc~7A;s3Q#Zyn!1~f#hRIT{Yi8Vp>%OKMjv6q1t2S0X7Y#ZWk|jT^F)xX)DL((o}F%L zs?gBlNjh=UH*Ni$l8o=Ss?K*6{0<&Uputnh&UmHl@a+R^9G+3 zxu!W(-O8B%@XzE-$&c=a9^!KW7_dVwAuZYk-`iOZP6@haya{;7MklkbGRphvwiu%mrOEKyU87g7rlekBS z2BKa^?t5r!?DxQ29Yy-6G>ARWL&ZJqsQOMkX3ELN19QY*SIv^4;(?tpxl(N6Ufi>1 zV?9WE)_zqeVPl1EO>7?R=u4>}3&;x}Q@E$m8(k=4Z(T=T_kXLE5DyQiJ!MGiV?$l$}kl$CQ_Of<4qjUs;ej)NX;$_dMCzgeMnEn1NW@HZeh)d zOE)cI)?27UI|1_e2$eF)*5d9ame&{J65yS>MPc_ zKxw9T`Es55nW4C)xq3XWl8)%AHO&2fXOC4#p3I-F!kaBtZYxZypnTie0~U&e9Cu3k zHw1W$#vJS__&8K4zW5jsk)c;1tGvwec)#am#w3jN?yO(RmSwxWuBr)my?@h{r=VDe z2cjYcE>oe>%8hw0)}8T`+ZZI^;Sa(&^d*^pfF;d0K9o6PGyH(;52F&>^mac?*+(sC z`y{0UuYod;>fDTq$5Pe`{F(m2b?Smb{f$*m{y-JV&vBwq6>oo!DLafQPg)*vxFj>i z&p$3L^t{j32BB4Qx22s+w+0lUxZ2LP4Fjio%+<;(e6R&iCHy()0C z)OgCO+xcpDqPt+_Hd+>-F|$4Z)zZ#DJhfJ|2jU&+bPIZu*&W~=);D_C5z8_Y-!}tQ zGF_+1y5f<#9|-X#ff)bXIDTrbOqNi?lXgL50I(2lf|lcR>YjD~C|FrCotG!~M}wOB z1N-6DD|0qq8$p$KYT5t7#P?n7BseWjPY`&&LY`_IEL<-zYcB*@nF&F4_A#q3zt74; z_2}6zhhOmCVRa8dBXi73)%of;8y6vsOZ=4-3%cHpX9&}$?pQ_8oFW`md*oenc}d&U zrE2+2F=ogW5djcOtm*1Fm@WNKn-ApTNi}tBoOaGa9Ce39Kgm394M$&J1L0WTJZy?@ zLeMgNeSxoZQY*ed#>lN1@$So(usm}6Z3y?x+QOHN^!_G=2CfMm@$ZWe%Iq?MarFSU z_7CJtAcPw(EBu5&xq?s?p9AYNottR=cK#sX6%`L zN-?92ppOVVf*O5*oFWqo3r0)WM4k%#-c?zgd1Pyuyi1#Q0i=&!H?yW51utHW^5Y48 zEoopQTV@Z?D8Tw?6Bop%b}bj-MKJZU zt!OId63-LP#dontr?d%TjDIwL_iTAX0eaIR?A_&qD=kdi{Er24sq;ueSbTSo)#8+t zlpD0|WF#uVvTRDU#YX@FaNAV)s$UMw?rASQh8O$f^m=^QS!{JTExO^1ElkV90c@>7 z%|J5YAazYekiBjDvC^c@M2LN5fU`*fc7RhzZ@?ZncHoeoj0QFT!4bD2>Z|lBe2-La z1h&<2dIjPMEMj~lG`~9wlMDYYeX2G^Zt=d)>1#jL3_rAbrkNGLO@{a+vKn)-5)!c4 z2fNsBSyqHZ)hWBQj?eohKk;2q;5`*gyFTrGcuGk7L?k}gDVA)hr?atU;cZK5P0mN5 zlxRSC^tC#fE)H+Uf21pk2rV7LWU}R1)-0ovz7j5et`}M2NI|zi^VmVH)zXY@EGLq~ z3#d5DZ0y?QF*R|5aT)-{KiYI5$_ z6l2$$EwNk5E&dkBT5TQILcPIL~KH>b5%~vLi1b^2^T?(GO zev|*#^K3P+oxNNgY_`;ukk`?_|;me*esg_$($_TerQ{oFdN_+m??Z&sE9V^uz! zANx>A+kPPjAtLRWx@V~us8=`?LBZttb&O={N|{(24)qnMQf4FDU~H4f&h9%3Hy$Ep zH^6hXqH*5ox3A;g#EJGSF@s`i8+0|sm24iy>APS6*0CRcIrDs2@p?zlj7Uv`$4%qA z9ue3#6JWm#U%zyeJWu~Ao@Cvu$t}9J=z6O!@*v=vyZC?23hL|>#@h09tGrZSk=>o4 zu17rb7ptD$3VgG_q14TOarLFNK#A85*T#jUMQ^JnbK#E5^s(I2hE%_?8Z3ejSTzG3 zsb{};UMnM);X_r|4INK-y-k8+r4R*CS1f@Wr~uk}$6p-Y7T}b7m~io%id)7L^)0Fw=;-co^bn2LCWZc3ryAk8GT$SO|qL&&gRk>HE2^C!$R|0ZdIiHTYuY&Wer#w3u7xgc;B4rq!Qb{0RV89a|aIS z`ssT@!EjRXm}rVqJD4&Bkxs?FH+RVD^;Oc@vgfHKb;4}N#lExTA;&l<4&}G^loyZf zR)BN1qZQwEufnITC%xYkdB}y(v((8POy7(gFgIR-no`i~pKkqkos*W7=awujn zX*i($l|NPCrL292n$A(_M6TBIEOd%oi^Wxh-h`52>__djQNzfmtUWY#Utkw1W{#j#)<*f&XGZNTzjmq>0hi9^`}^ZM8{({qPQ9&k1p zv9V2`Cb-o{mJkZFIO5VW&5=gB;Y*3L)-&bGd(*S)DD}#593zjh7YpW)^{!V!K22a&m^7{>IL2YPF$s^a3+4^gn zlAO=DxhB2(6$TC8lOy$*+52)$nZ%r4p>s>%eU9?-K8A@Tl)ks+hlqv?hE#LZc{~g@ zxHE{+hqd-hx*ixjx@-vNvI?SRb9uaRxpDdIwC?6~C~^UJ%LG>(*Z3R~zTkN$GHFD> zF>CHKwXW%-hcia3-Dv71*nB$Cy_T>Hv~{r6sAc$1qb2E)+ipPeM3M2rNFegkORREh zvbdv}K7*?_Ma-r4*cC z<9m4in&d(HZ+jOzFndCRfEW# zfs?`7u@*#b_b?_|gsUs@QYLW*Ra{f9v_S`uf&P9wZ6|nJX9hm}bi*fMSjMZtCqA~*hn zNcqD}lwev%Qw2WlhdzoK9In{~RqS%gH5Cn$3$lQAaBbEaMH+b8|FM{tC~&eY$S5EfEwOi+Vh`Kw>oEQ?TYJFdl`GFj4kZML>`d}3N6&@gnOGryq()z85J5M$PzDE2PB5Sb}o+N(Sp~Vj({<9BO5IqYU(v+ zylnN41@R70>*Rcn9o&zjDU5&_%TG0t#3ilLUOcXll(G<@S4;Dkr`i?iOG}E)+TR5|Z7{Zf$HnEA7z>z4%nk@&vBiD~57jK> z6HCTaa(x9{9z4Jjpx5=g*dI?|b${kG-pv&Z7Q*3Xu=Nbr_jLS1gsbwylw9uzrR_gx zmkKi=@+~g}t9TR{E|)R!1|9j8cdHYPlekc0^7PqO zOQo=SSFDtmk$5`eg;Onaudv{#$jO^qnyw0(B2dN1!M&hT@o0_Aqs_K@&30Il(4&%_ zoi7k^E%rgFMz5Z25`tD=ne!F>f&v4Ct1AP|U!Yh*NnL%Kt*yY)+ag`jHGyIgUe_|& z(FR>tECwh=Ng$-+CWBo*sg{6Un5E045(wyw?>-qI$TFXl0qWAm8U0U$#1+Ja3WTYs%(?(kYr5J}$EE7Bnmze}_@YAd6gO1@iK@qZLfhDiL1 zVMZUy-43`~vt}hs?NIuNlf=OVDV?HA0HxQcmWUygDQ zkzKiMVLT8?kVq~s37AiY-VhMm$@1kad(9@}{sUYyCP1Bp7j{1IQGxQ z<12$r?ig7(xB8@n;5eFU?SA|rDrsq&4|<4}&VEYNpsg`x71;j)>3;j)BG^<{v4qO; zsa^*wOd_)V$sVQqU&;k0Qk(*=B|~S+`epDP)w-I RJ;AsAJk)%m@mOW>0MRouH literal 0 HcmV?d00001 diff --git a/spqlios/lib/docs/logo-inpher2.png b/spqlios/lib/docs/logo-inpher2.png new file mode 100644 index 0000000000000000000000000000000000000000..a25c87f16dd1afc71658586772db657a71c7e337 GIT binary patch literal 24239 zcmeFZWm6nX6E2LqyIXJ%?(WXwBuH?F0KwfY!4`K}Ji#G21PE@y-Q8UlIm`XL?=LuC z&sI@ewKdf}-90_iay8K!YVznPBq&f&Q0R&ZG5{zjI8rDmXe}fJ$S+=B;mjak$Sw*7 zZctF{AKw3+u-WM;&qmlR6=TfeSJVQBHy7xi%jjYYK*#gj{W`24W=WlE&;OkUTI08PWlKr@&8tl8 zXRBpX8p_1j?CKxEYh%sMTmA45;z4s@qzOTx{#j>Z?JF1*W799aR;2ojr?{sgFk9B| z+$WvjNHrAZJHGuI(k7ZTJl%gewPgDpQmswfp6@LQ4Ox*thQybTMZYe#43t7!#EF7t zYvsrU<>x?!w#}{=M{&PXk&wi4kXZIRKdsK%_><}%;8M`lL@CKB2_|eral{8Xnidea zPXZx_60-CQWa-Gt-S|++KnS7fg)*(mU5FM^L5K{zxJ&+OzO;^%u5?v;2p;15Y8!?S z^ug5}-zE&(jxVtUqO6Hn!5rlxfBaI6`3lS}aW`3!{v!&85M}o`Su`7_X|mVsH7w?6 zDKHmE&##1rMMfmT8K&3HaX`Exd0(5~fy76pC4A@baS?S{#zLY(5_*?`On?r@IEbo_ ztp)E7ItveTsYb2+`(mh{@-@)Me2vzFH^xM>G|RHjVrXB4sL?by(y=;ovL-g1=41!3amUZ=bir=)GQ5>2b>KvPB=gC!^Jy>GrE?K|NH%pQcQ zrJJ@-;KnF)uClPJhzlukMiA1kf8R*N5tgT6kcBLL4_<5c65m$`>C!9&m8f?4hS)yY zgra%yp7#-KeL>KBK}ZzZZ7Qf%Z7)IqrLrY`KYjrtg2R30 z>3O`XQUSl0Rc2*6&?3AR;ZCXlGcOU}g%H))s4v6qSi9%Hu&5KA*`#QRtK9}r=T!*w zQl$%$HJ?l{FjTr2uUkVQYDS{88h>Es#;crrXq?U`c&BDbM z2w49@u_R64nPSUwXzE;U3nuu9O%*ERCe;p#n#PFy&E2(+(RVY(gmLQrQ$3V7^#WAO zzIADe1D`Y#37v*8oZj#z0FIH3q>vwz_i#>ck>;Y|gF-8T*~ zFL?J;x9!yhmK>{AM&WCv2Pr8!Z0u}-oU3g(9`ZfxHBm7!{Hl{)Uf&HO;au3zA-@WK zomV9YO&iW=skY1 zM|u8C%gN0dUDC@VhLt(=P?v)n==P)S@0x=`%0TZp?7W{KYNA8%Kf~QfRbgZZ!+K?A zPnt00BBCF_Fc@rBp)PRv(*?{uO_*EmS?i=#YUfIFBOFkj>kT1jbn|e@?*32YDx3^- z(Vf@(iKFR4RleQ?IAUa>K^2!eA{;KFl`6z&*7eiIfP~GB$CVJYZD;i5Q3UZ1U-Euf zu%w^XI@&I4IaR7g$S`hyk(C*<`AOhG@r~S{;(d$x$PiWby@-pmv^db<=<7=}g%es3 zi9Uc<2)7zdrR|u)(4}~E_%`84`Dh)KSI#VjMZV57qKZD!zaOh0dQG`0Sc-?o#a722vd?06Rja1_K}e0a#c1* z(HMKbMHpJn;}cuSs99-LrteW%6<_BA$;z#&L=@LW>i_{&QK^Zi zh+PQ=+XGgr?TuqY4htJV(W*vZq4BlCn;1CT5r?P1or^-IQjpDVFQceS_J1B)Gw5&t zzH6`#I?j08!9@o+IJ%42!KoQ%UDcGQB%dj!doRl5TM*DBeIumv zLLGn;9NU!j3TKJ71-qf z+P{@M#s8|+`6lh6m;rxN;(s03EtS$sgHMovAT(jyrO%2yW@Q zJ~ZN|DkN0*b{yCmsy-ze+`4?%utX_MeZh#xbF7=P;Z1wgkvfG1v97XdR4G@w0|GX2 z%X9ud>NM6zGNLlaDdc5ZpJfg4h>KI5f2T$3EHfXZXfkkT!5C3x!O3s_!z1oqo9UWd z z1dC;fE$TWHJb~w0_XpiMucgegSAtuM?(UJCnBG}@t5ZKO`n0Y3V7f4(p=|8#yt9ug z zbN0aR@)Oj|e>ZaN)(RJ2rihP&FGy;5MICGLyguElih1OA-TcS9*KHFUr5CwQ_%?K# zDL>b;*3p%eWYhROo40nJqWgTYfGSwDO6_ncGZ&0t-hC6Wfkqjm)*Nxd(S@dcED3`F zHEBfTzgS>3LWs%oRwe1DVfk-f@dJk}vy#PhKRK4o1zmhT+V{G+XM>-J2u=oGE*pE; z*o$SO83c*0Ct#MZF?9a|J+_Z6wPjM@c1tn2nCsmfM!u?FKd(VeGGE&x4QE32x1(!A zLAAp5L}2jfb@n~u~plgne4u2 z#zbG?xxd)>o+V?EygtYwxUe}r)Pt4c?jKcdVkGifim6KeHXSKA$jtf4cO`F|f-l(# z4JsVeNc_fXkDi}3G5^Kb26Ps~d+GcZaN*00R(!9uu)$+lwTGj!=FoZXLH&(Ui*2_UG$z?S8VHp@iY_luH5!VV1%Ngvs&-K2Xjv3K1{`kmUsnCE=M5G4SGJ zk0sp{{bc?#_lA<5I&l8!}#5L=7NTH zYZAtBk2Bw#`peNUG@H*b z8*!8;yoT|%JgCtTiP$bh!1a}of?X#6%hBpj*g{KP8Q&tc~&EV&q~S@RXAI?*oI&^m!t z2s#}0`MkNJPoGNMEsJmrTpeLLzKZDYu!!1X=Ncu+8uld_@Jyg3OpG>Ist&gVPGU&>nC z@OeF8Avl?dfmck=#2#O?lX^L-&68^pS`e zq@ifeF|zOAlkXTDk#Vnfl;?di6jJs7QngVwU zKia~qk2Vbe_w*dps|82)P8$!SCj!R%uGiT&8F5RRZlycKO-*EFpnz2gz3F&r<#uhT zca2DSH%;ZYBPm^KKNe{QHaK7tfP&1?n&Wth7PJJ@iDZeVXe?kgbK}Zct+UuA)F>Y+ zJuu7;RLEm0C!*b$WLEZmBLiuliN8m(Ht-?n@P#sCv1MZXWms9N2_koOG=*|}pnpG2jM5qtW?wQkWhW@j0HAa=lkV>zhzvun9b2OwtY z27On(RiEr(Xc~^pbuL6gqU9752f}9%sV5@D>y>{#^I!CiT+9$OLMc#6E<*Xf$*;im z-nSG89J0dV#XHdywIAf+jBUf#^x4*l?Bpyw$*}BVvhe%3h77{{mG+~hohxwE^23_B z3huC1nF-+IMY1)Lzz)X*GDlm?6&2(x2i-!<4>j?&C;UQVx8i)I&ypCWnvawMP;hA* zas>Gj7``x`y)KrQAQvA_U{P8VS6nHeek@J}q@SQUTkNDnvZ4iCCp@+|lV2qz*Cyu9 zCIE>ik@w^O3X>Coc#fZwC{0~E;ZAQ63~gtdZ! z=#(nq&|*B|uD{f@zKZ)Z0PgQ_vqDE^GlE9JT=BJa=1S$M zY5&zIyqWAUJ901oIClwGHRjHG^prKWXt$lbIKDB=9*@czO%?LzcC4$65HQ}qp14mZ zxIk_VOWap69S$8ruDrS5Sf2SX_D3ohqBOkhEDOYAI zS0Qk?UIZ4(PzsF5e$1;-{~7841CDg*Abj1}nSj~1QSighZOL$#AZnZ_EK#qy{nAf& zBM)o$U26=61dRuPqI@ad?EKzObbgMq4Wm&9hqclKnJ7>zujA2bHD7v}2{vDw{9or$ zY3%lY%8N{1{Y##GI8vufpQY)Y8nAsGI&-@gcHj9r*ObyYxRGTdmQuef2nhpm3cq3# z8=6mk>9O~zz>}ngG72_@h3b~ADfn*ueEI++5X}TmO;SyNxf-2NZ>;j$f@F;c^M(7e z%0uX^+M!57nLYWTL^5U`p#-A)mfU0_T@~RM^zK(#SVvq^ycSJ}T*+QHKn*8u9<_ zMb-EzK00ZxuFBHn1xlKAT5{Oh?-iV7*f@XR5mWXJ098QhJrG!@lqZNv2CLsOw9 zOTXtNLf?v3<)RhQse6|T$7p-U@THr1twGk+kj$XN+}VM0{4K+At&74p*t=tY>pbFN z2ff+0q(XFZvp4WLw46qi9^y1mtMR7Nq@!;G=p}!ByN#Oo-43@&!%dzfnNlkDFv352W=KfdOc6#I@u2Kx}C2CwB9E9;C z4jieW!=@+y_e}1C**p@%5#bSO{7+naAD6}Vdh(pPp(A83?>>^*M5O^{*vg5P113>F zHB;V5S*c#w^@Q3m$pSd3eb&7MWks+y{V>H*|Fo;sr=7m;ttop3+EY~y)at%@oIEno z;co(BfK*|~`J#Pah9hZ@Bhspru)<&nMuFg2TmWl2u~Lb_hE9vkFH}z_IE- z`%KqXeU`@Ov#Y$nh!$z(ln}^J+;k0pbWbs!+Q@V~si(4iVp_>7Yrf+ZWA&tAZoJrH*NLSZv z+>e#Y#xse z=8<_}O|g{v?d^T{KHu!^67ZCWHC91qw;@MneQ5sLlo1;8GmYZGK^YcZ(a8sVOowd= zgSFhp-gj@?cJKopEvZSi$)2C8mb!IpM9AB8V{au?QnB;-$o$37ViB{377@QNzK2PG z$^Iw8z29zFQa?6@pjdBEncB4t9IA#I8F&2X2Zzk6b^= z9$N;k7eCd1FhDJEh+H~t@7iYjBH%434wL5_(}Glss&1a=_YArKxibX1EnhF@z!a75W@vqn8)O7#y_|CY5&i^4$q^iB%d~cDvrF_}tS8)!#^m0{p zeABkuioGgP`|R31O>S@=IK3G^?tgwEufltHtI` z^~Ic~^)u;q2H!aM-?&~f=ib-%Q2W3Pj?}dTxMXa!C(A`*XMjFi3EN`DE7IgAna9X$ zg3A0m_=l;vYf~srGB#BxA?1bk06#(dV^n0Cli*f?B(8ww+JYKeV}EO#ggQyGfpc42 zZkvFZ@A#I?2p59!%RiU?AMoGX**B-sYyEF2Hx7ELN7D=&t0A1BS}A)X$X=EyJ#Kt0 zQ7G^YMwJ;yHP}EJc8A64>j`c@3rQNS8|;N~Rcn#f#gG)z4^cmHh;0o0$~1Ku!Nz1? zh2-8@CuXZZ>Z;gAlSi1Iy;M#!TnTR}wA5<7#x3SWDa&Nm~3ce(e3h#OA{zbjG!U~wj{5#^@8 z%yRy0=VLgE@PYfDPCY-0bxn)#){@Q6#p&WH)sKoCe58iUxp=I)s({HFcBl7z3UQaq zj%u{ao)!c{`{c1YS(ymHnIndZH4vc6bIr<5YRb8&xqL)6y8EP`LlybR3JzP17r|f; zg`aTxRH#i)lsHNe=jc2SUiVf9p7_!rdF|-X=whtl(p-t(b^N zPML>k*lMQ~t07LZ48eZ1?d`g^YRN+ltP8w<-k0ir&0VkxRHwCk&D^HOUT@@+ zFs{qP?Ae#C?@_oWIp_JlnHv!QqRlC-;w1lY|j#O1sj(CTcncEgNnfH2D z>|FikWlEz%L_@Xeb`R>^dP}+ggcbPnROxS`|rSf0ncXoZ9eD&W#&@}g2NDa=VM!nS?ngR(y=c7s&Yn+H<9Ud2pN2a0Z ztiTg+enZT}rta5uR5;4HvTn!$n4Qi;-Rg?ax|5NuFG&Hkki{=;g!rl?$_~|ge{a4z z7x`f-trqY+kl~FjT3p!^=Ko!ZiP`Y@`zP^~iN*7lE4n?(f5sfsG*{QMv@iBk_)~R< zAuszV8A~%LdhN9x7BKS0snQ}*lPXNIY`u~r8{V$v4BK@w_wJkhB-=ikKhp>vS+Y_5H&x z5FxZ7x_QMVBg9!!-U&)9!PNW^LE9K!nF66&IrTW{ye$PIK13xPHg#^_hy2#KUPO8ZaqOTSJ>_{UH6K zNxr>cFW{&$~DfS5z_nCxM6`9~Kf| zLBJRI!qMFeycmrQ>V~GgA^Sj&RE}e<05(U98!K(r$r_g)LTg_kSpm*(J1Uk>gNN80 zFZ@EYl62}-3djoStkM0tC)Zt+o=wO+DiQZ=hw7|QA6+I{+G0{;47Hc%YE6oA>KMNP zFy<>ZukpxCGi5VGIS>2beO)wWStKKG;YK4!$z|k1(0a+Aa$V5x|7_8)j2F698@YS* z`ciGc7dQkZjavPuaY{~Kok5)w(FOdIUY7whxhQ!%mUsxk;1q~)c4^kODy1cz>69a`Yx3@U(-n@YlpeEyqDX6N4uZ$Ik2NnllySwt*+FG z``~qUp^|JW>1JdCLamM@_!}2bp-O%c|r1`qD9x}T?!N~T4XUo`j}?> zuZW3*G&yOVlq# z0UQaU67Vg9BujkSa|QSY=eQkAzs_0xN;usnZSE;`di<<^zHM{#%7^O&Ngx}*9StT~ zfdK_r4ufkNtbx=xUPw*Cfv&7NH60`?8*()!hakM&>#m0b_W3W@;u}-=GhFd zW}R0p{7u|#0K`a?Wz%F!a(>%8Zr6@dq*CvJN5%PElG1>JYTR%+LkT&skA&KXGda(4PY>MJ*H|?FKbtA(t}{D z#C)yVj$e(c%;Q|ENp{_+$Gs`1C<`_q8W|BrU9o|Yb#$y|HhI~ZZAsC1=pgyGr&gxH z7h3O6Tm5pVd+*R$MBq}UrG(99v72%YArkE1`}yBrf}|Xnrcz+D?L3l zZ$ww*T^y&$5AHbtxBY-k>et|^o{lW%YB<#o(<6{#YpySb|p4w2k@l&?tSRfz0kS|O6!+zYA5CgjaMzK-*Jo2(t~Et zU{0hVK$pP;fTTGNzTr}7%16e;6(vR>|GXI9f&#hkA@G~Z-ouT12Lwp42T!KWTX@KR zWK{k_ZZ+Zc($p?4QW{#Db4W~7Tawy~*m=ceR9I_FW*U*PZ=O39sUXRT5}agI?AdAn zGk!b=c|pl>Y&ZuN1o>XiMc_!e>VDsE24K&cpmUTd9Lkgd3tP7)cE=uOs;f^tedQJG zTKyQW%?`6wF_UI%)b7z&bf;)Or6|yOe>`{4s^yC9Y{z5_If=02r0)1ZLxOksXd{#m z8t_z@j}fAYgtk~)Lq2pS+<1GO2L5WM`Aq|PJPyE|_V+C(##GwFp$j289;kA*sUh~7 z*O@YsrMV~(Dt!B=g&kyQk~o`?2u5_`Euxo~T0~VG`(WOFDs2+@()2lyy3)==+I_~- zEAYcP(9!!0cd}RInxmwXDjHc%Q!Q9RK2=L#7k5yueEx~w9l|mR5VJ;t4Ux5ouk_Yl z1JZ0Bz(H~TQr5>5AH_;)4#KnB%~rkV4yKX5$6PV1({+F<)x?bS66~Kh8L7_Ry4@bH zSwx=##JCX1(IQO+cC&D}>B0;mB5o;gJIRN___Nv3=QL^#A8<8y4AkZB3FpIE8#3|g z2fZ+6$KO2RCg4w3OZwXXH6=bT9kt&--Y{%)9O-)51ZOp_F)Tf+gJWAa>h7GnCR#o( zpZVvwEVigBE}rXxp68%1>8d^94bTO>OgB=aQPX~{R;LObh=J7z<;dH9se{XB%>~ca zvZL;%n^lF>0Ials&S9mA+DO8LOcwDcp+9E=Ro{xEzWd#k#s|{%fRnsj0G*?n>jxFLET}r*)`Y zbDHLsH_;ilUq=YMnyDh-C0pgVuC19+_EB0gS$ufa>->vY5s1Z`*dsXuV37IKI2Q~5 zWZ7l*mL5p@w*CBv(E0g9$e!-B=PM^&PqyOX#B0v|;x0aD=5KbOoE4YPo3-swC&#PB zI%bdczEUb;u-Z`UBe?ox{incoQ!OrrRb!xwCav<_%SG_@A~mtUF7RYu<|3nZv6Jvh zY+IyM+d7VE-cY?UAJ1m!)wjCy@6YFEGj6ny38-uD3uG{F9%}!yiUOeNJZDs#LBdK} zd-2Uh3J-2M1RTMQsonV2_2%z>x_9@= zoSLH{%49rN7-hv7wyz4;{M{k_!(>7pq+R^$bB+6H6pnK(fv*uxJ$dodlgi^Zk6XPK z9Z$ea6OiVyIdmS=qP1IHuvNLALgB2b>&A`az7El;Cs>X*c>sYL!Qm}Q1ct+cpNB*L zGu(M>D!GI3dC!sW6g%IjubJ|XF?Yx_JNnYj8N9L*a=X<3o4+XwH#RCgLkEAwxm- z+nY&0pFyoe0|LW1MoyTGy+KF?piOU+Rlxk1 zL6dU77UgH$;0XOCplLPo`(2nM#rc$)4Ve&7I&~mPG^=YifP#>3!%UoTRLrTdi~9_V3_q4{RVdDlccZcSHHZFsKM3*` zsqxk6)*`~}QMqbOVZu0Ue3@hqs25ov&nGsGGSH6`K#X)0k6ny1GEL|Ho|O&w!Fuyi zh_l!%uv&MW6n?3h8JD(EGO`L*BL^#e)Q4Kc3idm>-mq<;b=<;4IO}~^b+UY6Q+Eiu zj8g-&T?ne9L=AOcwjgTiLxs2A9Grm9*atNC{b_*n;jiH)~3AM}kjF6XvE$ zq=LTG)Je4t1=)jfzp%{@3yzuJxKylf2%u?`whz@jGM50X57cG97i?6 z4$49m+PaBH1va{aizqFK8c<_@u(H%5Hii--k~LX8bO;$>rDq=iee6hshU?Hu)>G~A zlF&1)S_0QH4hqHAtRSI(6>Op(d*7HKF=y4cB}kyTpv0@6ps%!PC8*%#3rOf;zmTSe z$Dhzd@)y?;znEOveMtIDLw)r$DV{f18pl%TBtXETN#mq=b{S130P<(NV*!*TynswxOhQHTIe=N z)hda3Hb~t4`bVGtw(hGKVI9gZHx*s2rabHqDv^vBA?QCi{7x!aN*cP?ao7W+tOb7? zz$FnOXfQuFx~cATCK2f?u!9uD-^dF8Ku0X~Pu?FqgtlPt&@=lGG^h!N-*X+lLU33g z22ALM@Z6P9ykxjOJ~4s1g}yIxV;D_k&moZ28Oi{C+$D+0Xgzd7RlpPzS8*QI?5x9l z0*g}vYXdHbg3yb2!f=6rqJ3r0}Fus=D9e?q9)ag6%{BRIWIDEXiLTc@r=egVs|?dkS#g_U@v z>}LjK7|m7J(SNAHLseqEtsXG3MtIivpHFwq>!`M)%gNwnYp9((Mf7Wa0Aj)d(4#?k z=&+bmCsBJG15U@*HYdjW*I~SXfR1ue8zp-zVpI{d?7Jm$K`Y#Y@US(svvKXWI>N){ zpho?JtEnHXjPoV-b)Sz0nY-ug4J=dGQF-9DzqNTJ+WIP5V8dk`2Wxu;SXZeps-;t9 z6?7x<$uMf}7Vr9X08ml@O4yj#zp3G8sV+3yKCvkz7%L?b1u3wC$^rsON*acwVfl4r zWG5VWSV0Nd@?@tt`#CAPN}DGphwrf@8mTVJh|lhy+F&3nfow-Ze)e%V#->MYc0Wx0 z{IEkSFBb1~pJS$G-$YE+UHLIMX39F8{^-QV4mZj_1uF$y-e z4-NBYgS%AI9H7ZpXromC%GXdDAi$KDz-K3)DptaY<7Ii+ZfX{}&QJuOizr2U_2F}K zZfyf}6RQ|zte2c`7h0!0D6#~};;Kbs=a5$Lpf#@4=`PVY5T=mN|HMQ$_3gatSJb~3 zL&QpUM?KS51EcCNspAcn`i_&XOVtEPU;gQW<*Zn%!-F5Z+FR=lO{zL9w~dt27hSUw z4?>9-o@ocWwOV4n%JZ|*o=k2DqWI%hQNM1!9gwiu3G}+JFls3!a#?BVtS67>I#Np` zR9;BNix}*6e_h?S8$0+-fizxqv#@eqMKBHu;qw|7T1e1|@z!LsyfBzD=^jY!FKkZx zw?+7D-t+Bv!mHRS5!>c~3XU`>0}vD7iG0p>+d(CbMdcELTY?Djd1=9AUV&hIGKI(z^OW+S1`rytelPANVcMLBP}iHlHumfvdT z^S<;B6&{t}=()L6HF~<5b;Q=HTQn#ej9jt$a?Bi`(=Q}c=!HxhGlkC?TOImv>nV%6 z!r0Uegv^k~bH!K~U-#u7OZ`?4^)-P1PSd>en?o3aNDSlfeLD<=;eQe^XI(xGC`z#`i9=y|i<0^KONzA+QXA zSGZ)zP5}#8Nr}DQ=F<(^=682%=z?JKTm<@PoCJCYYGT&^AcDQ zy$lzq{meH6F6~{V(qbH6gHe2C2(=EMsypB@VDs|4n`FuBFKk^FWAXV4{KB86X9-H9 zC;WyVg3bgyUB*IZhfB+$lCJsKEQSnG125`YWHm~-K;brP(LM+<7d+O;3b!nWDNFeD z<*ab6gGQ7J8x?`p&fUL$oxyLpAxIYU*%da3_1m=+b0rnD(Vnm#!9Y$2k%NReFw&DT z1l<-0F2VR;WKbY%zP=n7z%}tVPIx*?S|}J1A4Mu*b&`JAefNh<;6EgRw{($Ux3s1x|?mGFPpPjCl zsdMh=5Sf$${rzNRQb^MPVRd9z*-9VT%&dl64bU4q6k0?sRwCzXupCfDXFf+rJ+~g2 zt#{U32FXZu7fRMp-H6ga$6Git$^Pk#I9Kq1O%N4p)5NsoLz-L%^?n?#rG|MRhXfAf zO7KjKASZ3_)BZwHL%5k*2&&{k2mQz0NM8W{Bpo1bmB9T$?&chL4SH3gmF z`qB!!R`?L0_|?g(X^D|c1%s5hSS?Vh{B>Wf-HKD7|scm7NLyrEt-r5AWY{f_eP zlKVR5&-dsg>M1_Qaf^=!2f4QUb*1Vu8Llxs*>}iu&1=7li)HT-tK}Lt%^~46a?|fv zuM|J9SF_?T|8A%vlZr=$Nn6+5{&e$yNI?7y z^s&Xmcxu#A)CU6$NuhQRE!3yA*X!ioNdM>2&sLB;CwX4i-Hso;MJyR!!J{=wWWJe(O0_qTe(mJ+I|xpKp0UU#2~d}RqOWLTL2Kp5YW4Ph zdz>Tin?oQJ7lJnpBQ#u3qmg==m}v#}Ywh_aC)yDgL_-fYLx$GyIiEO{qSpXjcprX$ zx#erS_Wf^Q;)WyyeJrbLl?#b8OwL&8#%KBi^#u9S)x>jKn zs>pnF!d)~#`r~)i_ZT!ErjyaHku<^VM$Ph}C`ycw9YRsT!jWcbKOJxD$!Irf7FYZ_ z+eZIj4~<0Z#qxF%p@`iGV2Gur84S`WDG7YdFbS;W*nc_pS@*tv4xr;A8#dLrB@y-+ zhSU~6q#dF&bn)@#1DPHbhkr#`89#uZATSELh5y)a!;$Jr>P*NClxixNrH0;njxVCG zmq!go;mRPX4_omkwy^|s7u;==rUJqStT4oHD_QpQo@ViHTx>#cL{osg z32PsEmihSoceH=R`f5pOI#Y(^putdORzW6OgU01_VNs$@dzVyLk|h(kzW@33yHdY* z@Rk4V`N74bEy3Zqw6hH_?(nlElWaGuE@2++hQ@r~r{##=ZMx8@q(rPAG4LPg_1>m4 z-Ja~5F?Ey~E5uEXQUN<60q>pk$D%iack(0-PYYcOClGW90x0`utu4Nm^XuW|3gP7^ z@oz6-IJ|_p8i@_^ji}7+wL6Gpf1j?`=GAx}?+pBvS`5&g?j&BSBFN>(cT#6J)C+8Q zyrqPLu{*AM%{fLbpWmSohJVODjbCaPMIrH3tu`F4<^z>#dJ&6S(gNLMNY};82RV4B#)%Z~TVkE-L3q=i6yBb>2I##`x}hj0wy+qp zV(2}c04WHY$&BH(+D?E)d0`ENNpDH~%Sc$1A+qJiFfb3~VSSdvWXpJR`BVB=73ij+ z?Ik5^h9n~!iczpUE_T^-w>TjFyS)@F0GS}22CL{w7oj1Tlw}vt2wL&#FL!iRaG3SF z#$5d}XWvYNm`7rtaXBd(8#F@?gU%=vf_}_zV3ZmOxnvkO{tJpPOLK64F0|l^@oHkT zg?jXmMzhnVs6&uyADJZle z%9V?F#Z;kj$kJ|lu=`>B3N+%jtL#fNvN`9eMvTILSH-roy$E~8z3WIi;gk@uXZ_!3Zed+_CaO5Z3YYis$lRoulNhLz+ zW;??W^z{TU)H~t_GVacG)!+4K*pEO~VTl8215U9bW~_Ex>5tUGzGIE>16fN9QPKcB zP9%p2G}*CM>!ns{982JZt8k-_^w;B}5a&sfg&>r~W~mL;ue_GG0WLR0?+z`Yd3%%oEr0pF0~$3S_EG-_gO%V6 z!GT(5_vF9fBDTCMmtpHp{;l(RX*zf!?RIaAt3DLu5RCZOE7CykowR-R_mMa7r)wK7 zZaNy!Qg5`H2c$h(#(pyb!FcKaQgx6GQH6q{P<;Qt0M!0V2&I4#IvViVOU6)O*Sn&F zOr`hUrPh5W?n&ah`ZtCgWhoxV@l#}}?iKw#gR2)@?of<2v^B$#Rl1Yb_<`N3n&+)T zQ+E%U@8f^F;20;j8-g|bkUUjmK0#0)MX5H9rr?ZZn^712)BAkiJ9a(`I~%sZzY<;; zr^L)_bYc3lybBDJp;dS(36zFA1lU#cwvLt|o1>*h-SnBJ{ELN)(YKN3tu=W=F~HT# z#2SELs^sgA)HNh&i%6{XW%TR}M~lY6b$J+(?h>%RD=>+}W49i()AN15`9>NH+0ju>M&MmVY~D__z>-mtQdwg12f=%*fw1xi*&P(Jr7$TsjvW*8aX{0j%4JB;*9 z{}7eto6MJ%q_Ct+xu5)Y%3Jvg3Y9IoLxsPiC*!R?JIXNlWFtZ)C@5wz(A z%B*pzOKxbhS%7o5orYXu5$vl+rQozEIk~sNCyKxC12a0dU(b8l zax^xDUh12xvtL`55l;msl!{6c3qp{;I+f8Z7nww!&pVG0sf7OR2d{~1t>3>AkKA_# z{?(p~SAUl1O<>Eh*;3JRAGQ2Rz2k}LYgfFvTK|#ROv|NG{ZlUh<$5U*qTv37euHJ! z@&D7zd3H6?eQi9_Tj(9B(!2B`J&1G!ln&BS5D;mC0i-J22uKYzgsK>-bV3yb6bOoR z2}P6|deH!RCVq$iTF<BS=%n@Yzi9s-L2)uHf`L86;4CZq~LX?amR_{e8<1nf_upG&s*=H$&^VM=o@8x_H9<|y)6zsqti|clhh|r@5 zL~kR7%CTcTAfKWSuw44E;ygyCa{#EYUl8L@A@Y_%;Mo~~F`>>l%uOJn(9qZo2n8rC(THVo= z6mJoo0&pWwxqjX%A}#7^B);mspDhKT3Yr$O)txQ9@RvP%0g=6ZDo00F)ZWmjh%&7w z16P)Yq7`>`*U}H0lOK-@FUnqGUqJx}mx?D&9 zsZl<{8tNH&*y%!fw@wP~xEDwz^K@;Y+3e>EeW&TtrvgYg#H;pshi$3P7(nRR7Rllp zx0gzn8>te(CunLE3WNjwjD4WDo%;4eVurXBh~eF>yvxmKj7?rl+>_I&Fe$%g5d&!n(TUyuEi1u`m z0w?V@@)Irno!g~zub5#9uRfSEb>>e;sFX2M5M_u8#)2OQqVZ~jfMO*Z(Cd`*#l>?P zkM7stL(Er*Xomtfmqr6OyEijT&mK3r%U5Op87vbR*t-A|loYkXf2anOc~)^iXqV^Dj*+>(?T1v9Akv0$>7%h9_IlIuBpiQpUXE@T~ER3D`g+IL&QV*xe`8n@TdH++y z`FDIrr@602&?3IOE1mia(6hSPt84CQ(k|!#-~E=CqTU_4jO#_nnWb;(Jm5BfiUHee zcS_-5MgHx=S(m8alTfU_yvK5>_N6oQFj+6u^yfj>RM~|l*et8_RFfjqzH`BMV)Cnd z9MgXtB*0vzVtzHvb5%?c)k&%K*8WH2ci9s5ecJ!|Wb?F@ooV%5qbCT;1c;X&7x!UU zRrce=IPN5lF|TS*)ixL7wyu83CAR3pPn}elgwsV+y4pfqLQ>@fu74ozfa1N;c<3EN z^|mu_7hXa?$=4an4#b!q=8IRz^3LAarBDD=uAW|B)-8&nXii4OL`y(Ge_kZc{fgF7 zgut&A34Cta(X-#y^*XFgk}6>z+gQl>PH;=2y^=0|JT$gMHl!&SzCXPe%-Y{D*w#9E z9W9AJPm`}Oi`c25lx7E9M_zd7UdhdfJ)Da&U{YDvA1mrL>36!3zy%H+vO4`N%xbTx zbO;EM)!>lINBk9uLQKsv+5c(Cg%hu#L2%HAX}!?Jpq!rEu!DI7(`?g@9ghLn7n%FG z?E7o4QC+{=dO6paHMuNo0VCZqxD%d`B{QUQ%mzZbCOX0<^%fqbvojXUd_ zPnTLMlF0qbpA&)RzZ`p`?a4FK8{aF|iB{P2kdRNidJmZwQ!eg)w>J-dgzoc_m-Nn? zcD9c`<05 zW=D4+dQTjT1gD((Hd`LBezU))Zauq@vxrNOnc3?`e~b$EWO7gX(z-tIcS~7j=NrOX z|84jF00pNhBu=)}nTKWUm=ff}uUSo4&4=6d>d5Bh+c(%%!R zlqYp<9PZ=~KitGz?O$f#kJ4;dgP6@kZ!H^`eyuf~ z3e0F9FB~E+TgLZ&is0$^6MZ?MPNiI>GVjEH-27&^>}~^I*og_1cWm{z9BS1`3??Mu z?7sc{y8C%q|F!O$7LE?TSRtt;Pms3ny`4QtxYfbc7~O;v72t!@;l4U(LcOIBUPeDS z_B;U)Wy~XEbf)=pOF%0NNmNHZ+3)Zn?^V54l?`X#uy;^&=Q9a$wenu9e%$DjX=yGQ z3<)Sw7aCQd`r5FXZZALllPtqoo^SF7pm7^A@dR!7XMUM55NS#^s7|bDLBkGm`S+QdNWGcYyGgibC2>Ne15>paOjj6i zP0qYCn`aQ`FD-oqJlH^x4+45!St!$7?&FJPPjU^HJm#j6)`?Iea8B!tCz#2>tj{|8L>}y}x*7jkr@l1K53A+5VEc-FD1$st4VoGAE~ zlw~p~yf617b)QeJ_qXwvIP2VYJ%I?HD=q&*rpZhog$FeruRPC`xEJCQXP7g}y=S-Z zwd`ts=Coa@HB`!;S0e0Kt~tbJ@-GvsgNMnWHSNE(3A*&8|*0Cj$Wqcm^YJ2 z=JFGS`nIQVw&}-6g*KN3@nlklJ>s)6z>73BEJV(%jNV7JnwQ)|_^wH1w^<2YS3??l z=tI(WRDu$~*RAi@%3$oL1nOYW9T`GHp1!>&m?9CE*xq9zJD@R_})En-?E%a>-2u@1GglL`Ut1 zte#xK@TyV~xexz_PhOcbN3B3Vt5=;qyqjWAo)IZ>R30JLJ=WtH(#i2eHT-wRnd~z} zMgFB0$ZT9vw9U0-R)nlo5Vj!SLvE|*Lm>tx(0SK6k} z)PRouQwq0B?!SJO150dCxF2)87r#0Rzy7eUzFA(T-3aT%r5&kLTjL}E5R+5M`dMo`?{vfVCm(}u?i%-t8 zBIVG#aMvGu&WRXFk*wQ}IP@x0f}>Kf7ULwTb%%cnknA&-KeQww$yp z9JTVR>Ji{JH1H%ZqgQWN{H$YZ-IbyFA`au&dVUemEnoOUK4lEu#h){5AwWU0mujA;!pC z{scbqx?cH_h)(?Zz2ioXOjX=sK<(w{r@YScCp$?PXB-Ljx>-T>t|6)}1v{=GVj`NM zpUJYi_f(fm@97O`1Aq0^YmYdn$9YKJd+XS3!e1Y4|Cp%YDWvRP`t4_P8~4KP?RjQq zmtUg)MWM?cuBB?}EVOGgf;B7K-}b*bO@`hYk3^5n(PCB%tD^yA@xXvM3)kmu353K4XfvR#(Iy7+oDF^w*3H(ml8`IoIr? zkyx@X07ZdfF!BmAh)+LabpPgkh64n!@cq`*#h{OIB%ErWB5qrZ1&@|}coBP27?S@b zFidUmi`rL%+-l68n#Ac*zr&qi-~BN%i$#e~0CvANta>^cU6=XKX(ypUb*A_Z6+-%s2j9r3QLfvGC1Da+unyuw00}*R%~e+- zF!x6u+q!i{hpk?DwoG3~h{@S%|6Fl{qp|EeH-mxFBjoI0lVOq)^{Y*uk$n!kHGwDP z zI}_ShbD|lr^V%MYB|N@G0M;j-9MDq_SWU1j^|TSLN->QEjd9xf=&01lbkmxs7}^Je#2T zHy_KVsbSfwr{_Hs$C8IFGQ(gy({U?a50ObbPp7}AtTZ|WvRpN3`)cU0{mJ7KifRRm zp6+;eOgm!`c9FV>40yD@c}bt+de7SgXPR*zB)OmAo)<2xO4LA;E25Ua^3o)aEyGK$ z)W#3~Vr^AgyS~)m;ympe)PDOe1Ij4{YfL2>A4~b24-&Loie0z0=wIj%M1&RewC;wT zS>BrEI{+oN-er@X1@ej(Juiz>uMjn)46!k~g1zvyh0}*m;czk$x{@{G<^kA);+KA{ zdnp^&PLq)}Kjy&Re|E27lyJ6XZ92jZSuXyq$Ab%!6m|WZpsR7|3yX=-K%y$~`$ZBLaiejBfXlJxvTq zx8VWY&BqMRGfRiS`nw`TTS%(VY@>kTIDj$-%D(v6^=tuwk7 z)hBVwRKL|c33;bXoZ5%!sbZv%RkE8v19t88(~=DB#|2M&fOA|W>U^zmj8Gy+hxpLX zmA%G)3aGt~H^*?kO63SD&`VA>srE5Z#Z-mfTx7hd86|%@P-MCbRxkl@m;nH)hY8`K zT72ag0C?#|+MX#^xsH=?)@tog0{sWJ4gydR>vFXzF7s_Asi!Bz5|S!0Iq-e`iC~4! zLzd6hFP&+<#YlqtTb4sdN5awLkF7VlcY(AEWo50OV+uA+^W6u)v4D-Gd}E!dwAe)t z>%!JoMRvFnz$7tP0ODC(7j?Mcj9yvCU~UN`Ol&T9jmphMBx?)WB_|h#h#1TOrhtLG zzty%qJz{VKOi1SNAwIdoXim081Y&<9JB|P#T+puTgq#dJ5L0{R#%xS_9rVyXb3S0a z%UtpTTj|Q+8B1H?lM`D`%FQUxZT-B+--t>VdRkc_SKn0 zV#%e7#+UOEW2F<5uz#cgbH3RaYd!q)qRn+#WjwEPR@coiqr~j{&`Fbo58uNdKMjfz z2sW0vQ>xdOR9>0+tZf^WiK65VRd6CwmeM1=7O$rB(xS%n==y6X=*H2pH(_Ye@<48> z6j{>I?f08B4RyY7z|RzJA)uE!Djo|S(coSjp`jdPYknWI7{rB**kiVmcKu~X_nnBV z=3O>W_+q==9kP*<1(XDUimoaNSG~AMRK!=wVm9aV-7wr20VhbyY)_`K806hbtnIhq z**jm(k&8I*b!} zJ|sdtDwUOv+A)s4toy?IvT(@X0T}Jd;ByuF73T~Wu|Wc0lf(eP6(|LCKz@)Ub|9s< zO@8b*atPO{c+4Dkgb}6%1PE`)JXWVB7I!>vQ= zf+E-Q{dW}bCNF3oz~Eh8gyDj)o8|f(M9({V=iQ#k%OwuEWi$eWk*ej*N8Uz)yn{ zI_yOV14?;7925p5S9ulk3aO)}nRxwNlys72XN^Tj`@xzDhrJW6Ft0R+4$XWQ)emUtOSlRg=HoG= zFn!}=Ys8S@Xyi&H=mqc}qdg`KW?90|fbr1M+Mt z55X*xa%GUpRwj>sM^pl=S_eq(nKWL_P&9)gmPsvi`<(Cd#@Tr-V9Qm~Sp|U|5`)}szZ&Y+ky8uq5&&59)2nTS*DLu5Fh5PyOIRc)uY>I4mNE7#) z2-WzO8dV0=5%4DB9xp^_-{a*2;&}uo8NgC%gV7$C(Oy0N@g(PjQ@_$*DrMMANG=Ru zH(H1jMh>J(2y2<49b?A+N=!^AK712$K3!^=%))?r5KEJLeLomx^AC7+s9%5WF(Z~1 zrg56^j1S7E=$dq*E zAh0Y}w+Z+BcP(iED8gTf!puR#8-ot=c=t%BKDAflC0j=+MO{K-HFM->bF3<0%! z0!uYIsx!}Q@=OvM8RwAtqragmuCdgFuUJ?BgQwZx40ev`JGQCL`;Szp5utw{pkguD z54qXGmBOfm}6F}4;ke$Fk zg$)QK2mIoH|M5r$I2a&XWDwbeEh`52lNqS60mW5&I}^IUOcQGVzEu6ye$VOmt>mcU ze|I2|MMX1@vjN<-*qeaW0)amIAaiHvea@+Yf&*dW%b4yk_eHfOYsZP9L=XJ_)rMm_ z{0#L`lH^Ptd9r + + + + + + + + + + + +SANDBOX +AQ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spqlios/lib/docs/logo-sandboxaq-white.svg b/spqlios/lib/docs/logo-sandboxaq-white.svg new file mode 100644 index 0000000..036ce5a --- /dev/null +++ b/spqlios/lib/docs/logo-sandboxaq-white.svg @@ -0,0 +1,133 @@ + + + + + +SANDBOX +AQ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spqlios/lib/manifest.yaml b/spqlios/lib/manifest.yaml new file mode 100644 index 0000000..02235cb --- /dev/null +++ b/spqlios/lib/manifest.yaml @@ -0,0 +1,2 @@ +library: spqlios-fft +version: 2.0.0 diff --git a/spqlios/lib/scripts/auto-release.sh b/spqlios/lib/scripts/auto-release.sh new file mode 100644 index 0000000..c31efbe --- /dev/null +++ b/spqlios/lib/scripts/auto-release.sh @@ -0,0 +1,27 @@ +#!/bin/sh + +# this script generates one tag if there is a version change in manifest.yaml +cd `dirname $0`/.. +if [ "v$1" = "v-y" ]; then + echo "production mode!"; +fi +changes=`git diff HEAD~1..HEAD -- manifest.yaml | grep 'version:'` +oldversion=$(echo "$changes" | grep '^-version:' | cut '-d ' -f2) +version=$(echo "$changes" | grep '^+version:' | cut '-d ' -f2) +echo "Versions: $oldversion --> $version" +if [ "v$oldversion" = "v$version" ]; then + echo "Same version - nothing to do"; exit 0; +fi +if [ "v$1" = "v-y" ]; then + git config user.name github-actions + git config user.email github-actions@github.com + git tag -a "v$version" -m "Version $version" + git push origin "v$version" +else +cat </dev/null + rm -f "$DIR/$FNAME" 2>/dev/null + DESTDIR="$DIR/dist" cmake --install build || exit 1 + if [ -d "$DIR/dist$CI_INSTALL_PREFIX" ]; then + tar -C "$DIR/dist" -cvzf "$DIR/$FNAME" . + else + # fix since msys can mess up the paths + REAL_DEST=`find "$DIR/dist" -type d -exec test -d "{}$CI_INSTALL_PREFIX" \; -print` + echo "REAL_DEST: $REAL_DEST" + [ -d "$REAL_DEST$CI_INSTALL_PREFIX" ] && tar -C "$REAL_DEST" -cvzf "$DIR/$FNAME" . + fi + [ -f "$DIR/$FNAME" ] || { echo "failed to create $DIR/$FNAME"; exit 1; } + [ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not uploading"; exit 1; } + curl -u "$CI_CREDS" -T "$DIR/$FNAME" "$CI_REPO_URL/$FNAME" +fi + +if [ "x$1" = "xinstall" ]; then + [ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not downloading"; exit 1; } + # cleaning + rm -rf "$DESTDIR$CI_INSTALL_PREFIX"/* 2>/dev/null + rm -f "$DIR/$FNAME" 2>/dev/null + # downloading + curl -u "$CI_CREDS" -o "$DIR/$FNAME" "$CI_REPO_URL/$FNAME" + [ -f "$DIR/$FNAME" ] || { echo "failed to download $DIR/$FNAME"; exit 0; } + # installing + mkdir -p $DESTDIR + tar -C "$DESTDIR" -xvzf "$DIR/$FNAME" + exit 0 +fi diff --git a/spqlios/lib/scripts/prepare-release b/spqlios/lib/scripts/prepare-release new file mode 100644 index 0000000..4e4843a --- /dev/null +++ b/spqlios/lib/scripts/prepare-release @@ -0,0 +1,181 @@ +#!/usr/bin/perl +## +## This script will help update manifest.yaml and Changelog.md before a release +## Any merge to master that changes the version line in manifest.yaml +## is considered as a new release. +## +## When ready to make a release, please run ./scripts/prepare-release +## and commit push the final result! +use File::Basename; +use Cwd 'abs_path'; + +# find its way to the root of git's repository +my $scriptsdirname = dirname(abs_path(__FILE__)); +chdir "$scriptsdirname/.."; +print "✓ Entering directory:".`pwd`; + +# ensures that the current branch is ahead of origin/main +my $diff= `git diff`; +chop $diff; +if ($diff =~ /./) { + die("ERROR: Please commit all the changes before calling the prepare-release script."); +} else { + print("✓ All changes are comitted.\n"); +} +system("git fetch origin"); +my $vcount = `git rev-list --left-right --count origin/main...HEAD`; +$vcount =~ /^([0-9]+)[ \t]*([0-9]+)$/; +if ($2>0) { + die("ERROR: the current HEAD is not ahead of origin/main\n. Please use git merge origin/main."); +} else { + print("✓ Current HEAD is up to date with origin/main.\n"); +} + +mkdir ".changes"; +my $currentbranch = `git rev-parse --abbrev-ref HEAD`; +chop $currentbranch; +$currentbranch =~ s/[^a-zA-Z._-]+/-/g; +my $changefile=".changes/$currentbranch.md"; +my $origmanifestfile=".changes/$currentbranch--manifest.yaml"; +my $origchangelogfile=".changes/$currentbranch--Changelog.md"; + +my $exit_code=system("wget -O $origmanifestfile https://raw.githubusercontent.com/tfhe/spqlios-fft/main/manifest.yaml"); +if ($exit_code!=0 or ! -f $origmanifestfile) { + die("ERROR: failed to download manifest.yaml"); +} +$exit_code=system("wget -O $origchangelogfile https://raw.githubusercontent.com/tfhe/spqlios-fft/main/Changelog.md"); +if ($exit_code!=0 or ! -f $origchangelogfile) { + die("ERROR: failed to download Changelog.md"); +} + +# read the current version (from origin/main manifest) +my $vmajor = 0; +my $vminor = 0; +my $vpatch = 0; +my $versionline = `grep '^version: ' $origmanifestfile | cut -d" " -f2`; +chop $versionline; +if (not $versionline =~ /^([0-9]+)\.([0-9]+)\.([0-9]+)$/) { + die("ERROR: invalid version in manifest file: $versionline\n"); +} else { + $vmajor = int($1); + $vminor = int($2); + $vpatch = int($3); +} +print "Version in manifest file: $vmajor.$vminor.$vpatch\n"; + +if (not -f $changefile) { + ## create a changes file + open F,">$changefile"; + print F "# Changefile for branch $currentbranch\n\n"; + print F "## Type of release (major,minor,patch)?\n\n"; + print F "releasetype: patch\n\n"; + print F "## What has changed (please edit)?\n\n"; + print F "- This has changed.\n"; + close F; +} + +system("editor $changefile"); + +# compute the new version +my $nvmajor; +my $nvminor; +my $nvpatch; +my $changelog; +my $recordchangelog=0; +open F,"$changefile"; +while ($line=) { + chop $line; + if ($recordchangelog) { + ($line =~ /^$/) and next; + $changelog .= "$line\n"; + next; + } + if ($line =~ /^releasetype *: *patch *$/) { + $nvmajor=$vmajor; + $nvminor=$vminor; + $nvpatch=$vpatch+1; + } + if ($line =~ /^releasetype *: *minor *$/) { + $nvmajor=$vmajor; + $nvminor=$vminor+1; + $nvpatch=0; + } + if ($line =~ /^releasetype *: *major *$/) { + $nvmajor=$vmajor+1; + $nvminor=0; + $nvpatch=0; + } + if ($line =~ /^## What has changed/) { + $recordchangelog=1; + } +} +close F; +print "New version: $nvmajor.$nvminor.$nvpatch\n"; +print "Changes:\n$changelog"; + +# updating manifest.yaml +open F,"manifest.yaml"; +open G,">.changes/manifest.yaml"; +while ($line=) { + if ($line =~ /^version *: */) { + print G "version: $nvmajor.$nvminor.$nvpatch\n"; + next; + } + print G $line; +} +close F; +close G; +# updating Changelog.md +open F,"$origchangelogfile"; +open G,">.changes/Changelog.md"; +print G <) { + if ($line =~ /^## +\[([0-9]+)\.([0-9]+)\.([0-9]+)\] +/) { + if ($1>$nvmajor) { + die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n"); + } elsif ($1<$nvmajor) { + $skip_section=0; + } elsif ($2>$nvminor) { + die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n"); + } elsif ($2<$nvminor) { + $skip_section=0; + } elsif ($3>$nvpatch) { + die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n"); + } elsif ($2<$nvpatch) { + $skip_section=0; + } else { + $skip_section=1; + } + } + ($skip_section) and next; + print G $line; +} +close F; +close G; + +print "-------------------------------------\n"; +print "THIS WILL BE UPDATED:\n"; +print "-------------------------------------\n"; +system("diff -u manifest.yaml .changes/manifest.yaml"); +system("diff -u Changelog.md .changes/Changelog.md"); +print "-------------------------------------\n"; +print "To proceed: press otherwise \n"; +my $bla; +$bla=; +system("cp -vf .changes/manifest.yaml manifest.yaml"); +system("cp -vf .changes/Changelog.md Changelog.md"); +system("git commit -a -m \"Update version and changelog.\""); +system("git push"); +print("✓ Changes have been committed and pushed!\n"); +print("✓ A new release will be created when this branch is merged to main.\n"); + diff --git a/spqlios/lib/spqlios/CMakeLists.txt b/spqlios/lib/spqlios/CMakeLists.txt new file mode 100644 index 0000000..4326576 --- /dev/null +++ b/spqlios/lib/spqlios/CMakeLists.txt @@ -0,0 +1,223 @@ +enable_language(ASM) + +# C source files that are compiled for all targets (i.e. reference code) +set(SRCS_GENERIC + commons.c + commons_private.c + coeffs/coeffs_arithmetic.c + arithmetic/vec_znx.c + arithmetic/vec_znx_dft.c + arithmetic/vector_matrix_product.c + cplx/cplx_common.c + cplx/cplx_conversions.c + cplx/cplx_fft_asserts.c + cplx/cplx_fft_ref.c + cplx/cplx_fftvec_ref.c + cplx/cplx_ifft_ref.c + cplx/spqlios_cplx_fft.c + reim4/reim4_arithmetic_ref.c + reim4/reim4_fftvec_addmul_ref.c + reim4/reim4_fftvec_conv_ref.c + reim/reim_conversions.c + reim/reim_fft_ifft.c + reim/reim_fft_ref.c + reim/reim_fftvec_addmul_ref.c + reim/reim_ifft_ref.c + reim/reim_ifft_ref.c + reim/reim_to_tnx_ref.c + q120/q120_ntt.c + q120/q120_arithmetic_ref.c + q120/q120_arithmetic_simple.c + arithmetic/scalar_vector_product.c + arithmetic/vec_znx_big.c + arithmetic/znx_small.c + arithmetic/module_api.c + arithmetic/zn_vmp_int8_ref.c + arithmetic/zn_vmp_int16_ref.c + arithmetic/zn_vmp_int32_ref.c + arithmetic/zn_vmp_ref.c + arithmetic/zn_api.c + arithmetic/zn_conversions_ref.c + arithmetic/zn_approxdecomp_ref.c + arithmetic/vec_rnx_api.c + arithmetic/vec_rnx_conversions_ref.c + arithmetic/vec_rnx_svp_ref.c + reim/reim_execute.c + cplx/cplx_execute.c + reim4/reim4_execute.c + arithmetic/vec_rnx_arithmetic.c + arithmetic/vec_rnx_approxdecomp_ref.c + arithmetic/vec_rnx_vmp_ref.c +) +# C or assembly source files compiled only on x86 targets +set(SRCS_X86 + ) +# C or assembly source files compiled only on aarch64 targets +set(SRCS_AARCH64 + cplx/cplx_fallbacks_aarch64.c + reim/reim_fallbacks_aarch64.c + reim4/reim4_fallbacks_aarch64.c + q120/q120_fallbacks_aarch64.c + reim/reim_fft_neon.c +) + +# C or assembly source files compiled only on x86: avx, avx2, fma targets +set(SRCS_FMA_C + arithmetic/vector_matrix_product_avx.c + cplx/cplx_conversions_avx2_fma.c + cplx/cplx_fft_avx2_fma.c + cplx/cplx_fft_sse.c + cplx/cplx_fftvec_avx2_fma.c + cplx/cplx_ifft_avx2_fma.c + reim4/reim4_arithmetic_avx2.c + reim4/reim4_fftvec_conv_fma.c + reim4/reim4_fftvec_addmul_fma.c + reim/reim_conversions_avx.c + reim/reim_fft4_avx_fma.c + reim/reim_fft8_avx_fma.c + reim/reim_ifft4_avx_fma.c + reim/reim_ifft8_avx_fma.c + reim/reim_fft_avx2.c + reim/reim_ifft_avx2.c + reim/reim_to_tnx_avx.c + reim/reim_fftvec_addmul_fma.c +) +set(SRCS_FMA_ASM + cplx/cplx_fft16_avx_fma.s + cplx/cplx_ifft16_avx_fma.s + reim/reim_fft16_avx_fma.s + reim/reim_ifft16_avx_fma.s +) +set(SRCS_FMA_WIN32_ASM + cplx/cplx_fft16_avx_fma_win32.s + cplx/cplx_ifft16_avx_fma_win32.s + reim/reim_fft16_avx_fma_win32.s + reim/reim_ifft16_avx_fma_win32.s +) +set_source_files_properties(${SRCS_FMA_C} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx;-mavx2") +set_source_files_properties(${SRCS_FMA_ASM} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx;-mavx2") + +# C or assembly source files compiled only on x86: avx512f/vl/dq + fma targets +set(SRCS_AVX512 + cplx/cplx_fft_avx512.c + ) +set_source_files_properties(${SRCS_AVX512} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx512f;-mavx512vl;-mavx512dq") + +# C or assembly source files compiled only on x86: avx2 + bmi targets +set(SRCS_AVX2 + arithmetic/vec_znx_avx.c + coeffs/coeffs_arithmetic_avx.c + arithmetic/vec_znx_dft_avx2.c + arithmetic/zn_vmp_int8_avx.c + arithmetic/zn_vmp_int16_avx.c + arithmetic/zn_vmp_int32_avx.c + q120/q120_arithmetic_avx2.c + q120/q120_ntt_avx2.c + arithmetic/vec_rnx_arithmetic_avx.c + arithmetic/vec_rnx_approxdecomp_avx.c + arithmetic/vec_rnx_vmp_avx.c + +) +set_source_files_properties(${SRCS_AVX2} PROPERTIES COMPILE_OPTIONS "-mbmi2;-mavx2") + +# C source files on float128 via libquadmath on x86 targets targets +set(SRCS_F128 + cplx_f128/cplx_fft_f128.c + cplx_f128/cplx_fft_f128.h + ) + +# H header files containing the public API (these headers are installed) +set(HEADERSPUBLIC + commons.h + arithmetic/vec_znx_arithmetic.h + arithmetic/vec_rnx_arithmetic.h + arithmetic/zn_arithmetic.h + cplx/cplx_fft.h + reim/reim_fft.h + q120/q120_common.h + q120/q120_arithmetic.h + q120/q120_ntt.h + ) + +# H header files containing the private API (these headers are used internally) +set(HEADERSPRIVATE + commons_private.h + cplx/cplx_fft_internal.h + cplx/cplx_fft_private.h + reim4/reim4_arithmetic.h + reim4/reim4_fftvec_internal.h + reim4/reim4_fftvec_private.h + reim4/reim4_fftvec_public.h + reim/reim_fft_internal.h + reim/reim_fft_private.h + q120/q120_arithmetic_private.h + q120/q120_ntt_private.h + arithmetic/vec_znx_arithmetic.h + arithmetic/vec_rnx_arithmetic_private.h + arithmetic/vec_rnx_arithmetic_plugin.h + arithmetic/zn_arithmetic_private.h + arithmetic/zn_arithmetic_plugin.h + coeffs/coeffs_arithmetic.h + reim/reim_fft_core_template.h +) + +set(SPQLIOSSOURCES + ${SRCS_GENERIC} + ${HEADERSPUBLIC} + ${HEADERSPRIVATE} + ) +if (${X86}) + set(SPQLIOSSOURCES ${SPQLIOSSOURCES} + ${SRCS_X86} + ${SRCS_FMA_C} + ${SRCS_FMA_ASM} + ${SRCS_AVX2} + ${SRCS_AVX512} + ) +elseif (${X86_WIN32}) + set(SPQLIOSSOURCES ${SPQLIOSSOURCES} + #${SRCS_X86} + ${SRCS_FMA_C} + ${SRCS_FMA_WIN32_ASM} + ${SRCS_AVX2} + ${SRCS_AVX512} + ) +elseif (${AARCH64}) + set(SPQLIOSSOURCES ${SPQLIOSSOURCES} + ${SRCS_AARCH64} + ) +endif () + + +set(SPQLIOSLIBDEP + m # libmath depencency for cosinus/sinus functions + ) + +if (ENABLE_SPQLIOS_F128) + find_library(quadmath REQUIRED NAMES quadmath) + set(SPQLIOSSOURCES ${SPQLIOSSOURCES} ${SRCS_F128}) + set(SPQLIOSLIBDEP ${SPQLIOSLIBDEP} quadmath) +endif (ENABLE_SPQLIOS_F128) + +add_library(libspqlios-static STATIC ${SPQLIOSSOURCES}) +add_library(libspqlios SHARED ${SPQLIOSSOURCES}) +set_property(TARGET libspqlios-static PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET libspqlios PROPERTY OUTPUT_NAME spqlios) +set_property(TARGET libspqlios-static PROPERTY OUTPUT_NAME spqlios) +set_property(TARGET libspqlios PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET libspqlios PROPERTY SOVERSION ${SPQLIOS_VERSION_MAJOR}) +set_property(TARGET libspqlios PROPERTY VERSION ${SPQLIOS_VERSION}) +if (NOT APPLE) +target_link_options(libspqlios-static PUBLIC -Wl,--no-undefined) +target_link_options(libspqlios PUBLIC -Wl,--no-undefined) +endif() +target_link_libraries(libspqlios ${SPQLIOSLIBDEP}) +target_link_libraries(libspqlios-static ${SPQLIOSLIBDEP}) +install(TARGETS libspqlios-static) +install(TARGETS libspqlios) + +# install the public headers only +foreach (file ${HEADERSPUBLIC}) + get_filename_component(dir ${file} DIRECTORY) + install(FILES ${file} DESTINATION include/spqlios/${dir}) +endforeach () diff --git a/spqlios/lib/spqlios/arithmetic/module_api.c b/spqlios/lib/spqlios/arithmetic/module_api.c new file mode 100644 index 0000000..52140a0 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/module_api.c @@ -0,0 +1,164 @@ +#include + +#include "vec_znx_arithmetic_private.h" + +static void fill_generic_virtual_table(MODULE* module) { + // TODO add default ref handler here + module->func.vec_znx_zero = vec_znx_zero_ref; + module->func.vec_znx_copy = vec_znx_copy_ref; + module->func.vec_znx_negate = vec_znx_negate_ref; + module->func.vec_znx_add = vec_znx_add_ref; + module->func.vec_znx_sub = vec_znx_sub_ref; + module->func.vec_znx_rotate = vec_znx_rotate_ref; + module->func.vec_znx_automorphism = vec_znx_automorphism_ref; + module->func.vec_znx_normalize_base2k = vec_znx_normalize_base2k_ref; + module->func.vec_znx_normalize_base2k_tmp_bytes = vec_znx_normalize_base2k_tmp_bytes_ref; + if (CPU_SUPPORTS("avx2")) { + // TODO add avx handlers here + module->func.vec_znx_negate = vec_znx_negate_avx; + module->func.vec_znx_add = vec_znx_add_avx; + module->func.vec_znx_sub = vec_znx_sub_avx; + } +} + +static void fill_fft64_virtual_table(MODULE* module) { + // TODO add default ref handler here + // module->func.vec_znx_dft = ...; + module->func.vec_znx_big_normalize_base2k = fft64_vec_znx_big_normalize_base2k; + module->func.vec_znx_big_normalize_base2k_tmp_bytes = fft64_vec_znx_big_normalize_base2k_tmp_bytes; + module->func.vec_znx_big_range_normalize_base2k = fft64_vec_znx_big_range_normalize_base2k; + module->func.vec_znx_big_range_normalize_base2k_tmp_bytes = fft64_vec_znx_big_range_normalize_base2k_tmp_bytes; + module->func.vec_znx_dft = fft64_vec_znx_dft; + module->func.vec_znx_idft = fft64_vec_znx_idft; + module->func.vec_znx_idft_tmp_bytes = fft64_vec_znx_idft_tmp_bytes; + module->func.vec_znx_idft_tmp_a = fft64_vec_znx_idft_tmp_a; + module->func.vec_znx_big_add = fft64_vec_znx_big_add; + module->func.vec_znx_big_add_small = fft64_vec_znx_big_add_small; + module->func.vec_znx_big_add_small2 = fft64_vec_znx_big_add_small2; + module->func.vec_znx_big_sub = fft64_vec_znx_big_sub; + module->func.vec_znx_big_sub_small_a = fft64_vec_znx_big_sub_small_a; + module->func.vec_znx_big_sub_small_b = fft64_vec_znx_big_sub_small_b; + module->func.vec_znx_big_sub_small2 = fft64_vec_znx_big_sub_small2; + module->func.vec_znx_big_rotate = fft64_vec_znx_big_rotate; + module->func.vec_znx_big_automorphism = fft64_vec_znx_big_automorphism; + module->func.svp_prepare = fft64_svp_prepare_ref; + module->func.svp_apply_dft = fft64_svp_apply_dft_ref; + module->func.znx_small_single_product = fft64_znx_small_single_product; + module->func.znx_small_single_product_tmp_bytes = fft64_znx_small_single_product_tmp_bytes; + module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_ref; + module->func.vmp_prepare_contiguous_tmp_bytes = fft64_vmp_prepare_contiguous_tmp_bytes; + module->func.vmp_apply_dft = fft64_vmp_apply_dft_ref; + module->func.vmp_apply_dft_tmp_bytes = fft64_vmp_apply_dft_tmp_bytes; + module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_ref; + module->func.vmp_apply_dft_to_dft_tmp_bytes = fft64_vmp_apply_dft_to_dft_tmp_bytes; + module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft; + module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft; + module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft; + module->func.bytes_of_vec_znx_big = fft64_bytes_of_vec_znx_big; + module->func.bytes_of_svp_ppol = fft64_bytes_of_svp_ppol; + module->func.bytes_of_vmp_pmat = fft64_bytes_of_vmp_pmat; + if (CPU_SUPPORTS("avx2")) { + // TODO add avx handlers here + // TODO: enable when avx implementation is done + module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_avx; + module->func.vmp_apply_dft = fft64_vmp_apply_dft_avx; + module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_avx; + } +} + +static void fill_ntt120_virtual_table(MODULE* module) { + // TODO add default ref handler here + // module->func.vec_znx_dft = ...; + if (CPU_SUPPORTS("avx2")) { + // TODO add avx handlers here + module->func.vec_znx_dft = ntt120_vec_znx_dft_avx; + module->func.vec_znx_idft = ntt120_vec_znx_idft_avx; + module->func.vec_znx_idft_tmp_bytes = ntt120_vec_znx_idft_tmp_bytes_avx; + module->func.vec_znx_idft_tmp_a = ntt120_vec_znx_idft_tmp_a_avx; + } +} + +static void fill_virtual_table(MODULE* module) { + fill_generic_virtual_table(module); + switch (module->module_type) { + case FFT64: + fill_fft64_virtual_table(module); + break; + case NTT120: + fill_ntt120_virtual_table(module); + break; + default: + NOT_SUPPORTED(); // invalid type + } +} + +static void fill_fft64_precomp(MODULE* module) { + // fill any necessary precomp stuff + module->mod.fft64.p_conv = new_reim_from_znx64_precomp(module->m, 50); + module->mod.fft64.p_fft = new_reim_fft_precomp(module->m, 0); + module->mod.fft64.p_reim_to_znx = new_reim_to_znx64_precomp(module->m, module->m, 63); + module->mod.fft64.p_ifft = new_reim_ifft_precomp(module->m, 0); + module->mod.fft64.p_addmul = new_reim_fftvec_addmul_precomp(module->m); + module->mod.fft64.mul_fft = new_reim_fftvec_mul_precomp(module->m); +} +static void fill_ntt120_precomp(MODULE* module) { + // fill any necessary precomp stuff + if (CPU_SUPPORTS("avx2")) { + module->mod.q120.p_ntt = q120_new_ntt_bb_precomp(module->nn); + module->mod.q120.p_intt = q120_new_intt_bb_precomp(module->nn); + } +} + +static void fill_module_precomp(MODULE* module) { + switch (module->module_type) { + case FFT64: + fill_fft64_precomp(module); + break; + case NTT120: + fill_ntt120_precomp(module); + break; + default: + NOT_SUPPORTED(); // invalid type + } +} + +static void fill_module(MODULE* module, uint64_t nn, MODULE_TYPE mtype) { + // init to zero to ensure that any non-initialized field bug is detected + // by at least a "proper" segfault + memset(module, 0, sizeof(MODULE)); + module->module_type = mtype; + module->nn = nn; + module->m = nn >> 1; + fill_module_precomp(module); + fill_virtual_table(module); +} + +EXPORT MODULE* new_module_info(uint64_t N, MODULE_TYPE mtype) { + MODULE* m = (MODULE*)malloc(sizeof(MODULE)); + fill_module(m, N, mtype); + return m; +} + +EXPORT void delete_module_info(MODULE* mod) { + switch (mod->module_type) { + case FFT64: + free(mod->mod.fft64.p_conv); + free(mod->mod.fft64.p_fft); + free(mod->mod.fft64.p_ifft); + free(mod->mod.fft64.p_reim_to_znx); + free(mod->mod.fft64.mul_fft); + free(mod->mod.fft64.p_addmul); + break; + case NTT120: + if (CPU_SUPPORTS("avx2")) { + q120_del_ntt_bb_precomp(mod->mod.q120.p_ntt); + q120_del_intt_bb_precomp(mod->mod.q120.p_intt); + } + break; + default: + break; + } + free(mod); +} + +EXPORT uint64_t module_get_n(const MODULE* module) { return module->nn; } diff --git a/spqlios/lib/spqlios/arithmetic/scalar_vector_product.c b/spqlios/lib/spqlios/arithmetic/scalar_vector_product.c new file mode 100644 index 0000000..859893e --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/scalar_vector_product.c @@ -0,0 +1,63 @@ +#include + +#include "vec_znx_arithmetic_private.h" + +EXPORT uint64_t bytes_of_svp_ppol(const MODULE* module) { return module->func.bytes_of_svp_ppol(module); } + +EXPORT uint64_t fft64_bytes_of_svp_ppol(const MODULE* module) { return module->nn * sizeof(double); } + +EXPORT SVP_PPOL* new_svp_ppol(const MODULE* module) { return spqlios_alloc(bytes_of_svp_ppol(module)); } + +EXPORT void delete_svp_ppol(SVP_PPOL* ppol) { spqlios_free(ppol); } + +// public wrappers +EXPORT void svp_prepare(const MODULE* module, // N + SVP_PPOL* ppol, // output + const int64_t* pol // a +) { + module->func.svp_prepare(module, ppol, pol); +} + +/** @brief prepares a svp polynomial */ +EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N + SVP_PPOL* ppol, // output + const int64_t* pol // a +) { + reim_from_znx64(module->mod.fft64.p_conv, ppol, pol); + reim_fft(module->mod.fft64.p_fft, (double*)ppol); +} + +EXPORT void svp_apply_dft(const MODULE* module, // N + const VEC_ZNX_DFT* res, uint64_t res_size, // output + const SVP_PPOL* ppol, // prepared pol + const int64_t* a, uint64_t a_size, uint64_t a_sl) { + module->func.svp_apply_dft(module, // N + res, + res_size, // output + ppol, // prepared pol + a, a_size, a_sl); +} + +// result = ppol * a +EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N + const VEC_ZNX_DFT* res, uint64_t res_size, // output + const SVP_PPOL* ppol, // prepared pol + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->nn; + double* const dres = (double*)res; + double* const dppol = (double*)ppol; + + const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < auto_end_idx; ++i) { + const int64_t* a_ptr = a + i * a_sl; + double* const res_ptr = dres + i * nn; + // copy the polynomial to res, apply fft in place, call fftvec_mul in place. + reim_from_znx64(module->mod.fft64.p_conv, res_ptr, a_ptr); + reim_fft(module->mod.fft64.p_fft, res_ptr); + reim_fftvec_mul(module->mod.fft64.mul_fft, res_ptr, res_ptr, dppol); + } + + // then extend with zeros + memset(dres + auto_end_idx * nn, 0, (res_size - auto_end_idx) * nn * sizeof(double)); +} diff --git a/spqlios/lib/spqlios/arithmetic/vec_rnx_api.c b/spqlios/lib/spqlios/arithmetic/vec_rnx_api.c new file mode 100644 index 0000000..0f396fb --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_rnx_api.c @@ -0,0 +1,318 @@ +#include + +#include "vec_rnx_arithmetic_private.h" + +void fft64_init_rnx_module_precomp(MOD_RNX* module) { + // Add here initialization of items that are in the precomp + const uint64_t m = module->m; + module->precomp.fft64.p_fft = new_reim_fft_precomp(m, 0); + module->precomp.fft64.p_ifft = new_reim_ifft_precomp(m, 0); + module->precomp.fft64.p_fftvec_mul = new_reim_fftvec_mul_precomp(m); + module->precomp.fft64.p_fftvec_addmul = new_reim_fftvec_addmul_precomp(m); +} + +void fft64_finalize_rnx_module_precomp(MOD_RNX* module) { + // Add here deleters for items that are in the precomp + delete_reim_fft_precomp(module->precomp.fft64.p_fft); + delete_reim_ifft_precomp(module->precomp.fft64.p_ifft); + delete_reim_fftvec_mul_precomp(module->precomp.fft64.p_fftvec_mul); + delete_reim_fftvec_addmul_precomp(module->precomp.fft64.p_fftvec_addmul); +} + +void fft64_init_rnx_module_vtable(MOD_RNX* module) { + // Add function pointers here + module->vtable.vec_rnx_add = vec_rnx_add_ref; + module->vtable.vec_rnx_zero = vec_rnx_zero_ref; + module->vtable.vec_rnx_copy = vec_rnx_copy_ref; + module->vtable.vec_rnx_negate = vec_rnx_negate_ref; + module->vtable.vec_rnx_sub = vec_rnx_sub_ref; + module->vtable.vec_rnx_rotate = vec_rnx_rotate_ref; + module->vtable.vec_rnx_automorphism = vec_rnx_automorphism_ref; + module->vtable.vec_rnx_mul_xp_minus_one = vec_rnx_mul_xp_minus_one_ref; + module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref; + module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_ref; + module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref; + module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_ref; + module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref; + module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_ref; + module->vtable.bytes_of_rnx_vmp_pmat = fft64_bytes_of_rnx_vmp_pmat; + module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_ref; + module->vtable.vec_rnx_to_znx32 = vec_rnx_to_znx32_ref; + module->vtable.vec_rnx_from_znx32 = vec_rnx_from_znx32_ref; + module->vtable.vec_rnx_to_tnx32 = vec_rnx_to_tnx32_ref; + module->vtable.vec_rnx_from_tnx32 = vec_rnx_from_tnx32_ref; + module->vtable.vec_rnx_to_tnxdbl = vec_rnx_to_tnxdbl_ref; + module->vtable.bytes_of_rnx_svp_ppol = fft64_bytes_of_rnx_svp_ppol; + module->vtable.rnx_svp_prepare = fft64_rnx_svp_prepare_ref; + module->vtable.rnx_svp_apply = fft64_rnx_svp_apply_ref; + + // Add optimized function pointers here + if (CPU_SUPPORTS("avx")) { + module->vtable.vec_rnx_add = vec_rnx_add_avx; + module->vtable.vec_rnx_sub = vec_rnx_sub_avx; + module->vtable.vec_rnx_negate = vec_rnx_negate_avx; + module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx; + module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_avx; + module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx; + module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_avx; + module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx; + module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_avx; + module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_avx; + } +} + +void init_rnx_module_info(MOD_RNX* module, // + uint64_t n, RNX_MODULE_TYPE mtype) { + memset(module, 0, sizeof(MOD_RNX)); + module->n = n; + module->m = n >> 1; + module->mtype = mtype; + switch (mtype) { + case FFT64: + fft64_init_rnx_module_precomp(module); + fft64_init_rnx_module_vtable(module); + break; + default: + NOT_SUPPORTED(); // unknown mtype + } +} + +void finalize_rnx_module_info(MOD_RNX* module) { + if (module->custom) module->custom_deleter(module->custom); + switch (module->mtype) { + case FFT64: + fft64_finalize_rnx_module_precomp(module); + // fft64_finalize_rnx_module_vtable(module); // nothing to finalize + break; + default: + NOT_SUPPORTED(); // unknown mtype + } +} + +EXPORT MOD_RNX* new_rnx_module_info(uint64_t nn, RNX_MODULE_TYPE mtype) { + MOD_RNX* res = (MOD_RNX*)malloc(sizeof(MOD_RNX)); + init_rnx_module_info(res, nn, mtype); + return res; +} + +EXPORT void delete_rnx_module_info(MOD_RNX* module_info) { + finalize_rnx_module_info(module_info); + free(module_info); +} + +EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module) { return module->n; } + +/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */ +EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols) { // dimensions + return (RNX_VMP_PMAT*)spqlios_alloc(bytes_of_rnx_vmp_pmat(module, nrows, ncols)); +} +EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr) { spqlios_free(ptr); } + +//////////////// wrappers ////////////////// + +/** @brief sets res = a + b */ +EXPORT void vec_rnx_add( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + module->vtable.vec_rnx_add(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl); +} + +/** @brief sets res = 0 */ +EXPORT void vec_rnx_zero( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl // res +) { + module->vtable.vec_rnx_zero(module, res, res_size, res_sl); +} + +/** @brief sets res = a */ +EXPORT void vec_rnx_copy( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_copy(module, res, res_size, res_sl, a, a_size, a_sl); +} + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_negate(module, res, res_size, res_sl, a, a_size, a_sl); +} + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + module->vtable.vec_rnx_sub(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl); +} + +/** @brief sets res = a . X^p */ +EXPORT void vec_rnx_rotate( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_rotate(module, p, res, res_size, res_sl, a, a_size, a_sl); +} + +/** @brief sets res = a(X^p) */ +EXPORT void vec_rnx_automorphism( // + const MOD_RNX* module, // N + int64_t p, // X -> X^p + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_automorphism(module, p, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_mul_xp_minus_one( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_mul_xp_minus_one(module, p, res, res_size, res_sl, a, a_size, a_sl); +} +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols) { // dimensions + return module->vtable.bytes_of_rnx_vmp_pmat(module, nrows, ncols); +} + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void rnx_vmp_prepare_contiguous( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* a, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + module->vtable.rnx_vmp_prepare_contiguous(module, pmat, a, nrows, ncols, tmp_space); +} + +/** @brief number of scratch bytes necessary to prepare a matrix */ +EXPORT uint64_t rnx_vmp_prepare_contiguous_tmp_bytes(const MOD_RNX* module) { + return module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes(module); +} + +/** @brief applies a vmp product res = a x pmat */ +EXPORT void rnx_vmp_apply_tmp_a( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + module->vtable.rnx_vmp_apply_tmp_a(module, res, res_size, res_sl, tmpa, a_size, a_sl, pmat, nrows, ncols, tmp_space); +} + +EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( // + const MOD_RNX* module, // N + uint64_t res_size, // res size + uint64_t a_size, // a size + uint64_t nrows, uint64_t ncols // prep matrix dims +) { + return module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes(module, res_size, a_size, nrows, ncols); +} + +/** @brief minimal size of the tmp_space */ +EXPORT void rnx_vmp_apply_dft_to_dft( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + module->vtable.rnx_vmp_apply_dft_to_dft(module, res, res_size, res_sl, a_dft, a_size, a_sl, pmat, nrows, ncols, + tmp_space); +} + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + return module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes(module, res_size, a_size, nrows, ncols); +} + +EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->vtable.bytes_of_rnx_svp_ppol(module); } + +EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N + RNX_SVP_PPOL* ppol, // output + const double* pol // a +) { + module->vtable.rnx_svp_prepare(module, ppol, pol); +} + +EXPORT void rnx_svp_apply( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // output + const RNX_SVP_PPOL* ppol, // prepared pol + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.rnx_svp_apply(module, // N + res, res_size, res_sl, // output + ppol, // prepared pol + a, a_size, a_sl); +} + +EXPORT void rnx_approxdecomp_from_tnxdbl( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a) { // a + module->vtable.rnx_approxdecomp_from_tnxdbl(module, gadget, res, res_size, res_sl, a); +} + +EXPORT void vec_rnx_to_znx32( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_to_znx32(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_from_znx32( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_from_znx32(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_to_tnx32( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_to_tnx32(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_from_tnx32( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_from_tnx32(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_to_tnxdbl( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_to_tnxdbl(module, res, res_size, res_sl, a, a_size, a_sl); +} diff --git a/spqlios/lib/spqlios/arithmetic/vec_rnx_approxdecomp_avx.c b/spqlios/lib/spqlios/arithmetic/vec_rnx_approxdecomp_avx.c new file mode 100644 index 0000000..2acda14 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_rnx_approxdecomp_avx.c @@ -0,0 +1,59 @@ +#include + +#include "immintrin.h" +#include "vec_rnx_arithmetic_private.h" + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnxdbl_avx( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a // a +) { + const uint64_t nn = module->n; + if (nn < 4) return rnx_approxdecomp_from_tnxdbl_ref(module, gadget, res, res_size, res_sl, a); + const uint64_t ell = gadget->ell; + const __m256i k = _mm256_set1_epi64x(gadget->k); + const __m256d add_cst = _mm256_set1_pd(gadget->add_cst); + const __m256i and_mask = _mm256_set1_epi64x(gadget->and_mask); + const __m256i or_mask = _mm256_set1_epi64x(gadget->or_mask); + const __m256d sub_cst = _mm256_set1_pd(gadget->sub_cst); + const uint64_t msize = res_size <= ell ? res_size : ell; + // gadget decompose column by column + if (msize == ell) { + // this is the main scenario when msize == ell + double* const last_r = res + (msize - 1) * res_sl; + for (uint64_t j = 0; j < nn; j += 4) { + double* rr = last_r + j; + const double* aa = a + j; + __m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst); + __m256i t_int = _mm256_castpd_si256(t_dbl); + do { + __m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask); + _mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst)); + t_int = _mm256_srlv_epi64(t_int, k); + rr -= res_sl; + } while (rr >= res); + } + } else if (msize > 0) { + // otherwise, if msize < ell: there is one additional rshift + const __m256i first_rsh = _mm256_set1_epi64x((ell - msize) * gadget->k); + double* const last_r = res + (msize - 1) * res_sl; + for (uint64_t j = 0; j < nn; j += 4) { + double* rr = last_r + j; + const double* aa = a + j; + __m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst); + __m256i t_int = _mm256_srlv_epi64(_mm256_castpd_si256(t_dbl), first_rsh); + do { + __m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask); + _mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst)); + t_int = _mm256_srlv_epi64(t_int, k); + rr -= res_sl; + } while (rr >= res); + } + } + // zero-out the last slices (if any) + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/spqlios/lib/spqlios/arithmetic/vec_rnx_approxdecomp_ref.c b/spqlios/lib/spqlios/arithmetic/vec_rnx_approxdecomp_ref.c new file mode 100644 index 0000000..eab2d12 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_rnx_approxdecomp_ref.c @@ -0,0 +1,75 @@ +#include + +#include "vec_rnx_arithmetic_private.h" + +typedef union di { + double dv; + uint64_t uv; +} di_t; + +/** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */ +EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( // + const MOD_RNX* module, // N + uint64_t k, uint64_t ell // base 2^K and size +) { + if (k * ell > 50) return spqlios_error("gadget requires a too large fp precision"); + TNXDBL_APPROXDECOMP_GADGET* res = spqlios_alloc(sizeof(TNXDBL_APPROXDECOMP_GADGET)); + res->k = k; + res->ell = ell; + // double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[) + union di add_cst; + add_cst.dv = UINT64_C(3) << (51 - ell * k); + for (uint64_t i = 0; i < ell; ++i) { + add_cst.uv |= UINT64_C(1) << ((i + 1) * k - 1); + } + res->add_cst = add_cst.dv; + // uint64_t and_mask; // uint64(2^(K)-1) + res->and_mask = (UINT64_C(1) << k) - 1; + // uint64_t or_mask; // double(2^52) + union di or_mask; + or_mask.dv = (UINT64_C(1) << 52); + res->or_mask = or_mask.uv; + // double sub_cst; // double(2^52 + 2^(K-1)) + res->sub_cst = ((UINT64_C(1) << 52) + (UINT64_C(1) << (k - 1))); + return res; +} + +EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget) { spqlios_free(gadget); } + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnxdbl_ref( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a // a +) { + const uint64_t nn = module->n; + const uint64_t k = gadget->k; + const uint64_t ell = gadget->ell; + const double add_cst = gadget->add_cst; + const uint64_t and_mask = gadget->and_mask; + const uint64_t or_mask = gadget->or_mask; + const double sub_cst = gadget->sub_cst; + const uint64_t msize = res_size <= ell ? res_size : ell; + const uint64_t first_rsh = (ell - msize) * k; + // gadget decompose column by column + if (msize > 0) { + double* const last_r = res + (msize - 1) * res_sl; + for (uint64_t j = 0; j < nn; ++j) { + double* rr = last_r + j; + di_t t = {.dv = a[j] + add_cst}; + if (msize < ell) t.uv >>= first_rsh; + do { + di_t u; + u.uv = (t.uv & and_mask) | or_mask; + *rr = u.dv - sub_cst; + t.uv >>= k; + rr -= res_sl; + } while (rr >= res); + } + } + // zero-out the last slices (if any) + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic.c b/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic.c new file mode 100644 index 0000000..eb56899 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic.c @@ -0,0 +1,223 @@ +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "vec_rnx_arithmetic_private.h" + +void rnx_add_ref(uint64_t nn, double* res, const double* a, const double* b) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = a[i] + b[i]; + } +} + +void rnx_sub_ref(uint64_t nn, double* res, const double* a, const double* b) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = a[i] - b[i]; + } +} + +void rnx_negate_ref(uint64_t nn, double* res, const double* a) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = -a[i]; + } +} + +/** @brief sets res = a + b */ +EXPORT void vec_rnx_add_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->n; + if (a_size < b_size) { + const uint64_t msize = res_size < a_size ? res_size : a_size; + const uint64_t nsize = res_size < b_size ? res_size : b_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_add_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, b + i * b_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } else { + const uint64_t msize = res_size < b_size ? res_size : b_size; + const uint64_t nsize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_add_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } +} + +/** @brief sets res = 0 */ +EXPORT void vec_rnx_zero_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl // res +) { + const uint64_t nn = module->n; + for (uint64_t i = 0; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a */ +EXPORT void vec_rnx_copy_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + memcpy(res_ptr, a_ptr, nn * sizeof(double)); + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + rnx_negate_ref(nn, res_ptr, a_ptr); + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->n; + if (a_size < b_size) { + const uint64_t msize = res_size < a_size ? res_size : a_size; + const uint64_t nsize = res_size < b_size ? res_size : b_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_sub_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + rnx_negate_ref(nn, res + i * res_sl, b + i * b_sl); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } else { + const uint64_t msize = res_size < b_size ? res_size : b_size; + const uint64_t nsize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_sub_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } +} + +/** @brief sets res = a . X^p */ +EXPORT void vec_rnx_rotate_ref( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + rnx_rotate_inplace_f64(nn, p, res_ptr); + } else { + rnx_rotate_f64(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a(X^p) */ +EXPORT void vec_rnx_automorphism_ref( // + const MOD_RNX* module, // N + int64_t p, // X -> X^p + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + rnx_automorphism_inplace_f64(nn, p, res_ptr); + } else { + rnx_automorphism_f64(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a . (X^p - 1) */ +EXPORT void vec_rnx_mul_xp_minus_one_ref( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + rnx_mul_xp_minus_one_inplace(nn, p, res_ptr); + } else { + rnx_mul_xp_minus_one(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic.h b/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic.h new file mode 100644 index 0000000..16a5e6d --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic.h @@ -0,0 +1,340 @@ +#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_H +#define SPQLIOS_VEC_RNX_ARITHMETIC_H + +#include + +#include "../commons.h" + +/** + * We support the following module families: + * - FFT64: + * the overall precision should fit at all times over 52 bits. + */ +typedef enum rnx_module_type_t { FFT64 } RNX_MODULE_TYPE; + +/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */ +typedef struct rnx_module_info_t MOD_RNX; + +/** + * @brief obtain a module info for ring dimension N + * the module-info knows about: + * - the dimension N (or the complex dimension m=N/2) + * - any moduleuted fft or ntt items + * - the hardware (avx, arm64, x86, ...) + */ +EXPORT MOD_RNX* new_rnx_module_info(uint64_t N, RNX_MODULE_TYPE mode); +EXPORT void delete_rnx_module_info(MOD_RNX* module_info); +EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module); + +// basic arithmetic + +/** @brief sets res = 0 */ +EXPORT void vec_rnx_zero( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl // res +); + +/** @brief sets res = a */ +EXPORT void vec_rnx_copy( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a + b */ +EXPORT void vec_rnx_add( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a . X^p */ +EXPORT void vec_rnx_rotate( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a . (X^p - 1) */ +EXPORT void vec_rnx_mul_xp_minus_one( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a(X^p) */ +EXPORT void vec_rnx_automorphism( // + const MOD_RNX* module, // N + int64_t p, // X -> X^p + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/////////////////////////////////////////////////////////////////// +// conversions // +/////////////////////////////////////////////////////////////////// + +EXPORT void vec_rnx_to_znx32( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_znx32( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnx32( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_tnx32( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnx32x2( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_tnx32x2( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnxdbl( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/////////////////////////////////////////////////////////////////// +// isolated products (n.log(n), but not particularly optimized // +/////////////////////////////////////////////////////////////////// + +/** @brief res = a * b : small polynomial product */ +EXPORT void rnx_small_single_product( // + const MOD_RNX* module, // N + double* res, // output + const double* a, // a + const double* b, // b + uint8_t* tmp); // scratch space + +EXPORT uint64_t rnx_small_single_product_tmp_bytes(const MOD_RNX* module); + +/** @brief res = a * b centermod 1: small polynomial product */ +EXPORT void tnxdbl_small_single_product( // + const MOD_RNX* module, // N + double* torus_res, // output + const double* int_a, // a + const double* torus_b, // b + uint8_t* tmp); // scratch space + +EXPORT uint64_t tnxdbl_small_single_product_tmp_bytes(const MOD_RNX* module); + +/** @brief res = a * b: small polynomial product */ +EXPORT void znx32_small_single_product( // + const MOD_RNX* module, // N + int32_t* int_res, // output + const int32_t* int_a, // a + const int32_t* int_b, // b + uint8_t* tmp); // scratch space + +EXPORT uint64_t znx32_small_single_product_tmp_bytes(const MOD_RNX* module); + +/** @brief res = a * b centermod 1: small polynomial product */ +EXPORT void tnx32_small_single_product( // + const MOD_RNX* module, // N + int32_t* torus_res, // output + const int32_t* int_a, // a + const int32_t* torus_b, // b + uint8_t* tmp); // scratch space + +EXPORT uint64_t tnx32_small_single_product_tmp_bytes(const MOD_RNX* module); + +/////////////////////////////////////////////////////////////////// +// prepared gadget decompositions (optimized) // +/////////////////////////////////////////////////////////////////// + +// decompose from tnx32 + +typedef struct tnx32_approxdecomp_gadget_t TNX32_APPROXDECOMP_GADGET; + +/** @brief new gadget: delete with delete_tnx32_approxdecomp_gadget */ +EXPORT TNX32_APPROXDECOMP_GADGET* new_tnx32_approxdecomp_gadget( // + const MOD_RNX* module, // N + uint64_t k, uint64_t ell // base 2^K and size +); +EXPORT void delete_tnx32_approxdecomp_gadget(const MOD_RNX* module); + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnx32( // + const MOD_RNX* module, // N + const TNX32_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a // a +); + +// decompose from tnx32x2 + +typedef struct tnx32x2_approxdecomp_gadget_t TNX32X2_APPROXDECOMP_GADGET; + +/** @brief new gadget: delete with delete_tnx32x2_approxdecomp_gadget */ +EXPORT TNX32X2_APPROXDECOMP_GADGET* new_tnx32x2_approxdecomp_gadget(const MOD_RNX* module, uint64_t ka, uint64_t ella, + uint64_t kb, uint64_t ellb); +EXPORT void delete_tnx32x2_approxdecomp_gadget(const MOD_RNX* module); + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnx32x2( // + const MOD_RNX* module, // N + const TNX32X2_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a // a +); + +// decompose from tnxdbl + +typedef struct tnxdbl_approxdecomp_gadget_t TNXDBL_APPROXDECOMP_GADGET; + +/** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */ +EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( // + const MOD_RNX* module, // N + uint64_t k, uint64_t ell // base 2^K and size +); +EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget); + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnxdbl( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a); // a + +/////////////////////////////////////////////////////////////////// +// prepared scalar-vector product (optimized) // +/////////////////////////////////////////////////////////////////// + +/** @brief opaque type that represents a polynomial of RnX prepared for a scalar-vector product */ +typedef struct rnx_svp_ppol_t RNX_SVP_PPOL; + +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module); // N + +/** @brief allocates a prepared vector (release with delete_rnx_svp_ppol) */ +EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module); // N + +/** @brief frees memory for a prepared vector */ +EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* res); + +/** @brief prepares a svp polynomial */ +EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N + RNX_SVP_PPOL* ppol, // output + const double* pol // a +); + +/** @brief apply a svp product, result = ppol * a, presented in DFT space */ +EXPORT void rnx_svp_apply( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // output + const RNX_SVP_PPOL* ppol, // prepared pol + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/////////////////////////////////////////////////////////////////// +// prepared vector-matrix product (optimized) // +/////////////////////////////////////////////////////////////////// + +typedef struct rnx_vmp_pmat_t RNX_VMP_PMAT; + +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols); // dimensions + +/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */ +EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols); // dimensions +EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr); + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void rnx_vmp_prepare_contiguous( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* a, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief number of scratch bytes necessary to prepare a matrix */ +EXPORT uint64_t rnx_vmp_prepare_contiguous_tmp_bytes(const MOD_RNX* module); + +/** @brief applies a vmp product res = a x pmat */ +EXPORT void rnx_vmp_apply_tmp_a( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); + +EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( // + const MOD_RNX* module, // N + uint64_t res_size, // res size + uint64_t a_size, // a size + uint64_t nrows, uint64_t ncols // prep matrix dims +); + +/** @brief minimal size of the tmp_space */ +EXPORT void rnx_vmp_apply_dft_to_dft( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); + +/** @brief sets res = DFT(a) */ +EXPORT void vec_rnx_dft(const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = iDFT(a_dft) -- idft is not normalized */ +EXPORT void vec_rnx_idft(const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl // a +); + +#endif // SPQLIOS_VEC_RNX_ARITHMETIC_H diff --git a/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_avx.c b/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_avx.c new file mode 100644 index 0000000..04b3ec0 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_avx.c @@ -0,0 +1,189 @@ +#include +#include + +#include "vec_rnx_arithmetic_private.h" + +void rnx_add_avx(uint64_t nn, double* res, const double* a, const double* b) { + if (nn < 8) { + if (nn == 4) { + _mm256_storeu_pd(res, _mm256_add_pd(_mm256_loadu_pd(a), _mm256_loadu_pd(b))); + } else if (nn == 2) { + _mm_storeu_pd(res, _mm_add_pd(_mm_loadu_pd(a), _mm_loadu_pd(b))); + } else if (nn == 1) { + *res = *a + *b; + } else { + NOT_SUPPORTED(); // not a power of 2 + } + return; + } + // general case: nn >= 8 + __m256d x0, x1, x2, x3, x4, x5; + const double* aa = a; + const double* bb = b; + double* rr = res; + double* const rrend = res + nn; + do { + x0 = _mm256_loadu_pd(aa); + x1 = _mm256_loadu_pd(aa + 4); + x2 = _mm256_loadu_pd(bb); + x3 = _mm256_loadu_pd(bb + 4); + x4 = _mm256_add_pd(x0, x2); + x5 = _mm256_add_pd(x1, x3); + _mm256_storeu_pd(rr, x4); + _mm256_storeu_pd(rr + 4, x5); + aa += 8; + bb += 8; + rr += 8; + } while (rr < rrend); +} + +void rnx_sub_avx(uint64_t nn, double* res, const double* a, const double* b) { + if (nn < 8) { + if (nn == 4) { + _mm256_storeu_pd(res, _mm256_sub_pd(_mm256_loadu_pd(a), _mm256_loadu_pd(b))); + } else if (nn == 2) { + _mm_storeu_pd(res, _mm_sub_pd(_mm_loadu_pd(a), _mm_loadu_pd(b))); + } else if (nn == 1) { + *res = *a - *b; + } else { + NOT_SUPPORTED(); // not a power of 2 + } + return; + } + // general case: nn >= 8 + __m256d x0, x1, x2, x3, x4, x5; + const double* aa = a; + const double* bb = b; + double* rr = res; + double* const rrend = res + nn; + do { + x0 = _mm256_loadu_pd(aa); + x1 = _mm256_loadu_pd(aa + 4); + x2 = _mm256_loadu_pd(bb); + x3 = _mm256_loadu_pd(bb + 4); + x4 = _mm256_sub_pd(x0, x2); + x5 = _mm256_sub_pd(x1, x3); + _mm256_storeu_pd(rr, x4); + _mm256_storeu_pd(rr + 4, x5); + aa += 8; + bb += 8; + rr += 8; + } while (rr < rrend); +} + +void rnx_negate_avx(uint64_t nn, double* res, const double* b) { + if (nn < 8) { + if (nn == 4) { + _mm256_storeu_pd(res, _mm256_sub_pd(_mm256_set1_pd(0), _mm256_loadu_pd(b))); + } else if (nn == 2) { + _mm_storeu_pd(res, _mm_sub_pd(_mm_set1_pd(0), _mm_loadu_pd(b))); + } else if (nn == 1) { + *res = -*b; + } else { + NOT_SUPPORTED(); // not a power of 2 + } + return; + } + // general case: nn >= 8 + __m256d x2, x3, x4, x5; + const __m256d ZERO = _mm256_set1_pd(0); + const double* bb = b; + double* rr = res; + double* const rrend = res + nn; + do { + x2 = _mm256_loadu_pd(bb); + x3 = _mm256_loadu_pd(bb + 4); + x4 = _mm256_sub_pd(ZERO, x2); + x5 = _mm256_sub_pd(ZERO, x3); + _mm256_storeu_pd(rr, x4); + _mm256_storeu_pd(rr + 4, x5); + bb += 8; + rr += 8; + } while (rr < rrend); +} + +/** @brief sets res = a + b */ +EXPORT void vec_rnx_add_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->n; + if (a_size < b_size) { + const uint64_t msize = res_size < a_size ? res_size : a_size; + const uint64_t nsize = res_size < b_size ? res_size : b_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_add_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, b + i * b_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } else { + const uint64_t msize = res_size < b_size ? res_size : b_size; + const uint64_t nsize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_add_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } +} + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_negate_avx(nn, res + i * res_sl, a + i * a_sl); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->n; + if (a_size < b_size) { + const uint64_t msize = res_size < a_size ? res_size : a_size; + const uint64_t nsize = res_size < b_size ? res_size : b_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_sub_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + rnx_negate_avx(nn, res + i * res_sl, b + i * b_sl); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } else { + const uint64_t msize = res_size < b_size ? res_size : b_size; + const uint64_t nsize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_sub_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } +} diff --git a/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h b/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h new file mode 100644 index 0000000..f2e07eb --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h @@ -0,0 +1,88 @@ +#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H +#define SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H + +#include "vec_rnx_arithmetic.h" + +typedef typeof(vec_rnx_zero) VEC_RNX_ZERO_F; +typedef typeof(vec_rnx_copy) VEC_RNX_COPY_F; +typedef typeof(vec_rnx_negate) VEC_RNX_NEGATE_F; +typedef typeof(vec_rnx_add) VEC_RNX_ADD_F; +typedef typeof(vec_rnx_sub) VEC_RNX_SUB_F; +typedef typeof(vec_rnx_rotate) VEC_RNX_ROTATE_F; +typedef typeof(vec_rnx_mul_xp_minus_one) VEC_RNX_MUL_XP_MINUS_ONE_F; +typedef typeof(vec_rnx_automorphism) VEC_RNX_AUTOMORPHISM_F; +typedef typeof(vec_rnx_to_znx32) VEC_RNX_TO_ZNX32_F; +typedef typeof(vec_rnx_from_znx32) VEC_RNX_FROM_ZNX32_F; +typedef typeof(vec_rnx_to_tnx32) VEC_RNX_TO_TNX32_F; +typedef typeof(vec_rnx_from_tnx32) VEC_RNX_FROM_TNX32_F; +typedef typeof(vec_rnx_to_tnx32x2) VEC_RNX_TO_TNX32X2_F; +typedef typeof(vec_rnx_from_tnx32x2) VEC_RNX_FROM_TNX32X2_F; +typedef typeof(vec_rnx_to_tnxdbl) VEC_RNX_TO_TNXDBL_F; +// typedef typeof(vec_rnx_from_tnxdbl) VEC_RNX_FROM_TNXDBL_F; +typedef typeof(rnx_small_single_product) RNX_SMALL_SINGLE_PRODUCT_F; +typedef typeof(rnx_small_single_product_tmp_bytes) RNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; +typedef typeof(tnxdbl_small_single_product) TNXDBL_SMALL_SINGLE_PRODUCT_F; +typedef typeof(tnxdbl_small_single_product_tmp_bytes) TNXDBL_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; +typedef typeof(znx32_small_single_product) ZNX32_SMALL_SINGLE_PRODUCT_F; +typedef typeof(znx32_small_single_product_tmp_bytes) ZNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; +typedef typeof(tnx32_small_single_product) TNX32_SMALL_SINGLE_PRODUCT_F; +typedef typeof(tnx32_small_single_product_tmp_bytes) TNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; +typedef typeof(rnx_approxdecomp_from_tnx32) RNX_APPROXDECOMP_FROM_TNX32_F; +typedef typeof(rnx_approxdecomp_from_tnx32x2) RNX_APPROXDECOMP_FROM_TNX32X2_F; +typedef typeof(rnx_approxdecomp_from_tnxdbl) RNX_APPROXDECOMP_FROM_TNXDBL_F; +typedef typeof(bytes_of_rnx_svp_ppol) BYTES_OF_RNX_SVP_PPOL_F; +typedef typeof(rnx_svp_prepare) RNX_SVP_PREPARE_F; +typedef typeof(rnx_svp_apply) RNX_SVP_APPLY_F; +typedef typeof(bytes_of_rnx_vmp_pmat) BYTES_OF_RNX_VMP_PMAT_F; +typedef typeof(rnx_vmp_prepare_contiguous) RNX_VMP_PREPARE_CONTIGUOUS_F; +typedef typeof(rnx_vmp_prepare_contiguous_tmp_bytes) RNX_VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F; +typedef typeof(rnx_vmp_apply_tmp_a) RNX_VMP_APPLY_TMP_A_F; +typedef typeof(rnx_vmp_apply_tmp_a_tmp_bytes) RNX_VMP_APPLY_TMP_A_TMP_BYTES_F; +typedef typeof(rnx_vmp_apply_dft_to_dft) RNX_VMP_APPLY_DFT_TO_DFT_F; +typedef typeof(rnx_vmp_apply_dft_to_dft_tmp_bytes) RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F; +typedef typeof(vec_rnx_dft) VEC_RNX_DFT_F; +typedef typeof(vec_rnx_idft) VEC_RNX_IDFT_F; + +typedef struct rnx_module_vtable_t RNX_MODULE_VTABLE; +struct rnx_module_vtable_t { + VEC_RNX_ZERO_F* vec_rnx_zero; + VEC_RNX_COPY_F* vec_rnx_copy; + VEC_RNX_NEGATE_F* vec_rnx_negate; + VEC_RNX_ADD_F* vec_rnx_add; + VEC_RNX_SUB_F* vec_rnx_sub; + VEC_RNX_ROTATE_F* vec_rnx_rotate; + VEC_RNX_MUL_XP_MINUS_ONE_F* vec_rnx_mul_xp_minus_one; + VEC_RNX_AUTOMORPHISM_F* vec_rnx_automorphism; + VEC_RNX_TO_ZNX32_F* vec_rnx_to_znx32; + VEC_RNX_FROM_ZNX32_F* vec_rnx_from_znx32; + VEC_RNX_TO_TNX32_F* vec_rnx_to_tnx32; + VEC_RNX_FROM_TNX32_F* vec_rnx_from_tnx32; + VEC_RNX_TO_TNX32X2_F* vec_rnx_to_tnx32x2; + VEC_RNX_FROM_TNX32X2_F* vec_rnx_from_tnx32x2; + VEC_RNX_TO_TNXDBL_F* vec_rnx_to_tnxdbl; + RNX_SMALL_SINGLE_PRODUCT_F* rnx_small_single_product; + RNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* rnx_small_single_product_tmp_bytes; + TNXDBL_SMALL_SINGLE_PRODUCT_F* tnxdbl_small_single_product; + TNXDBL_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* tnxdbl_small_single_product_tmp_bytes; + ZNX32_SMALL_SINGLE_PRODUCT_F* znx32_small_single_product; + ZNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx32_small_single_product_tmp_bytes; + TNX32_SMALL_SINGLE_PRODUCT_F* tnx32_small_single_product; + TNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* tnx32_small_single_product_tmp_bytes; + RNX_APPROXDECOMP_FROM_TNX32_F* rnx_approxdecomp_from_tnx32; + RNX_APPROXDECOMP_FROM_TNX32X2_F* rnx_approxdecomp_from_tnx32x2; + RNX_APPROXDECOMP_FROM_TNXDBL_F* rnx_approxdecomp_from_tnxdbl; + BYTES_OF_RNX_SVP_PPOL_F* bytes_of_rnx_svp_ppol; + RNX_SVP_PREPARE_F* rnx_svp_prepare; + RNX_SVP_APPLY_F* rnx_svp_apply; + BYTES_OF_RNX_VMP_PMAT_F* bytes_of_rnx_vmp_pmat; + RNX_VMP_PREPARE_CONTIGUOUS_F* rnx_vmp_prepare_contiguous; + RNX_VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F* rnx_vmp_prepare_contiguous_tmp_bytes; + RNX_VMP_APPLY_TMP_A_F* rnx_vmp_apply_tmp_a; + RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* rnx_vmp_apply_tmp_a_tmp_bytes; + RNX_VMP_APPLY_DFT_TO_DFT_F* rnx_vmp_apply_dft_to_dft; + RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* rnx_vmp_apply_dft_to_dft_tmp_bytes; + VEC_RNX_DFT_F* vec_rnx_dft; + VEC_RNX_IDFT_F* vec_rnx_idft; +}; + +#endif // SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H diff --git a/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_private.h b/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_private.h new file mode 100644 index 0000000..59a4cf8 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_private.h @@ -0,0 +1,284 @@ +#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H +#define SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H + +#include "../commons_private.h" +#include "../reim/reim_fft.h" +#include "vec_rnx_arithmetic.h" +#include "vec_rnx_arithmetic_plugin.h" + +typedef struct fft64_rnx_module_precomp_t FFT64_RNX_MODULE_PRECOMP; +struct fft64_rnx_module_precomp_t { + REIM_FFT_PRECOMP* p_fft; + REIM_IFFT_PRECOMP* p_ifft; + REIM_FFTVEC_MUL_PRECOMP* p_fftvec_mul; + REIM_FFTVEC_ADDMUL_PRECOMP* p_fftvec_addmul; +}; + +typedef union rnx_module_precomp_t RNX_MODULE_PRECOMP; +union rnx_module_precomp_t { + FFT64_RNX_MODULE_PRECOMP fft64; +}; + +void fft64_init_rnx_module_precomp(MOD_RNX* module); + +void fft64_finalize_rnx_module_precomp(MOD_RNX* module); + +/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */ +struct rnx_module_info_t { + uint64_t n; + uint64_t m; + RNX_MODULE_TYPE mtype; + RNX_MODULE_VTABLE vtable; + RNX_MODULE_PRECOMP precomp; + void* custom; + void (*custom_deleter)(void*); +}; + +void init_rnx_module_info(MOD_RNX* module, // + uint64_t, RNX_MODULE_TYPE mtype); + +void finalize_rnx_module_info(MOD_RNX* module); + +void fft64_init_rnx_module_vtable(MOD_RNX* module); + +/////////////////////////////////////////////////////////////////// +// prepared gadget decompositions (optimized) // +/////////////////////////////////////////////////////////////////// + +struct tnx32_approxdec_gadget_t { + uint64_t k; + uint64_t ell; + int32_t add_cst; // 1/2.(sum 2^-(i+1)K) + int32_t rshift_base; // 32 - K + int64_t and_mask; // 2^K-1 + int64_t or_mask; // double(2^52) + double sub_cst; // double(2^52 + 2^(K-1)) + uint8_t rshifts[8]; // 32 - (i+1).K +}; + +struct tnx32x2_approxdec_gadget_t { + // TODO +}; + +struct tnxdbl_approxdecomp_gadget_t { + uint64_t k; + uint64_t ell; + double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[) + uint64_t and_mask; // uint64(2^(K)-1) + uint64_t or_mask; // double(2^52) + double sub_cst; // double(2^52 + 2^(K-1)) +}; + +EXPORT void vec_rnx_add_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); +EXPORT void vec_rnx_add_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = 0 */ +EXPORT void vec_rnx_zero_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl // res +); + +/** @brief sets res = a */ +EXPORT void vec_rnx_copy_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a . X^p */ +EXPORT void vec_rnx_rotate_ref( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a(X^p) */ +EXPORT void vec_rnx_automorphism_ref( // + const MOD_RNX* module, // N + int64_t p, // X -> X^p + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t fft64_bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols); + +EXPORT void fft64_rnx_vmp_apply_dft_to_dft_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); +EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); +EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); +EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); +EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); +EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); +EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref(const MOD_RNX* module); +EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx(const MOD_RNX* module); + +EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a) + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); +EXPORT void fft64_rnx_vmp_apply_tmp_a_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a) + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); + +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); + +/// gadget decompositions + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnxdbl_ref( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a); // a +EXPORT void rnx_approxdecomp_from_tnxdbl_avx( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a); // a + +EXPORT void vec_rnx_mul_xp_minus_one_ref( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_znx32_ref( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_znx32_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnx32_ref( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_tnx32_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnxdbl_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module); // N + +/** @brief prepares a svp polynomial */ +EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N + RNX_SVP_PPOL* ppol, // output + const double* pol // a +); + +/** @brief apply a svp product, result = ppol * a, presented in DFT space */ +EXPORT void fft64_rnx_svp_apply_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // output + const RNX_SVP_PPOL* ppol, // prepared pol + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +#endif // SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H diff --git a/spqlios/lib/spqlios/arithmetic/vec_rnx_conversions_ref.c b/spqlios/lib/spqlios/arithmetic/vec_rnx_conversions_ref.c new file mode 100644 index 0000000..2a1b296 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_rnx_conversions_ref.c @@ -0,0 +1,91 @@ +#include + +#include "vec_rnx_arithmetic_private.h" +#include "zn_arithmetic_private.h" + +EXPORT void vec_rnx_to_znx32_ref( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + dbl_round_to_i32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} + +EXPORT void vec_rnx_from_znx32_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + i32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} +EXPORT void vec_rnx_to_tnx32_ref( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + dbl_to_tn32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} +EXPORT void vec_rnx_from_tnx32_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + tn32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} + +static void dbl_to_tndbl_ref( // + const void* UNUSED, // N + double* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + static const double OFF_CST = INT64_C(3) << 51; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + double ai = a[i] + OFF_CST; + res[i] = a[i] - (ai - OFF_CST); + } + memset(res + msize, 0, (res_size - msize) * sizeof(double)); +} + +EXPORT void vec_rnx_to_tnxdbl_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + dbl_to_tndbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} diff --git a/spqlios/lib/spqlios/arithmetic/vec_rnx_svp_ref.c b/spqlios/lib/spqlios/arithmetic/vec_rnx_svp_ref.c new file mode 100644 index 0000000..f811148 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_rnx_svp_ref.c @@ -0,0 +1,47 @@ +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "vec_rnx_arithmetic_private.h" + +EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->n * sizeof(double); } + +EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module) { return spqlios_alloc(bytes_of_rnx_svp_ppol(module)); } + +EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* ppol) { spqlios_free(ppol); } + +/** @brief prepares a svp polynomial */ +EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N + RNX_SVP_PPOL* ppol, // output + const double* pol // a +) { + double* const dppol = (double*)ppol; + rnx_divide_by_m_ref(module->n, module->m, dppol, pol); + reim_fft(module->precomp.fft64.p_fft, dppol); +} + +EXPORT void fft64_rnx_svp_apply_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // output + const RNX_SVP_PPOL* ppol, // prepared pol + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + double* const dppol = (double*)ppol; + + const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < auto_end_idx; ++i) { + const double* a_ptr = a + i * a_sl; + double* const res_ptr = res + i * res_sl; + // copy the polynomial to res, apply fft in place, call fftvec + // _mul, apply ifft in place. + memcpy(res_ptr, a_ptr, nn * sizeof(double)); + reim_fft(module->precomp.fft64.p_fft, (double*)res_ptr); + reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, res_ptr, res_ptr, dppol); + reim_ifft(module->precomp.fft64.p_ifft, res_ptr); + } + + // then extend with zeros + for (uint64_t i = auto_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/spqlios/lib/spqlios/arithmetic/vec_rnx_vmp_avx.c b/spqlios/lib/spqlios/arithmetic/vec_rnx_vmp_avx.c new file mode 100644 index 0000000..4c1b23d --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_rnx_vmp_avx.c @@ -0,0 +1,196 @@ +#include +#include +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "../reim/reim_fft.h" +#include "../reim4/reim4_arithmetic.h" +#include "vec_rnx_arithmetic_private.h" + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->n; + const uint64_t m = module->m; + + double* const dtmp = (double*)tmp_space; + double* const output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + rnx_divide_by_m_avx(nn, m, dtmp, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->precomp.fft64.p_fft, dtmp); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, dtmp); + } + } + } + } else { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = output_mat + (col_i * nrows + row_i) * nn; + rnx_divide_by_m_avx(nn, m, res, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->precomp.fft64.p_fft, res); + } + } + } +} + +/** @brief minimal size of the tmp_space */ +EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + const uint64_t m = module->m; + const uint64_t nn = module->n; + + double* mat2cols_output = (double*)tmp_space; // 128 bytes + double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes + + double* mat_input = (double*)pmat; + + const uint64_t row_max = nrows < a_size ? nrows : a_size; + const uint64_t col_max = ncols < res_size ? ncols : res_size; + + if (row_max > 0 && col_max > 0) { + if (nn >= 8) { + // let's do some prefetching of the GSW key, since on some cpus, + // it helps + const uint64_t ms4 = m >> 2; // m/4 + const uint64_t gsw_iter_doubles = 8 * nrows * ncols; + const uint64_t pref_doubles = 1200; + const double* gsw_pref_ptr = mat_input; + const double* const gsw_ptr_end = mat_input + ms4 * gsw_iter_doubles; + const double* gsw_pref_ptr_target = mat_input + pref_doubles; + for (; gsw_pref_ptr < gsw_pref_ptr_target; gsw_pref_ptr += 8) { + __builtin_prefetch(gsw_pref_ptr, 0, _MM_HINT_T0); + } + const double* mat_blk_start; + uint64_t blk_i; + for (blk_i = 0, mat_blk_start = mat_input; blk_i < ms4; blk_i++, mat_blk_start += gsw_iter_doubles) { + // prefetch the next iteration + if (gsw_pref_ptr_target < gsw_ptr_end) { + gsw_pref_ptr_target += gsw_iter_doubles; + if (gsw_pref_ptr_target > gsw_ptr_end) gsw_pref_ptr_target = gsw_ptr_end; + for (; gsw_pref_ptr < gsw_pref_ptr_target; gsw_pref_ptr += 8) { + __builtin_prefetch(gsw_pref_ptr, 0, _MM_HINT_T0); + } + } + reim4_extract_1blk_from_contiguous_reim_sl_avx(m, a_sl, row_max, blk_i, extracted_blk, a_dft); + // apply mat2cols + for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) { + uint64_t col_offset = col_i * (8 * nrows); + reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + + reim4_save_1blk_to_reim_avx(m, blk_i, res + col_i * res_sl, mat2cols_output); + reim4_save_1blk_to_reim_avx(m, blk_i, res + (col_i + 1) * res_sl, mat2cols_output + 8); + } + + // check if col_max is odd, then special case + if (col_max % 2 == 1) { + uint64_t last_col = col_max - 1; + uint64_t col_offset = last_col * (8 * nrows); + + // the last column is alone in the pmat: vec_mat1col + if (ncols == col_max) { + reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } else { + // the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position + reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } + reim4_save_1blk_to_reim_avx(m, blk_i, res + last_col * res_sl, mat2cols_output); + } + } + } else { + const double* in; + uint64_t in_sl; + if (res == a_dft) { + // it is in place: copy the input vector + in = (double*)tmp_space; + in_sl = nn; + // vec_rnx_copy(module, (double*)tmp_space, row_max, nn, a_dft, row_max, a_sl); + for (uint64_t row_i = 0; row_i < row_max; row_i++) { + memcpy((double*)tmp_space + row_i * nn, a_dft + row_i * a_sl, nn * sizeof(double)); + } + } else { + // it is out of place: do the product directly + in = a_dft; + in_sl = a_sl; + } + for (uint64_t col_i = 0; col_i < col_max; col_i++) { + double* pmat_col = mat_input + col_i * nrows * nn; + { + reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, // + res + col_i * res_sl, // + in, // + pmat_col); + } + for (uint64_t row_i = 1; row_i < row_max; row_i++) { + reim_fftvec_addmul(module->precomp.fft64.p_fftvec_addmul, // + res + col_i * res_sl, // + in + row_i * in_sl, // + pmat_col + row_i * nn); + } + } + } + } + // zero out remaining bytes (if any) + for (uint64_t i = col_max; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief applies a vmp product res = a x pmat */ +EXPORT void fft64_rnx_vmp_apply_tmp_a_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a) + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + const uint64_t nn = module->n; + const uint64_t rows = nrows < a_size ? nrows : a_size; + const uint64_t cols = ncols < res_size ? ncols : res_size; + + // fft is done in place on the input (tmpa is destroyed) + for (uint64_t i = 0; i < rows; ++i) { + reim_fft(module->precomp.fft64.p_fft, tmpa + i * a_sl); + } + fft64_rnx_vmp_apply_dft_to_dft_avx(module, // + res, cols, res_sl, // + tmpa, rows, a_sl, // + pmat, nrows, ncols, // + tmp_space); + // ifft is done in place on the output + for (uint64_t i = 0; i < cols; ++i) { + reim_ifft(module->precomp.fft64.p_ifft, res + i * res_sl); + } + // zero out the remaining positions + for (uint64_t i = cols; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/spqlios/lib/spqlios/arithmetic/vec_rnx_vmp_ref.c b/spqlios/lib/spqlios/arithmetic/vec_rnx_vmp_ref.c new file mode 100644 index 0000000..de14ba8 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_rnx_vmp_ref.c @@ -0,0 +1,251 @@ +#include +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "../reim/reim_fft.h" +#include "../reim4/reim4_arithmetic.h" +#include "vec_rnx_arithmetic_private.h" + +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t fft64_bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols) { // dimensions + return nrows * ncols * module->n * sizeof(double); +} + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->n; + const uint64_t m = module->m; + + double* const dtmp = (double*)tmp_space; + double* const output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + rnx_divide_by_m_ref(nn, m, dtmp, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->precomp.fft64.p_fft, dtmp); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, dtmp); + } + } + } + } else { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = output_mat + (col_i * nrows + row_i) * nn; + rnx_divide_by_m_ref(nn, m, res, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->precomp.fft64.p_fft, res); + } + } + } +} + +/** @brief number of scratch bytes necessary to prepare a matrix */ +EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref(const MOD_RNX* module) { + const uint64_t nn = module->n; + return nn * sizeof(int64_t); +} + +/** @brief minimal size of the tmp_space */ +EXPORT void fft64_rnx_vmp_apply_dft_to_dft_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + const uint64_t m = module->m; + const uint64_t nn = module->n; + + double* mat2cols_output = (double*)tmp_space; // 128 bytes + double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes + + double* mat_input = (double*)pmat; + + const uint64_t row_max = nrows < a_size ? nrows : a_size; + const uint64_t col_max = ncols < res_size ? ncols : res_size; + + if (row_max > 0 && col_max > 0) { + if (nn >= 8) { + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols); + + reim4_extract_1blk_from_contiguous_reim_sl_ref(m, a_sl, row_max, blk_i, extracted_blk, a_dft); + // apply mat2cols + for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) { + uint64_t col_offset = col_i * (8 * nrows); + reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + + reim4_save_1blk_to_reim_ref(m, blk_i, res + col_i * res_sl, mat2cols_output); + reim4_save_1blk_to_reim_ref(m, blk_i, res + (col_i + 1) * res_sl, mat2cols_output + 8); + } + + // check if col_max is odd, then special case + if (col_max % 2 == 1) { + uint64_t last_col = col_max - 1; + uint64_t col_offset = last_col * (8 * nrows); + + // the last column is alone in the pmat: vec_mat1col + if (ncols == col_max) { + reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } else { + // the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position + reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } + reim4_save_1blk_to_reim_ref(m, blk_i, res + last_col * res_sl, mat2cols_output); + } + } + } else { + const double* in; + uint64_t in_sl; + if (res == a_dft) { + // it is in place: copy the input vector + in = (double*)tmp_space; + in_sl = nn; + // vec_rnx_copy(module, (double*)tmp_space, row_max, nn, a_dft, row_max, a_sl); + for (uint64_t row_i = 0; row_i < row_max; row_i++) { + memcpy((double*)tmp_space + row_i * nn, a_dft + row_i * a_sl, nn * sizeof(double)); + } + } else { + // it is out of place: do the product directly + in = a_dft; + in_sl = a_sl; + } + for (uint64_t col_i = 0; col_i < col_max; col_i++) { + double* pmat_col = mat_input + col_i * nrows * nn; + { + reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, // + res + col_i * res_sl, // + in, // + pmat_col); + } + for (uint64_t row_i = 1; row_i < row_max; row_i++) { + reim_fftvec_addmul(module->precomp.fft64.p_fftvec_addmul, // + res + col_i * res_sl, // + in + row_i * in_sl, // + pmat_col + row_i * nn); + } + } + } + } + // zero out remaining bytes (if any) + for (uint64_t i = col_max; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief applies a vmp product res = a x pmat */ +EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a) + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + const uint64_t nn = module->n; + const uint64_t rows = nrows < a_size ? nrows : a_size; + const uint64_t cols = ncols < res_size ? ncols : res_size; + + // fft is done in place on the input (tmpa is destroyed) + for (uint64_t i = 0; i < rows; ++i) { + reim_fft(module->precomp.fft64.p_fft, tmpa + i * a_sl); + } + fft64_rnx_vmp_apply_dft_to_dft_ref(module, // + res, cols, res_sl, // + tmpa, rows, a_sl, // + pmat, nrows, ncols, // + tmp_space); + // ifft is done in place on the output + for (uint64_t i = 0; i < cols; ++i) { + reim_ifft(module->precomp.fft64.p_ifft, res + i * res_sl); + } + // zero out the remaining positions + for (uint64_t i = cols; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + const uint64_t row_max = nrows < a_size ? nrows : a_size; + + return (128) + (64 * row_max); +} + +#ifdef __APPLE__ +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + return fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref(module, res_size, a_size, nrows, ncols); +} +#else +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix + ) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref"))); +#endif +// avx aliases that need to be defined in the same .c file + +/** @brief number of scratch bytes necessary to prepare a matrix */ +#ifdef __APPLE__ +#pragma weak fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref +#else +EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx(const MOD_RNX* module) + __attribute((alias("fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref"))); +#endif + +/** @brief minimal size of the tmp_space */ +#ifdef __APPLE__ +#pragma weak fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref +#else +EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix + ) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref"))); +#endif + +#ifdef __APPLE__ +#pragma weak fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref +#else +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix + ) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref"))); +#endif +// wrappers diff --git a/spqlios/lib/spqlios/arithmetic/vec_znx.c b/spqlios/lib/spqlios/arithmetic/vec_znx.c new file mode 100644 index 0000000..a850bfc --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_znx.c @@ -0,0 +1,333 @@ +#include +#include +#include +#include +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "../q120/q120_arithmetic.h" +#include "../q120/q120_ntt.h" +#include "../reim/reim_fft_internal.h" +#include "../reim4/reim4_arithmetic.h" +#include "vec_znx_arithmetic.h" +#include "vec_znx_arithmetic_private.h" + +// general function (virtual dispatch) + +EXPORT void vec_znx_add(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + module->func.vec_znx_add(module, // N + res, res_size, res_sl, // res + a, a_size, a_sl, // a + b, b_size, b_sl // b + ); +} + +EXPORT void vec_znx_sub(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + module->func.vec_znx_sub(module, // N + res, res_size, res_sl, // res + a, a_size, a_sl, // a + b, b_size, b_sl // b + ); +} + +EXPORT void vec_znx_rotate(const MODULE* module, // N + const int64_t p, // rotation value + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->func.vec_znx_rotate(module, // N + p, // p + res, res_size, res_sl, // res + a, a_size, a_sl // a + ); +} + +EXPORT void vec_znx_automorphism(const MODULE* module, // N + const int64_t p, // X->X^p + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->func.vec_znx_automorphism(module, // N + p, // p + res, res_size, res_sl, // res + a, a_size, a_sl // a + ); +} + +EXPORT void vec_znx_normalize_base2k(const MODULE* module, // N + uint64_t log2_base2k, // output base 2^K + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + uint8_t* tmp_space // scratch space of size >= N +) { + module->func.vec_znx_normalize_base2k(module, // N + log2_base2k, // log2_base2k + res, res_size, res_sl, // res + a, a_size, a_sl, // a + tmp_space); +} + +EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module // N +) { + return module->func.vec_znx_normalize_base2k_tmp_bytes(module // N + ); +} + +// specialized function (ref) + +EXPORT void vec_znx_add_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->nn; + if (a_size <= b_size) { + const uint64_t sum_idx = res_size < a_size ? res_size : a_size; + const uint64_t copy_idx = res_size < b_size ? res_size : b_size; + // add up to the smallest dimension + for (uint64_t i = 0; i < sum_idx; ++i) { + znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sum_idx; i < copy_idx; ++i) { + znx_copy_i64_ref(nn, res + i * res_sl, b + i * b_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } + } else { + const uint64_t sum_idx = res_size < b_size ? res_size : b_size; + const uint64_t copy_idx = res_size < a_size ? res_size : a_size; + // add up to the smallest dimension + for (uint64_t i = 0; i < sum_idx; ++i) { + znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sum_idx; i < copy_idx; ++i) { + znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } + } +} + +EXPORT void vec_znx_sub_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->nn; + if (a_size <= b_size) { + const uint64_t sub_idx = res_size < a_size ? res_size : a_size; + const uint64_t copy_idx = res_size < b_size ? res_size : b_size; + // subtract up to the smallest dimension + for (uint64_t i = 0; i < sub_idx; ++i) { + znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then negate to the largest dimension + for (uint64_t i = sub_idx; i < copy_idx; ++i) { + znx_negate_i64_ref(nn, res + i * res_sl, b + i * b_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } + } else { + const uint64_t sub_idx = res_size < b_size ? res_size : b_size; + const uint64_t copy_idx = res_size < a_size ? res_size : a_size; + // subtract up to the smallest dimension + for (uint64_t i = 0; i < sub_idx; ++i) { + znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sub_idx; i < copy_idx; ++i) { + znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } + } +} + +EXPORT void vec_znx_rotate_ref(const MODULE* module, // N + const int64_t p, // rotation value + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->nn; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + int64_t* res_ptr = res + i * res_sl; + const int64_t* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + znx_rotate_inplace_i64(nn, p, res_ptr); + } else { + znx_rotate_i64(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT void vec_znx_automorphism_ref(const MODULE* module, // N + const int64_t p, // X->X^p + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->nn; + + const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size; + + for (uint64_t i = 0; i < auto_end_idx; ++i) { + int64_t* res_ptr = res + i * res_sl; + const int64_t* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + znx_automorphism_inplace_i64(nn, p, res_ptr); + } else { + znx_automorphism_i64(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = auto_end_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module, // N + uint64_t log2_base2k, // output base 2^K + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + uint8_t* tmp_space // scratch space of size >= N +) { + const uint64_t nn = module->nn; + + // use MSB limb of res for carry propagation + int64_t* cout = (int64_t*)tmp_space; + int64_t* cin = 0x0; + + // propagate carry until first limb of res + int64_t i = a_size - 1; + for (; i >= res_size; --i) { + znx_normalize(nn, log2_base2k, 0x0, cout, a + i * a_sl, cin); + cin = cout; + } + + // propagate carry and normalize + for (; i >= 1; --i) { + znx_normalize(nn, log2_base2k, res + i * res_sl, cout, a + i * a_sl, cin); + cin = cout; + } + + // normalize last limb + znx_normalize(nn, log2_base2k, res, 0x0, a, cin); + + // extend result with zeros + for (uint64_t i = a_size; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module // N +) { + const uint64_t nn = module->nn; + return nn * sizeof(int64_t); +} + + +// alias have to be defined in this unit: do not move +#ifdef __APPLE__ +EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( // + const MODULE* module // N + ) { + return vec_znx_normalize_base2k_tmp_bytes_ref(module); +} +EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( // + const MODULE* module // N +) { + return vec_znx_normalize_base2k_tmp_bytes_ref(module); +} +#else +EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( // + const MODULE* module // N +) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref"))); + +EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( // + const MODULE* module // N +) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref"))); +#endif + +/** @brief sets res = 0 */ +EXPORT void vec_znx_zero(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl // res +) { + module->func.vec_znx_zero(module, res, res_size, res_sl); +} + +/** @brief sets res = a */ +EXPORT void vec_znx_copy(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->func.vec_znx_copy(module, res, res_size, res_sl, a, a_size, a_sl); +} + +/** @brief sets res = a */ +EXPORT void vec_znx_negate(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->func.vec_znx_negate(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_znx_zero_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl // res +) { + uint64_t nn = module->nn; + for (uint64_t i = 0; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT void vec_znx_copy_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + uint64_t nn = module->nn; + uint64_t smin = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < smin; ++i) { + znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl); + } + for (uint64_t i = smin; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT void vec_znx_negate_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + uint64_t nn = module->nn; + uint64_t smin = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < smin; ++i) { + znx_negate_i64_ref(nn, res + i * res_sl, a + i * a_sl); + } + for (uint64_t i = smin; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} diff --git a/spqlios/lib/spqlios/arithmetic/vec_znx_arithmetic.h b/spqlios/lib/spqlios/arithmetic/vec_znx_arithmetic.h new file mode 100644 index 0000000..b93a571 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_znx_arithmetic.h @@ -0,0 +1,357 @@ +#ifndef SPQLIOS_VEC_ZNX_ARITHMETIC_H +#define SPQLIOS_VEC_ZNX_ARITHMETIC_H + +#include + +#include "../commons.h" +#include "../reim/reim_fft.h" + +/** + * We support the following module families: + * - FFT64: + * all the polynomials should fit at all times over 52 bits. + * for FHE implementations, the recommended limb-sizes are + * between K=10 and 20, which is good for low multiplicative depths. + * - NTT120: + * all the polynomials should fit at all times over 119 bits. + * for FHE implementations, the recommended limb-sizes are + * between K=20 and 40, which is good for large multiplicative depths. + */ +typedef enum module_type_t { FFT64, NTT120 } MODULE_TYPE; + +/** @brief opaque structure that describr the modules (ZnX,TnX) and the hardware */ +typedef struct module_info_t MODULE; +/** @brief opaque type that represents a prepared matrix */ +typedef struct vmp_pmat_t VMP_PMAT; +/** @brief opaque type that represents a vector of znx in DFT space */ +typedef struct vec_znx_dft_t VEC_ZNX_DFT; +/** @brief opaque type that represents a vector of znx in large coeffs space */ +typedef struct vec_znx_bigcoeff_t VEC_ZNX_BIG; +/** @brief opaque type that represents a prepared scalar vector product */ +typedef struct svp_ppol_t SVP_PPOL; +/** @brief opaque type that represents a prepared left convolution vector product */ +typedef struct cnv_pvec_l_t CNV_PVEC_L; +/** @brief opaque type that represents a prepared right convolution vector product */ +typedef struct cnv_pvec_r_t CNV_PVEC_R; + +/** @brief bytes needed for a vec_znx in DFT space */ +EXPORT uint64_t bytes_of_vec_znx_dft(const MODULE* module, // N + uint64_t size); + +/** @brief allocates a vec_znx in DFT space */ +EXPORT VEC_ZNX_DFT* new_vec_znx_dft(const MODULE* module, // N + uint64_t size); + +/** @brief frees memory from a vec_znx in DFT space */ +EXPORT void delete_vec_znx_dft(VEC_ZNX_DFT* res); + +/** @brief bytes needed for a vec_znx_big */ +EXPORT uint64_t bytes_of_vec_znx_big(const MODULE* module, // N + uint64_t size); + +/** @brief allocates a vec_znx_big */ +EXPORT VEC_ZNX_BIG* new_vec_znx_big(const MODULE* module, // N + uint64_t size); +/** @brief frees memory from a vec_znx_big */ +EXPORT void delete_vec_znx_big(VEC_ZNX_BIG* res); + +/** @brief bytes needed for a prepared vector */ +EXPORT uint64_t bytes_of_svp_ppol(const MODULE* module); // N + +/** @brief allocates a prepared vector */ +EXPORT SVP_PPOL* new_svp_ppol(const MODULE* module); // N + +/** @brief frees memory for a prepared vector */ +EXPORT void delete_svp_ppol(SVP_PPOL* res); + +/** @brief bytes needed for a prepared matrix */ +EXPORT uint64_t bytes_of_vmp_pmat(const MODULE* module, // N + uint64_t nrows, uint64_t ncols); + +/** @brief allocates a prepared matrix */ +EXPORT VMP_PMAT* new_vmp_pmat(const MODULE* module, // N + uint64_t nrows, uint64_t ncols); + +/** @brief frees memory for a prepared matrix */ +EXPORT void delete_vmp_pmat(VMP_PMAT* res); + +/** + * @brief obtain a module info for ring dimension N + * the module-info knows about: + * - the dimension N (or the complex dimension m=N/2) + * - any moduleuted fft or ntt items + * - the hardware (avx, arm64, x86, ...) + */ +EXPORT MODULE* new_module_info(uint64_t N, MODULE_TYPE mode); +EXPORT void delete_module_info(MODULE* module_info); +EXPORT uint64_t module_get_n(const MODULE* module); + +/** @brief sets res = 0 */ +EXPORT void vec_znx_zero(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl // res +); + +/** @brief sets res = a */ +EXPORT void vec_znx_copy(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a */ +EXPORT void vec_znx_negate(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a + b */ +EXPORT void vec_znx_add(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a - b */ +EXPORT void vec_znx_sub(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = k-normalize-reduce(a) */ +EXPORT void vec_znx_normalize_base2k(const MODULE* module, // N + uint64_t log2_base2k, // output base 2^K + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + uint8_t* tmp_space // scratch space (size >= N) +); + +/** @brief returns the minimal byte length of scratch space for vec_znx_normalize_base2k */ +EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module // N +); + +/** @brief sets res = a . X^p */ +EXPORT void vec_znx_rotate(const MODULE* module, // N + const int64_t p, // rotation value + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a(X^p) */ +EXPORT void vec_znx_automorphism(const MODULE* module, // N + const int64_t p, // X-X^p + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void vmp_prepare_contiguous(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief prepares a vmp matrix (mat[row*ncols+col] points to the item) */ +EXPORT void vmp_prepare_dblptr(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief sets res = 0 */ +EXPORT void vec_dft_zero(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size // res +); + +/** @brief sets res = a+b */ +EXPORT void vec_dft_add(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a, uint64_t a_size, // a + const VEC_ZNX_DFT* b, uint64_t b_size // b +); + +/** @brief sets res = a-b */ +EXPORT void vec_dft_sub(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a, uint64_t a_size, // a + const VEC_ZNX_DFT* b, uint64_t b_size // b +); + +/** @brief sets res = DFT(a) */ +EXPORT void vec_znx_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = iDFT(a_dft) -- output in big coeffs space */ +EXPORT void vec_znx_idft(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + uint8_t* tmp // scratch space +); + +/** @brief tmp bytes required for vec_znx_idft */ +EXPORT uint64_t vec_znx_idft_tmp_bytes(const MODULE* module); + +/** + * @brief sets res = iDFT(a_dft) -- output in big coeffs space + * + * @note a_dft is overwritten + */ +EXPORT void vec_znx_idft_tmp_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten +); + +/** @brief sets res = a+b */ +EXPORT void vec_znx_big_add(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +); +/** @brief sets res = a+b */ +EXPORT void vec_znx_big_add_small(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); +EXPORT void vec_znx_big_add_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a-b */ +EXPORT void vec_znx_big_sub(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +); +EXPORT void vec_znx_big_sub_small_b(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); +EXPORT void vec_znx_big_sub_small_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +); +EXPORT void vec_znx_big_sub_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = k-normalize(a) -- output in int64 coeffs space */ +EXPORT void vec_znx_big_normalize_base2k(const MODULE* module, // N + uint64_t log2_base2k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + uint8_t* tmp_space // temp space +); + +/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */ +EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module // N +); + +/** @brief apply a svp product, result = ppol * a, presented in DFT space */ +EXPORT void fft64_svp_apply_dft(const MODULE* module, // N + const VEC_ZNX_DFT* res, uint64_t res_size, // output + const SVP_PPOL* ppol, // prepared pol + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */ +EXPORT void vec_znx_big_range_normalize_base2k( // + const MODULE* module, // N + uint64_t log2_base2k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_range_begin, uint64_t a_range_xend, uint64_t a_range_step, // range + uint8_t* tmp_space // temp space +); + +/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */ +EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( // + const MODULE* module // N +); + +/** @brief sets res = a . X^p */ +EXPORT void vec_znx_big_rotate(const MODULE* module, // N + int64_t p, // rotation value + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +); + +/** @brief sets res = a(X^p) */ +EXPORT void vec_znx_big_automorphism(const MODULE* module, // N + int64_t p, // X-X^p + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +); + +/** @brief apply a svp product, result = ppol * a, presented in DFT space */ +EXPORT void svp_apply_dft(const MODULE* module, // N + const VEC_ZNX_DFT* res, uint64_t res_size, // output + const SVP_PPOL* ppol, // prepared pol + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief prepares a svp polynomial */ +EXPORT void svp_prepare(const MODULE* module, // N + SVP_PPOL* ppol, // output + const int64_t* pol // a +); + +/** @brief res = a * b : small integer polynomial product */ +EXPORT void znx_small_single_product(const MODULE* module, // N + int64_t* res, // output + const int64_t* a, // a + const int64_t* b, // b + uint8_t* tmp); + +/** @brief tmp bytes required for znx_small_single_product */ +EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module); + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void vmp_prepare_contiguous(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief minimal scratch space byte-size required for the vmp_prepare function */ +EXPORT uint64_t vmp_prepare_contiguous_tmp_bytes(const MODULE* module, // N + uint64_t nrows, uint64_t ncols); + +/** @brief applies a vmp product (result in DFT space) */ +EXPORT void vmp_apply_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t vmp_apply_dft_tmp_bytes(const MODULE* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); + +/** @brief minimal size of the tmp_space */ +EXPORT void vmp_apply_dft_to_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, + const uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); +; + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); +#endif // SPQLIOS_VEC_ZNX_ARITHMETIC_H diff --git a/spqlios/lib/spqlios/arithmetic/vec_znx_arithmetic_private.h b/spqlios/lib/spqlios/arithmetic/vec_znx_arithmetic_private.h new file mode 100644 index 0000000..528dfad --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_znx_arithmetic_private.h @@ -0,0 +1,481 @@ +#ifndef SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H +#define SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H + +#include "../commons_private.h" +#include "../q120/q120_ntt.h" +#include "vec_znx_arithmetic.h" + +/** + * Layouts families: + * + * fft64: + * K: <= 20, N: <= 65536, ell: <= 200 + * vec normalized: represented by int64 + * vec large: represented by int64 (expect <=52 bits) + * vec DFT: represented by double (reim_fft space) + * On AVX2 inftastructure, PMAT, LCNV, RCNV use a special reim4_fft space + * + * ntt120: + * K: <= 50, N: <= 65536, ell: <= 80 + * vec normalized: represented by int64 + * vec large: represented by int128 (expect <=120 bits) + * vec DFT: represented by int64x4 (ntt120 space) + * On AVX2 inftastructure, PMAT, LCNV, RCNV use a special ntt120 space + * + * ntt104: + * K: <= 40, N: <= 65536, ell: <= 80 + * vec normalized: represented by int64 + * vec large: represented by int128 (expect <=120 bits) + * vec DFT: represented by int64x4 (ntt120 space) + * On AVX512 inftastructure, PMAT, LCNV, RCNV use a special ntt104 space + */ + +struct fft64_module_info_t { + // pre-computation for reim_fft + REIM_FFT_PRECOMP* p_fft; + // pre-computation for mul_fft + REIM_FFTVEC_MUL_PRECOMP* mul_fft; + // pre-computation for reim_from_znx6 + REIM_FROM_ZNX64_PRECOMP* p_conv; + // pre-computation for reim_tp_znx6 + REIM_TO_ZNX64_PRECOMP* p_reim_to_znx; + // pre-computation for reim_fft + REIM_IFFT_PRECOMP* p_ifft; + // pre-computation for reim_fftvec_addmul + REIM_FFTVEC_ADDMUL_PRECOMP* p_addmul; +}; + +struct q120_module_info_t { + // pre-computation for q120b to q120b ntt + q120_ntt_precomp* p_ntt; + // pre-computation for q120b to q120b intt + q120_ntt_precomp* p_intt; +}; + +// TODO add function types here +typedef typeof(vec_znx_zero) VEC_ZNX_ZERO_F; +typedef typeof(vec_znx_copy) VEC_ZNX_COPY_F; +typedef typeof(vec_znx_negate) VEC_ZNX_NEGATE_F; +typedef typeof(vec_znx_add) VEC_ZNX_ADD_F; +typedef typeof(vec_znx_dft) VEC_ZNX_DFT_F; +typedef typeof(vec_znx_idft) VEC_ZNX_IDFT_F; +typedef typeof(vec_znx_idft_tmp_bytes) VEC_ZNX_IDFT_TMP_BYTES_F; +typedef typeof(vec_znx_idft_tmp_a) VEC_ZNX_IDFT_TMP_A_F; +typedef typeof(vec_znx_sub) VEC_ZNX_SUB_F; +typedef typeof(vec_znx_rotate) VEC_ZNX_ROTATE_F; +typedef typeof(vec_znx_automorphism) VEC_ZNX_AUTOMORPHISM_F; +typedef typeof(vec_znx_normalize_base2k) VEC_ZNX_NORMALIZE_BASE2K_F; +typedef typeof(vec_znx_normalize_base2k_tmp_bytes) VEC_ZNX_NORMALIZE_BASE2K_TMP_BYTES_F; +typedef typeof(vec_znx_big_normalize_base2k) VEC_ZNX_BIG_NORMALIZE_BASE2K_F; +typedef typeof(vec_znx_big_normalize_base2k_tmp_bytes) VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F; +typedef typeof(vec_znx_big_range_normalize_base2k) VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F; +typedef typeof(vec_znx_big_range_normalize_base2k_tmp_bytes) VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F; +typedef typeof(vec_znx_big_add) VEC_ZNX_BIG_ADD_F; +typedef typeof(vec_znx_big_add_small) VEC_ZNX_BIG_ADD_SMALL_F; +typedef typeof(vec_znx_big_add_small2) VEC_ZNX_BIG_ADD_SMALL2_F; +typedef typeof(vec_znx_big_sub) VEC_ZNX_BIG_SUB_F; +typedef typeof(vec_znx_big_sub_small_a) VEC_ZNX_BIG_SUB_SMALL_A_F; +typedef typeof(vec_znx_big_sub_small_b) VEC_ZNX_BIG_SUB_SMALL_B_F; +typedef typeof(vec_znx_big_sub_small2) VEC_ZNX_BIG_SUB_SMALL2_F; +typedef typeof(vec_znx_big_rotate) VEC_ZNX_BIG_ROTATE_F; +typedef typeof(vec_znx_big_automorphism) VEC_ZNX_BIG_AUTOMORPHISM_F; +typedef typeof(svp_prepare) SVP_PREPARE; +typedef typeof(svp_apply_dft) SVP_APPLY_DFT_F; +typedef typeof(znx_small_single_product) ZNX_SMALL_SINGLE_PRODUCT_F; +typedef typeof(znx_small_single_product_tmp_bytes) ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; +typedef typeof(vmp_prepare_contiguous) VMP_PREPARE_CONTIGUOUS_F; +typedef typeof(vmp_prepare_contiguous_tmp_bytes) VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F; +typedef typeof(vmp_apply_dft) VMP_APPLY_DFT_F; +typedef typeof(vmp_apply_dft_tmp_bytes) VMP_APPLY_DFT_TMP_BYTES_F; +typedef typeof(vmp_apply_dft_to_dft) VMP_APPLY_DFT_TO_DFT_F; +typedef typeof(vmp_apply_dft_to_dft_tmp_bytes) VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F; +typedef typeof(bytes_of_vec_znx_dft) BYTES_OF_VEC_ZNX_DFT_F; +typedef typeof(bytes_of_vec_znx_big) BYTES_OF_VEC_ZNX_BIG_F; +typedef typeof(bytes_of_svp_ppol) BYTES_OF_SVP_PPOL_F; +typedef typeof(bytes_of_vmp_pmat) BYTES_OF_VMP_PMAT_F; + +struct module_virtual_functions_t { + // TODO add functions here + VEC_ZNX_ZERO_F* vec_znx_zero; + VEC_ZNX_COPY_F* vec_znx_copy; + VEC_ZNX_NEGATE_F* vec_znx_negate; + VEC_ZNX_ADD_F* vec_znx_add; + VEC_ZNX_DFT_F* vec_znx_dft; + VEC_ZNX_IDFT_F* vec_znx_idft; + VEC_ZNX_IDFT_TMP_BYTES_F* vec_znx_idft_tmp_bytes; + VEC_ZNX_IDFT_TMP_A_F* vec_znx_idft_tmp_a; + VEC_ZNX_SUB_F* vec_znx_sub; + VEC_ZNX_ROTATE_F* vec_znx_rotate; + VEC_ZNX_AUTOMORPHISM_F* vec_znx_automorphism; + VEC_ZNX_NORMALIZE_BASE2K_F* vec_znx_normalize_base2k; + VEC_ZNX_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_normalize_base2k_tmp_bytes; + VEC_ZNX_BIG_NORMALIZE_BASE2K_F* vec_znx_big_normalize_base2k; + VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_big_normalize_base2k_tmp_bytes; + VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F* vec_znx_big_range_normalize_base2k; + VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_big_range_normalize_base2k_tmp_bytes; + VEC_ZNX_BIG_ADD_F* vec_znx_big_add; + VEC_ZNX_BIG_ADD_SMALL_F* vec_znx_big_add_small; + VEC_ZNX_BIG_ADD_SMALL2_F* vec_znx_big_add_small2; + VEC_ZNX_BIG_SUB_F* vec_znx_big_sub; + VEC_ZNX_BIG_SUB_SMALL_A_F* vec_znx_big_sub_small_a; + VEC_ZNX_BIG_SUB_SMALL_B_F* vec_znx_big_sub_small_b; + VEC_ZNX_BIG_SUB_SMALL2_F* vec_znx_big_sub_small2; + VEC_ZNX_BIG_ROTATE_F* vec_znx_big_rotate; + VEC_ZNX_BIG_AUTOMORPHISM_F* vec_znx_big_automorphism; + SVP_PREPARE* svp_prepare; + SVP_APPLY_DFT_F* svp_apply_dft; + ZNX_SMALL_SINGLE_PRODUCT_F* znx_small_single_product; + ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx_small_single_product_tmp_bytes; + VMP_PREPARE_CONTIGUOUS_F* vmp_prepare_contiguous; + VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F* vmp_prepare_contiguous_tmp_bytes; + VMP_APPLY_DFT_F* vmp_apply_dft; + VMP_APPLY_DFT_TMP_BYTES_F* vmp_apply_dft_tmp_bytes; + VMP_APPLY_DFT_TO_DFT_F* vmp_apply_dft_to_dft; + VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* vmp_apply_dft_to_dft_tmp_bytes; + BYTES_OF_VEC_ZNX_DFT_F* bytes_of_vec_znx_dft; + BYTES_OF_VEC_ZNX_BIG_F* bytes_of_vec_znx_big; + BYTES_OF_SVP_PPOL_F* bytes_of_svp_ppol; + BYTES_OF_VMP_PMAT_F* bytes_of_vmp_pmat; +}; + +union backend_module_info_t { + struct fft64_module_info_t fft64; + struct q120_module_info_t q120; +}; + +struct module_info_t { + // generic parameters + MODULE_TYPE module_type; + uint64_t nn; + uint64_t m; + // backend_dependent functions + union backend_module_info_t mod; + // virtual functions + struct module_virtual_functions_t func; +}; + +EXPORT uint64_t fft64_bytes_of_vec_znx_dft(const MODULE* module, // N + uint64_t size); + +EXPORT uint64_t fft64_bytes_of_vec_znx_big(const MODULE* module, // N + uint64_t size); + +EXPORT uint64_t fft64_bytes_of_svp_ppol(const MODULE* module); // N + +EXPORT uint64_t fft64_bytes_of_vmp_pmat(const MODULE* module, // N + uint64_t nrows, uint64_t ncols); + +EXPORT void vec_znx_zero_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl // res +); + +EXPORT void vec_znx_copy_ref(const MODULE* precomp, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_znx_negate_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_znx_negate_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_znx_add_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); +EXPORT void vec_znx_add_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +EXPORT void vec_znx_sub_ref(const MODULE* precomp, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +EXPORT void vec_znx_sub_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module, // N + uint64_t log2_base2k, // output base 2^K + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // inp + uint8_t* tmp_space // scratch space +); + +EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module // N +); + +EXPORT void vec_znx_rotate_ref(const MODULE* module, // N + const int64_t p, // rotation value + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_znx_automorphism_ref(const MODULE* module, // N + const int64_t p, // X->X^p + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vmp_prepare_ref(const MODULE* precomp, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols // a +); + +EXPORT void vmp_apply_dft_ref(const MODULE* precomp, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols // prep matrix +); + +EXPORT void vec_dft_zero_ref(const MODULE* precomp, // N + VEC_ZNX_DFT* res, uint64_t res_size // res +); + +EXPORT void vec_dft_add_ref(const MODULE* precomp, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a, uint64_t a_size, // a + const VEC_ZNX_DFT* b, uint64_t b_size // b +); + +EXPORT void vec_dft_sub_ref(const MODULE* precomp, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a, uint64_t a_size, // a + const VEC_ZNX_DFT* b, uint64_t b_size // b +); + +EXPORT void vec_dft_ref(const MODULE* precomp, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_idft_ref(const MODULE* precomp, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size); + +EXPORT void vec_znx_big_normalize_ref(const MODULE* precomp, // N + uint64_t k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +); + +/** @brief apply a svp product, result = ppol * a, presented in DFT space */ +EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N + const VEC_ZNX_DFT* res, uint64_t res_size, // output + const SVP_PPOL* ppol, // prepared pol + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = k-normalize(a) -- output in int64 coeffs space */ +EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module, // N + uint64_t k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + uint8_t* tmp_space // temp space +); + +/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */ +EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module // N + +); + +/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */ +EXPORT void fft64_vec_znx_big_range_normalize_base2k(const MODULE* module, // N + uint64_t log2_base2k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_range_begin, // a + uint64_t a_range_xend, uint64_t a_range_step, // range + uint8_t* tmp_space // temp space +); + +/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */ +EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes(const MODULE* module // N +); + +EXPORT void fft64_vec_znx_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void fft64_vec_znx_idft(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + uint8_t* tmp // scratch space +); + +EXPORT uint64_t fft64_vec_znx_idft_tmp_bytes(const MODULE* module); + +EXPORT void fft64_vec_znx_idft_tmp_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten +); + +EXPORT void ntt120_vec_znx_dft_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** */ +EXPORT void ntt120_vec_znx_idft_avx(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + uint8_t* tmp // scratch space +); + +EXPORT uint64_t ntt120_vec_znx_idft_tmp_bytes_avx(const MODULE* module); + +EXPORT void ntt120_vec_znx_idft_tmp_a_avx(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten +); + +// big additions/subtractions + +/** @brief sets res = a+b */ +EXPORT void fft64_vec_znx_big_add(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +); +/** @brief sets res = a+b */ +EXPORT void fft64_vec_znx_big_add_small(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); +EXPORT void fft64_vec_znx_big_add_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a-b */ +EXPORT void fft64_vec_znx_big_sub(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +); +EXPORT void fft64_vec_znx_big_sub_small_b(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); +EXPORT void fft64_vec_znx_big_sub_small_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +); +EXPORT void fft64_vec_znx_big_sub_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a . X^p */ +EXPORT void fft64_vec_znx_big_rotate(const MODULE* module, // N + int64_t p, // rotation value + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +); + +/** @brief sets res = a(X^p) */ +EXPORT void fft64_vec_znx_big_automorphism(const MODULE* module, // N + int64_t p, // X-X^p + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +); + +/** @brief prepares a svp polynomial */ +EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N + SVP_PPOL* ppol, // output + const int64_t* pol // a +); + +/** @brief res = a * b : small integer polynomial product */ +EXPORT void fft64_znx_small_single_product(const MODULE* module, // N + int64_t* res, // output + const int64_t* a, // a + const int64_t* b, // b + uint8_t* tmp); + +/** @brief tmp bytes required for znx_small_single_product */ +EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module); + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief minimal scratch space byte-size required for the vmp_prepare function */ +EXPORT uint64_t fft64_vmp_prepare_contiguous_tmp_bytes(const MODULE* module, // N + uint64_t nrows, uint64_t ncols); + +/** @brief applies a vmp product (result in DFT space) */ +EXPORT void fft64_vmp_apply_dft_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); + +/** @brief applies a vmp product (result in DFT space) */ +EXPORT void fft64_vmp_apply_dft_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); + +/** @brief this inner function could be very handy */ +EXPORT void fft64_vmp_apply_dft_to_dft_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, + const uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); + +/** @brief this inner function could be very handy */ +EXPORT void fft64_vmp_apply_dft_to_dft_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, + const uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t fft64_vmp_apply_dft_tmp_bytes(const MODULE* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t fft64_vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); +#endif // SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H diff --git a/spqlios/lib/spqlios/arithmetic/vec_znx_avx.c b/spqlios/lib/spqlios/arithmetic/vec_znx_avx.c new file mode 100644 index 0000000..100902d --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_znx_avx.c @@ -0,0 +1,103 @@ +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "../reim4/reim4_arithmetic.h" +#include "vec_znx_arithmetic_private.h" + +// specialized function (ref) + +// Note: these functions do not have an avx variant. +#define znx_copy_i64_avx znx_copy_i64_ref +#define znx_zero_i64_avx znx_zero_i64_ref + +EXPORT void vec_znx_add_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->nn; + if (a_size <= b_size) { + const uint64_t sum_idx = res_size < a_size ? res_size : a_size; + const uint64_t copy_idx = res_size < b_size ? res_size : b_size; + // add up to the smallest dimension + for (uint64_t i = 0; i < sum_idx; ++i) { + znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sum_idx; i < copy_idx; ++i) { + znx_copy_i64_avx(nn, res + i * res_sl, b + i * b_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_avx(nn, res + i * res_sl); + } + } else { + const uint64_t sum_idx = res_size < b_size ? res_size : b_size; + const uint64_t copy_idx = res_size < a_size ? res_size : a_size; + // add up to the smallest dimension + for (uint64_t i = 0; i < sum_idx; ++i) { + znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sum_idx; i < copy_idx; ++i) { + znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_avx(nn, res + i * res_sl); + } + } +} + +EXPORT void vec_znx_sub_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->nn; + if (a_size <= b_size) { + const uint64_t sub_idx = res_size < a_size ? res_size : a_size; + const uint64_t copy_idx = res_size < b_size ? res_size : b_size; + // subtract up to the smallest dimension + for (uint64_t i = 0; i < sub_idx; ++i) { + znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then negate to the largest dimension + for (uint64_t i = sub_idx; i < copy_idx; ++i) { + znx_negate_i64_avx(nn, res + i * res_sl, b + i * b_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_avx(nn, res + i * res_sl); + } + } else { + const uint64_t sub_idx = res_size < b_size ? res_size : b_size; + const uint64_t copy_idx = res_size < a_size ? res_size : a_size; + // subtract up to the smallest dimension + for (uint64_t i = 0; i < sub_idx; ++i) { + znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sub_idx; i < copy_idx; ++i) { + znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_avx(nn, res + i * res_sl); + } + } +} + +EXPORT void vec_znx_negate_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + uint64_t nn = module->nn; + uint64_t smin = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < smin; ++i) { + znx_negate_i64_avx(nn, res + i * res_sl, a + i * a_sl); + } + for (uint64_t i = smin; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} diff --git a/spqlios/lib/spqlios/arithmetic/vec_znx_big.c b/spqlios/lib/spqlios/arithmetic/vec_znx_big.c new file mode 100644 index 0000000..923703c --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_znx_big.c @@ -0,0 +1,270 @@ +#include "vec_znx_arithmetic_private.h" + +EXPORT uint64_t bytes_of_vec_znx_big(const MODULE* module, // N + uint64_t size) { + return module->func.bytes_of_vec_znx_big(module, size); +} + +// public wrappers + +/** @brief sets res = a+b */ +EXPORT void vec_znx_big_add(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +) { + module->func.vec_znx_big_add(module, res, res_size, a, a_size, b, b_size); +} + +/** @brief sets res = a+b */ +EXPORT void vec_znx_big_add_small(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + module->func.vec_znx_big_add_small(module, res, res_size, a, a_size, b, b_size, b_sl); +} + +EXPORT void vec_znx_big_add_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + module->func.vec_znx_big_add_small2(module, res, res_size, a, a_size, a_sl, b, b_size, b_sl); +} + +/** @brief sets res = a-b */ +EXPORT void vec_znx_big_sub(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +) { + module->func.vec_znx_big_sub(module, res, res_size, a, a_size, b, b_size); +} + +EXPORT void vec_znx_big_sub_small_b(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + module->func.vec_znx_big_sub_small_b(module, res, res_size, a, a_size, b, b_size, b_sl); +} + +EXPORT void vec_znx_big_sub_small_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +) { + module->func.vec_znx_big_sub_small_a(module, res, res_size, a, a_size, a_sl, b, b_size); +} +EXPORT void vec_znx_big_sub_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + module->func.vec_znx_big_sub_small2(module, res, res_size, a, a_size, a_sl, b, b_size, b_sl); +} + +/** @brief sets res = a . X^p */ +EXPORT void vec_znx_big_rotate(const MODULE* module, // N + int64_t p, // rotation value + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +) { + module->func.vec_znx_big_rotate(module, p, res, res_size, a, a_size); +} + +/** @brief sets res = a(X^p) */ +EXPORT void vec_znx_big_automorphism(const MODULE* module, // N + int64_t p, // X-X^p + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +) { + module->func.vec_znx_big_automorphism(module, p, res, res_size, a, a_size); +} + +// private wrappers + +EXPORT uint64_t fft64_bytes_of_vec_znx_big(const MODULE* module, // N + uint64_t size) { + return module->nn * size * sizeof(double); +} + +EXPORT VEC_ZNX_BIG* new_vec_znx_big(const MODULE* module, // N + uint64_t size) { + return spqlios_alloc(bytes_of_vec_znx_big(module, size)); +} + +EXPORT void delete_vec_znx_big(VEC_ZNX_BIG* res) { spqlios_free(res); } + +/** @brief sets res = a+b */ +EXPORT void fft64_vec_znx_big_add(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +) { + const uint64_t n = module->nn; + vec_znx_add(module, // + (int64_t*)res, res_size, n, // + (int64_t*)a, a_size, n, // + (int64_t*)b, b_size, n); +} +/** @brief sets res = a+b */ +EXPORT void fft64_vec_znx_big_add_small(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t n = module->nn; + vec_znx_add(module, // + (int64_t*)res, res_size, n, // + (int64_t*)a, a_size, n, // + b, b_size, b_sl); +} +EXPORT void fft64_vec_znx_big_add_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t n = module->nn; + vec_znx_add(module, // + (int64_t*)res, res_size, n, // + a, a_size, a_sl, // + b, b_size, b_sl); +} + +/** @brief sets res = a-b */ +EXPORT void fft64_vec_znx_big_sub(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +) { + const uint64_t n = module->nn; + vec_znx_sub(module, // + (int64_t*)res, res_size, n, // + (int64_t*)a, a_size, n, // + (int64_t*)b, b_size, n); +} + +EXPORT void fft64_vec_znx_big_sub_small_b(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t n = module->nn; + vec_znx_sub(module, // + (int64_t*)res, res_size, n, // + (int64_t*)a, a_size, // + n, b, b_size, b_sl); +} +EXPORT void fft64_vec_znx_big_sub_small_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +) { + const uint64_t n = module->nn; + vec_znx_sub(module, // + (int64_t*)res, res_size, n, // + a, a_size, a_sl, // + (int64_t*)b, b_size, n); +} +EXPORT void fft64_vec_znx_big_sub_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t n = module->nn; + vec_znx_sub(module, // + (int64_t*)res, res_size, // + n, a, a_size, // + a_sl, b, b_size, b_sl); +} + +/** @brief sets res = a . X^p */ +EXPORT void fft64_vec_znx_big_rotate(const MODULE* module, // N + int64_t p, // rotation value + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +) { + uint64_t nn = module->nn; + vec_znx_rotate(module, p, (int64_t*)res, res_size, nn, (int64_t*)a, a_size, nn); +} + +/** @brief sets res = a(X^p) */ +EXPORT void fft64_vec_znx_big_automorphism(const MODULE* module, // N + int64_t p, // X-X^p + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +) { + uint64_t nn = module->nn; + vec_znx_automorphism(module, p, (int64_t*)res, res_size, nn, (int64_t*)a, a_size, nn); +} + +EXPORT void vec_znx_big_normalize_base2k(const MODULE* module, // N + uint64_t k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + uint8_t* tmp_space // temp space +) { + module->func.vec_znx_big_normalize_base2k(module, // N + k, // base-2^k + res, res_size, res_sl, // res + a, a_size, // a + tmp_space); +} + +EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module // N +) { + return module->func.vec_znx_big_normalize_base2k_tmp_bytes(module // N + ); +} + +/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */ +EXPORT void vec_znx_big_range_normalize_base2k( // + const MODULE* module, // N + uint64_t log2_base2k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_range_begin, uint64_t a_range_xend, uint64_t a_range_step, // range + uint8_t* tmp_space // temp space +) { + module->func.vec_znx_big_range_normalize_base2k(module, log2_base2k, res, res_size, res_sl, a, a_range_begin, + a_range_xend, a_range_step, tmp_space); +} + +/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */ +EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( // + const MODULE* module // N +) { + return module->func.vec_znx_big_range_normalize_base2k_tmp_bytes(module); +} + +EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module, // N + uint64_t k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + uint8_t* tmp_space) { + uint64_t a_sl = module->nn; + module->func.vec_znx_normalize_base2k(module, // N + k, // log2_base2k + res, res_size, res_sl, // res + (int64_t*)a, a_size, a_sl, // a + tmp_space); +} + +EXPORT void fft64_vec_znx_big_range_normalize_base2k( // + const MODULE* module, // N + uint64_t k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_begin, uint64_t a_end, uint64_t a_step, // a + uint8_t* tmp_space) { + // convert the range indexes to int64[] slices + const int64_t* a_st = ((int64_t*)a) + module->nn * a_begin; + const uint64_t a_size = (a_end + a_step - 1 - a_begin) / a_step; + const uint64_t a_sl = module->nn * a_step; + // forward the call + module->func.vec_znx_normalize_base2k(module, // N + k, // log2_base2k + res, res_size, res_sl, // res + a_st, a_size, a_sl, // a + tmp_space); +} diff --git a/spqlios/lib/spqlios/arithmetic/vec_znx_dft.c b/spqlios/lib/spqlios/arithmetic/vec_znx_dft.c new file mode 100644 index 0000000..16b3a9e --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_znx_dft.c @@ -0,0 +1,162 @@ +#include + +#include "../q120/q120_arithmetic.h" +#include "vec_znx_arithmetic_private.h" + +EXPORT void vec_znx_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + return module->func.vec_znx_dft(module, res, res_size, a, a_size, a_sl); +} + +EXPORT void vec_znx_idft(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + uint8_t* tmp // scratch space +) { + return module->func.vec_znx_idft(module, res, res_size, a_dft, a_size, tmp); +} + +EXPORT uint64_t vec_znx_idft_tmp_bytes(const MODULE* module) { return module->func.vec_znx_idft_tmp_bytes(module); } + +EXPORT void vec_znx_idft_tmp_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten +) { + return module->func.vec_znx_idft_tmp_a(module, res, res_size, a_dft, a_size); +} + +EXPORT uint64_t bytes_of_vec_znx_dft(const MODULE* module, // N + uint64_t size) { + return module->func.bytes_of_vec_znx_dft(module, size); +} + +// fft64 backend +EXPORT uint64_t fft64_bytes_of_vec_znx_dft(const MODULE* module, // N + uint64_t size) { + return module->nn * size * sizeof(double); +} + +EXPORT VEC_ZNX_DFT* new_vec_znx_dft(const MODULE* module, // N + uint64_t size) { + return spqlios_alloc(bytes_of_vec_znx_dft(module, size)); +} + +EXPORT void delete_vec_znx_dft(VEC_ZNX_DFT* res) { spqlios_free(res); } + +EXPORT void fft64_vec_znx_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t smin = res_size < a_size ? res_size : a_size; + const uint64_t nn = module->nn; + + for (uint64_t i = 0; i < smin; i++) { + reim_from_znx64(module->mod.fft64.p_conv, ((double*)res) + i * nn, a + i * a_sl); + reim_fft(module->mod.fft64.p_fft, ((double*)res) + i * nn); + } + + // fill up remaining part with 0's + double* const dres = (double*)res; + memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double)); +} + +EXPORT void fft64_vec_znx_idft(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + uint8_t* tmp // unused +) { + const uint64_t nn = module->nn; + const uint64_t smin = res_size < a_size ? res_size : a_size; + if ((double*)res != (double*)a_dft) { + memcpy(res, a_dft, smin * nn * sizeof(double)); + } + + for (uint64_t i = 0; i < smin; i++) { + reim_ifft(module->mod.fft64.p_ifft, ((double*)res) + i * nn); + reim_to_znx64(module->mod.fft64.p_reim_to_znx, ((int64_t*)res) + i * nn, ((int64_t*)res) + i * nn); + } + + // fill up remaining part with 0's + int64_t* const dres = (int64_t*)res; + memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double)); +} + +EXPORT uint64_t fft64_vec_znx_idft_tmp_bytes(const MODULE* module) { return 0; } + +EXPORT void fft64_vec_znx_idft_tmp_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten +) { + const uint64_t nn = module->nn; + const uint64_t smin = res_size < a_size ? res_size : a_size; + + int64_t* const tres = (int64_t*)res; + double* const ta = (double*)a_dft; + for (uint64_t i = 0; i < smin; i++) { + reim_ifft(module->mod.fft64.p_ifft, ta + i * nn); + reim_to_znx64(module->mod.fft64.p_reim_to_znx, tres + i * nn, ta + i * nn); + } + + // fill up remaining part with 0's + memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(double)); +} + +// ntt120 backend + +EXPORT void ntt120_vec_znx_dft_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->nn; + const uint64_t smin = res_size < a_size ? res_size : a_size; + + int64_t* tres = (int64_t*)res; + for (uint64_t i = 0; i < smin; i++) { + q120_b_from_znx64_simple(nn, (q120b*)(tres + i * nn * 4), a + i * a_sl); + q120_ntt_bb_avx2(module->mod.q120.p_ntt, (q120b*)(tres + i * nn * 4)); + } + + // fill up remaining part with 0's + memset(tres + smin * nn * 4, 0, (res_size - smin) * nn * 4 * sizeof(int64_t)); +} + +EXPORT void ntt120_vec_znx_idft_avx(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + uint8_t* tmp) { + const uint64_t nn = module->nn; + const uint64_t smin = res_size < a_size ? res_size : a_size; + + __int128_t* const tres = (__int128_t*)res; + const int64_t* const ta = (int64_t*)a_dft; + for (uint64_t i = 0; i < smin; i++) { + memcpy(tmp, ta + i * nn * 4, nn * 4 * sizeof(uint64_t)); + q120_intt_bb_avx2(module->mod.q120.p_intt, (q120b*)tmp); + q120_b_to_znx128_simple(nn, tres + i * nn, (q120b*)tmp); + } + + // fill up remaining part with 0's + memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(*tres)); +} + +EXPORT uint64_t ntt120_vec_znx_idft_tmp_bytes_avx(const MODULE* module) { return module->nn * 4 * sizeof(uint64_t); } + +EXPORT void ntt120_vec_znx_idft_tmp_a_avx(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten +) { + const uint64_t nn = module->nn; + const uint64_t smin = res_size < a_size ? res_size : a_size; + + __int128_t* const tres = (__int128_t*)res; + int64_t* const ta = (int64_t*)a_dft; + for (uint64_t i = 0; i < smin; i++) { + q120_intt_bb_avx2(module->mod.q120.p_intt, (q120b*)(ta + i * nn * 4)); + q120_b_to_znx128_simple(nn, tres + i * nn, (q120b*)(ta + i * nn * 4)); + } + + // fill up remaining part with 0's + memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(*tres)); +} diff --git a/spqlios/lib/spqlios/arithmetic/vec_znx_dft_avx2.c b/spqlios/lib/spqlios/arithmetic/vec_znx_dft_avx2.c new file mode 100644 index 0000000..dbca7cc --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vec_znx_dft_avx2.c @@ -0,0 +1 @@ +#include "vec_znx_arithmetic_private.h" diff --git a/spqlios/lib/spqlios/arithmetic/vector_matrix_product.c b/spqlios/lib/spqlios/arithmetic/vector_matrix_product.c new file mode 100644 index 0000000..79ab40c --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vector_matrix_product.c @@ -0,0 +1,240 @@ +#include + +#include "../reim4/reim4_arithmetic.h" +#include "vec_znx_arithmetic_private.h" + +EXPORT uint64_t bytes_of_vmp_pmat(const MODULE* module, // N + uint64_t nrows, uint64_t ncols // dimensions +) { + return module->func.bytes_of_vmp_pmat(module, nrows, ncols); +} + +// fft64 +EXPORT uint64_t fft64_bytes_of_vmp_pmat(const MODULE* module, // N + uint64_t nrows, uint64_t ncols // dimensions +) { + return module->nn * nrows * ncols * sizeof(double); +} + +EXPORT VMP_PMAT* new_vmp_pmat(const MODULE* module, // N + uint64_t nrows, uint64_t ncols // dimensions +) { + return spqlios_alloc(bytes_of_vmp_pmat(module, nrows, ncols)); +} + +EXPORT void delete_vmp_pmat(VMP_PMAT* res) { spqlios_free(res); } + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void vmp_prepare_contiguous(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + module->func.vmp_prepare_contiguous(module, pmat, mat, nrows, ncols, tmp_space); +} + +/** @brief minimal scratch space byte-size required for the vmp_prepare function */ +EXPORT uint64_t vmp_prepare_contiguous_tmp_bytes(const MODULE* module, // N + uint64_t nrows, uint64_t ncols) { + return module->func.vmp_prepare_contiguous_tmp_bytes(module, nrows, ncols); +} + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->nn; + const uint64_t m = module->m; + + double* output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)tmp_space, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->mod.fft64.p_fft, (double*)tmp_space); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, (double*)tmp_space); + } + } + } + } else { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = (double*)pmat + (col_i * nrows + row_i) * nn; + reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)res, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->mod.fft64.p_fft, res); + } + } + } +} + +/** @brief minimal scratch space byte-size required for the vmp_prepare function */ +EXPORT uint64_t fft64_vmp_prepare_contiguous_tmp_bytes(const MODULE* module, // N + uint64_t nrows, uint64_t ncols) { + const uint64_t nn = module->nn; + return nn * sizeof(int64_t); +} + +/** @brief applies a vmp product (result in DFT space) */ +EXPORT void fft64_vmp_apply_dft_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + const uint64_t nn = module->nn; + const uint64_t rows = nrows < a_size ? nrows : a_size; + + VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space; + uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double); + + fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl); + fft64_vmp_apply_dft_to_dft_ref(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, new_tmp_space); +} + +/** @brief this inner function could be very handy */ +EXPORT void fft64_vmp_apply_dft_to_dft_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, + const uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + const uint64_t m = module->m; + const uint64_t nn = module->nn; + + double* mat2cols_output = (double*)tmp_space; // 128 bytes + double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes + + double* mat_input = (double*)pmat; + double* vec_input = (double*)a_dft; + double* vec_output = (double*)res; + + const uint64_t row_max = nrows < a_size ? nrows : a_size; + const uint64_t col_max = ncols < res_size ? ncols : res_size; + + if (nn >= 8) { + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols); + + reim4_extract_1blk_from_contiguous_reim_ref(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft); + // apply mat2cols + for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) { + uint64_t col_offset = col_i * (8 * nrows); + reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + + reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + col_i * nn, mat2cols_output); + reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + (col_i + 1) * nn, mat2cols_output + 8); + } + + // check if col_max is odd, then special case + if (col_max % 2 == 1) { + uint64_t last_col = col_max - 1; + uint64_t col_offset = last_col * (8 * nrows); + + // the last column is alone in the pmat: vec_mat1col + if (ncols == col_max) { + reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } else { + // the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position + reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } + reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + last_col * nn, mat2cols_output); + } + } + } else { + for (uint64_t col_i = 0; col_i < col_max; col_i++) { + double* pmat_col = mat_input + col_i * nrows * nn; + for (uint64_t row_i = 0; row_i < 1; row_i++) { + reim_fftvec_mul(module->mod.fft64.mul_fft, vec_output + col_i * nn, vec_input + row_i * nn, + pmat_col + row_i * nn); + } + for (uint64_t row_i = 1; row_i < row_max; row_i++) { + reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_i * nn, vec_input + row_i * nn, + pmat_col + row_i * nn); + } + } + } + + // zero out remaining bytes + memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double)); +} + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t fft64_vmp_apply_dft_tmp_bytes(const MODULE* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + const uint64_t nn = module->nn; + const uint64_t row_max = nrows < a_size ? nrows : a_size; + + return (row_max * nn * sizeof(double)) + (128) + (64 * row_max); +} + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t fft64_vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + const uint64_t row_max = nrows < a_size ? nrows : a_size; + + return (128) + (64 * row_max); +} + +EXPORT void vmp_apply_dft_to_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, + const uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + module->func.vmp_apply_dft_to_dft(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, tmp_space); +} + +EXPORT uint64_t vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + return module->func.vmp_apply_dft_to_dft_tmp_bytes(module, res_size, a_size, nrows, ncols); +} + +/** @brief applies a vmp product (result in DFT space) */ +EXPORT void vmp_apply_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + module->func.vmp_apply_dft(module, res, res_size, a, a_size, a_sl, pmat, nrows, ncols, tmp_space); +} + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t vmp_apply_dft_tmp_bytes(const MODULE* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + return module->func.vmp_apply_dft_tmp_bytes(module, res_size, a_size, nrows, ncols); +} diff --git a/spqlios/lib/spqlios/arithmetic/vector_matrix_product_avx.c b/spqlios/lib/spqlios/arithmetic/vector_matrix_product_avx.c new file mode 100644 index 0000000..f428650 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/vector_matrix_product_avx.c @@ -0,0 +1,137 @@ +#include + +#include "../reim4/reim4_arithmetic.h" +#include "vec_znx_arithmetic_private.h" + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->nn; + const uint64_t m = module->m; + + double* output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)tmp_space, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->mod.fft64.p_fft, (double*)tmp_space); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, (double*)tmp_space); + } + } + } + } else { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = (double*)pmat + (col_i * nrows + row_i) * nn; + reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)res, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->mod.fft64.p_fft, res); + } + } + } +} + +/** @brief applies a vmp product (result in DFT space) */ +EXPORT void fft64_vmp_apply_dft_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + const uint64_t nn = module->nn; + const uint64_t rows = nrows < a_size ? nrows : a_size; + + VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space; + uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double); + + fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl); + fft64_vmp_apply_dft_to_dft_avx(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, new_tmp_space); +} + +/** @brief this inner function could be very handy */ +EXPORT void fft64_vmp_apply_dft_to_dft_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, + const uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + const uint64_t m = module->m; + const uint64_t nn = module->nn; + + double* mat2cols_output = (double*)tmp_space; // 128 bytes + double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes + + double* mat_input = (double*)pmat; + double* vec_input = (double*)a_dft; + double* vec_output = (double*)res; + + const uint64_t row_max = nrows < a_size ? nrows : a_size; + const uint64_t col_max = ncols < res_size ? ncols : res_size; + + if (nn >= 8) { + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols); + + reim4_extract_1blk_from_contiguous_reim_avx(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft); + // apply mat2cols + for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) { + uint64_t col_offset = col_i * (8 * nrows); + reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + + reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + col_i * nn, mat2cols_output); + reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + (col_i + 1) * nn, mat2cols_output + 8); + } + + // check if col_max is odd, then special case + if (col_max % 2 == 1) { + uint64_t last_col = col_max - 1; + uint64_t col_offset = last_col * (8 * nrows); + + // the last column is alone in the pmat: vec_mat1col + if (ncols == col_max) + reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + else { + // the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position + reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } + reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + last_col * nn, mat2cols_output); + } + } + } else { + for (uint64_t col_i = 0; col_i < col_max; col_i++) { + double* pmat_col = mat_input + col_i * nrows * nn; + for (uint64_t row_i = 0; row_i < 1; row_i++) { + reim_fftvec_mul(module->mod.fft64.mul_fft, vec_output + col_i * nn, vec_input + row_i * nn, + pmat_col + row_i * nn); + } + for (uint64_t row_i = 1; row_i < row_max; row_i++) { + reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_i * nn, vec_input + row_i * nn, + pmat_col + row_i * nn); + } + } + } + + // zero out remaining bytes + memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double)); +} diff --git a/spqlios/lib/spqlios/arithmetic/zn_api.c b/spqlios/lib/spqlios/arithmetic/zn_api.c new file mode 100644 index 0000000..28d5c8d --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/zn_api.c @@ -0,0 +1,169 @@ +#include + +#include "zn_arithmetic_private.h" + +void default_init_z_module_precomp(MOD_Z* module) { + // Add here initialization of items that are in the precomp +} + +void default_finalize_z_module_precomp(MOD_Z* module) { + // Add here deleters for items that are in the precomp +} + +void default_init_z_module_vtable(MOD_Z* module) { + // Add function pointers here + module->vtable.i8_approxdecomp_from_tndbl = default_i8_approxdecomp_from_tndbl_ref; + module->vtable.i16_approxdecomp_from_tndbl = default_i16_approxdecomp_from_tndbl_ref; + module->vtable.i32_approxdecomp_from_tndbl = default_i32_approxdecomp_from_tndbl_ref; + module->vtable.zn32_vmp_prepare_contiguous = default_zn32_vmp_prepare_contiguous_ref; + module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_ref; + module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_ref; + module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_ref; + module->vtable.dbl_to_tn32 = dbl_to_tn32_ref; + module->vtable.tn32_to_dbl = tn32_to_dbl_ref; + module->vtable.dbl_round_to_i32 = dbl_round_to_i32_ref; + module->vtable.i32_to_dbl = i32_to_dbl_ref; + module->vtable.dbl_round_to_i64 = dbl_round_to_i64_ref; + module->vtable.i64_to_dbl = i64_to_dbl_ref; + + // Add optimized function pointers here + if (CPU_SUPPORTS("avx")) { + module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_avx; + module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_avx; + module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_avx; + } +} + +void init_z_module_info(MOD_Z* module, // + Z_MODULE_TYPE mtype) { + memset(module, 0, sizeof(MOD_Z)); + module->mtype = mtype; + switch (mtype) { + case DEFAULT: + default_init_z_module_precomp(module); + default_init_z_module_vtable(module); + break; + default: + NOT_SUPPORTED(); // unknown mtype + } +} + +void finalize_z_module_info(MOD_Z* module) { + if (module->custom) module->custom_deleter(module->custom); + switch (module->mtype) { + case DEFAULT: + default_finalize_z_module_precomp(module); + // fft64_finalize_rnx_module_vtable(module); // nothing to finalize + break; + default: + NOT_SUPPORTED(); // unknown mtype + } +} + +EXPORT MOD_Z* new_z_module_info(Z_MODULE_TYPE mtype) { + MOD_Z* res = (MOD_Z*)malloc(sizeof(MOD_Z)); + init_z_module_info(res, mtype); + return res; +} + +EXPORT void delete_z_module_info(MOD_Z* module_info) { + finalize_z_module_info(module_info); + free(module_info); +} + +//////////////// wrappers ////////////////// + +/** @brief sets res = gadget_decompose(a) (int8_t* output) */ +EXPORT void i8_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int8_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size) { // a + module->vtable.i8_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size); +} + +/** @brief sets res = gadget_decompose(a) (int16_t* output) */ +EXPORT void i16_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int16_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size) { // a + module->vtable.i16_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size); +} +/** @brief sets res = gadget_decompose(a) (int32_t* output) */ +EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int32_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size) { // a + module->vtable.i32_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size); +} + +EXPORT void zn32_vmp_prepare_contiguous( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* mat, uint64_t nrows, uint64_t ncols) { // a + module->vtable.zn32_vmp_prepare_contiguous(module, pmat, mat, nrows, ncols); +} + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void zn32_vmp_apply_i32(const MOD_Z* module, int32_t* res, uint64_t res_size, const int32_t* a, uint64_t a_size, + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + module->vtable.zn32_vmp_apply_i32(module, res, res_size, a, a_size, pmat, nrows, ncols); +} +/** @brief applies a vmp product (int16_t* input) */ +EXPORT void zn32_vmp_apply_i16(const MOD_Z* module, int32_t* res, uint64_t res_size, const int16_t* a, uint64_t a_size, + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + module->vtable.zn32_vmp_apply_i16(module, res, res_size, a, a_size, pmat, nrows, ncols); +} + +/** @brief applies a vmp product (int8_t* input) */ +EXPORT void zn32_vmp_apply_i8(const MOD_Z* module, int32_t* res, uint64_t res_size, const int8_t* a, uint64_t a_size, + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + module->vtable.zn32_vmp_apply_i8(module, res, res_size, a, a_size, pmat, nrows, ncols); +} + +/** reduction mod 1, output in torus32 space */ +EXPORT void dbl_to_tn32(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + module->vtable.dbl_to_tn32(module, res, res_size, a, a_size); +} + +/** real centerlift mod 1, output in double space */ +EXPORT void tn32_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +) { + module->vtable.tn32_to_dbl(module, res, res_size, a, a_size); +} + +/** round to the nearest int, output in i32 space */ +EXPORT void dbl_round_to_i32(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + module->vtable.dbl_round_to_i32(module, res, res_size, a, a_size); +} + +/** small int (int32 space) to double */ +EXPORT void i32_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +) { + module->vtable.i32_to_dbl(module, res, res_size, a, a_size); +} + +/** round to the nearest int, output in int64 space */ +EXPORT void dbl_round_to_i64(const MOD_Z* module, // + int64_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + module->vtable.dbl_round_to_i64(module, res, res_size, a, a_size); +} + +/** small int (int64 space, <= 2^50) to double */ +EXPORT void i64_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size // a +) { + module->vtable.i64_to_dbl(module, res, res_size, a, a_size); +} diff --git a/spqlios/lib/spqlios/arithmetic/zn_approxdecomp_ref.c b/spqlios/lib/spqlios/arithmetic/zn_approxdecomp_ref.c new file mode 100644 index 0000000..616b9a3 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/zn_approxdecomp_ref.c @@ -0,0 +1,81 @@ +#include + +#include "zn_arithmetic_private.h" + +EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, // + uint64_t k, uint64_t ell) { + if (k * ell > 50) { + return spqlios_error("approx decomposition requested is too precise for doubles"); + } + if (k < 1) { + return spqlios_error("approx decomposition supports k>=1"); + } + TNDBL_APPROXDECOMP_GADGET* res = malloc(sizeof(TNDBL_APPROXDECOMP_GADGET)); + memset(res, 0, sizeof(TNDBL_APPROXDECOMP_GADGET)); + res->k = k; + res->ell = ell; + double add_cst = INT64_C(3) << (51 - k * ell); + for (uint64_t i = 0; i < ell; ++i) { + add_cst += pow(2., -(double)(i * k + 1)); + } + res->add_cst = add_cst; + res->and_mask = (UINT64_C(1) << k) - 1; + res->sub_cst = UINT64_C(1) << (k - 1); + for (uint64_t i = 0; i < ell; ++i) res->rshifts[i] = (ell - 1 - i) * k; + return res; +} +EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr) { free(ptr); } + +EXPORT int default_init_tndbl_approxdecomp_gadget(const MOD_Z* module, // + TNDBL_APPROXDECOMP_GADGET* res, // + uint64_t k, uint64_t ell) { + return 0; +} + +typedef union { + double dv; + uint64_t uv; +} du_t; + +#define IMPL_ixx_approxdecomp_from_tndbl_ref(ITYPE) \ + if (res_size != a_size * gadget->ell) NOT_IMPLEMENTED(); \ + const uint64_t ell = gadget->ell; \ + const double add_cst = gadget->add_cst; \ + const uint8_t* const rshifts = gadget->rshifts; \ + const ITYPE and_mask = gadget->and_mask; \ + const ITYPE sub_cst = gadget->sub_cst; \ + ITYPE* rr = res; \ + const double* aa = a; \ + const double* aaend = a + a_size; \ + while (aa < aaend) { \ + du_t t = {.dv = *aa + add_cst}; \ + for (uint64_t i = 0; i < ell; ++i) { \ + ITYPE v = (ITYPE)(t.uv >> rshifts[i]); \ + *rr = (v & and_mask) - sub_cst; \ + ++rr; \ + } \ + ++aa; \ + } + +/** @brief sets res = gadget_decompose(a) (int8_t* output) */ +EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int8_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size // +){IMPL_ixx_approxdecomp_from_tndbl_ref(int8_t)} + +/** @brief sets res = gadget_decompose(a) (int16_t* output) */ +EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int16_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +){IMPL_ixx_approxdecomp_from_tndbl_ref(int16_t)} + +/** @brief sets res = gadget_decompose(a) (int32_t* output) */ +EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + IMPL_ixx_approxdecomp_from_tndbl_ref(int32_t) +} diff --git a/spqlios/lib/spqlios/arithmetic/zn_arithmetic.h b/spqlios/lib/spqlios/arithmetic/zn_arithmetic.h new file mode 100644 index 0000000..3503e20 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/zn_arithmetic.h @@ -0,0 +1,135 @@ +#ifndef SPQLIOS_ZN_ARITHMETIC_H +#define SPQLIOS_ZN_ARITHMETIC_H + +#include + +#include "../commons.h" + +typedef enum z_module_type_t { DEFAULT } Z_MODULE_TYPE; + +/** @brief opaque structure that describes the module and the hardware */ +typedef struct z_module_info_t MOD_Z; + +/** + * @brief obtain a module info for ring dimension N + * the module-info knows about: + * - the dimension N (or the complex dimension m=N/2) + * - any moduleuted fft or ntt items + * - the hardware (avx, arm64, x86, ...) + */ +EXPORT MOD_Z* new_z_module_info(Z_MODULE_TYPE mode); +EXPORT void delete_z_module_info(MOD_Z* module_info); + +typedef struct tndbl_approxdecomp_gadget_t TNDBL_APPROXDECOMP_GADGET; + +EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, // + uint64_t k, + uint64_t ell); // base 2^k, and size + +EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr); + +/** @brief sets res = gadget_decompose(a) (int8_t* output) */ +EXPORT void i8_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int8_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a + +/** @brief sets res = gadget_decompose(a) (int16_t* output) */ +EXPORT void i16_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int16_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a +/** @brief sets res = gadget_decompose(a) (int32_t* output) */ +EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int32_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a + +/** @brief opaque type that represents a prepared matrix */ +typedef struct zn32_vmp_pmat_t ZN32_VMP_PMAT; + +/** @brief size in bytes of a prepared matrix (for custom allocation) */ +EXPORT uint64_t bytes_of_zn32_vmp_pmat(const MOD_Z* module, // N + uint64_t nrows, uint64_t ncols); // dimensions + +/** @brief allocates a prepared matrix (release with delete_zn32_vmp_pmat) */ +EXPORT ZN32_VMP_PMAT* new_zn32_vmp_pmat(const MOD_Z* module, // N + uint64_t nrows, uint64_t ncols); // dimensions + +/** @brief deletes a prepared matrix (release with free) */ +EXPORT void delete_zn32_vmp_pmat(ZN32_VMP_PMAT* ptr); // dimensions + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void zn32_vmp_prepare_contiguous( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* mat, uint64_t nrows, uint64_t ncols); // a + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void zn32_vmp_apply_i32( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int16_t* input) */ +EXPORT void zn32_vmp_apply_i16( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int16_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int8_t* input) */ +EXPORT void zn32_vmp_apply_i8( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int8_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +// explicit conversions + +/** reduction mod 1, output in torus32 space */ +EXPORT void dbl_to_tn32(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** real centerlift mod 1, output in double space */ +EXPORT void tn32_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +); + +/** round to the nearest int, output in i32 space. + * WARNING: ||a||_inf must be <= 2^18 in this function + */ +EXPORT void dbl_round_to_i32(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** small int (int32 space) to double + * WARNING: ||a||_inf must be <= 2^18 in this function + */ +EXPORT void i32_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +); + +/** round to the nearest int, output in int64 space + * WARNING: ||a||_inf must be <= 2^50 in this function + */ +EXPORT void dbl_round_to_i64(const MOD_Z* module, // + int64_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** small int (int64 space, <= 2^50) to double + * WARNING: ||a||_inf must be <= 2^50 in this function + */ +EXPORT void i64_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size // a +); + +#endif // SPQLIOS_ZN_ARITHMETIC_H diff --git a/spqlios/lib/spqlios/arithmetic/zn_arithmetic_plugin.h b/spqlios/lib/spqlios/arithmetic/zn_arithmetic_plugin.h new file mode 100644 index 0000000..d400a72 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/zn_arithmetic_plugin.h @@ -0,0 +1,39 @@ +#ifndef SPQLIOS_ZN_ARITHMETIC_PLUGIN_H +#define SPQLIOS_ZN_ARITHMETIC_PLUGIN_H + +#include "zn_arithmetic.h" + +typedef typeof(i8_approxdecomp_from_tndbl) I8_APPROXDECOMP_FROM_TNDBL_F; +typedef typeof(i16_approxdecomp_from_tndbl) I16_APPROXDECOMP_FROM_TNDBL_F; +typedef typeof(i32_approxdecomp_from_tndbl) I32_APPROXDECOMP_FROM_TNDBL_F; +typedef typeof(bytes_of_zn32_vmp_pmat) BYTES_OF_ZN32_VMP_PMAT_F; +typedef typeof(zn32_vmp_prepare_contiguous) ZN32_VMP_PREPARE_CONTIGUOUS_F; +typedef typeof(zn32_vmp_apply_i32) ZN32_VMP_APPLY_I32_F; +typedef typeof(zn32_vmp_apply_i16) ZN32_VMP_APPLY_I16_F; +typedef typeof(zn32_vmp_apply_i8) ZN32_VMP_APPLY_I8_F; +typedef typeof(dbl_to_tn32) DBL_TO_TN32_F; +typedef typeof(tn32_to_dbl) TN32_TO_DBL_F; +typedef typeof(dbl_round_to_i32) DBL_ROUND_TO_I32_F; +typedef typeof(i32_to_dbl) I32_TO_DBL_F; +typedef typeof(dbl_round_to_i64) DBL_ROUND_TO_I64_F; +typedef typeof(i64_to_dbl) I64_TO_DBL_F; + +typedef struct z_module_vtable_t Z_MODULE_VTABLE; +struct z_module_vtable_t { + I8_APPROXDECOMP_FROM_TNDBL_F* i8_approxdecomp_from_tndbl; + I16_APPROXDECOMP_FROM_TNDBL_F* i16_approxdecomp_from_tndbl; + I32_APPROXDECOMP_FROM_TNDBL_F* i32_approxdecomp_from_tndbl; + BYTES_OF_ZN32_VMP_PMAT_F* bytes_of_zn32_vmp_pmat; + ZN32_VMP_PREPARE_CONTIGUOUS_F* zn32_vmp_prepare_contiguous; + ZN32_VMP_APPLY_I32_F* zn32_vmp_apply_i32; + ZN32_VMP_APPLY_I16_F* zn32_vmp_apply_i16; + ZN32_VMP_APPLY_I8_F* zn32_vmp_apply_i8; + DBL_TO_TN32_F* dbl_to_tn32; + TN32_TO_DBL_F* tn32_to_dbl; + DBL_ROUND_TO_I32_F* dbl_round_to_i32; + I32_TO_DBL_F* i32_to_dbl; + DBL_ROUND_TO_I64_F* dbl_round_to_i64; + I64_TO_DBL_F* i64_to_dbl; +}; + +#endif // SPQLIOS_ZN_ARITHMETIC_PLUGIN_H diff --git a/spqlios/lib/spqlios/arithmetic/zn_arithmetic_private.h b/spqlios/lib/spqlios/arithmetic/zn_arithmetic_private.h new file mode 100644 index 0000000..3ff6c48 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/zn_arithmetic_private.h @@ -0,0 +1,150 @@ +#ifndef SPQLIOS_ZN_ARITHMETIC_PRIVATE_H +#define SPQLIOS_ZN_ARITHMETIC_PRIVATE_H + +#include "../commons_private.h" +#include "zn_arithmetic.h" +#include "zn_arithmetic_plugin.h" + +typedef struct main_z_module_precomp_t MAIN_Z_MODULE_PRECOMP; +struct main_z_module_precomp_t { + // TODO +}; + +typedef union z_module_precomp_t Z_MODULE_PRECOMP; +union z_module_precomp_t { + MAIN_Z_MODULE_PRECOMP main; +}; + +void main_init_z_module_precomp(MOD_Z* module); + +void main_finalize_z_module_precomp(MOD_Z* module); + +/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */ +struct z_module_info_t { + Z_MODULE_TYPE mtype; + Z_MODULE_VTABLE vtable; + Z_MODULE_PRECOMP precomp; + void* custom; + void (*custom_deleter)(void*); +}; + +void init_z_module_info(MOD_Z* module, Z_MODULE_TYPE mtype); + +void main_init_z_module_vtable(MOD_Z* module); + +struct tndbl_approxdecomp_gadget_t { + uint64_t k; + uint64_t ell; + double add_cst; // 3.2^51-(K.ell) + 1/2.(sum 2^-(i+1)K) + int64_t and_mask; // (2^K)-1 + int64_t sub_cst; // 2^(K-1) + uint8_t rshifts[64]; // 2^(ell-1-i).K for i in [0:ell-1] +}; + +/** @brief sets res = gadget_decompose(a) (int8_t* output) */ +EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int8_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a + +/** @brief sets res = gadget_decompose(a) (int16_t* output) */ +EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int16_t* res, + uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a +/** @brief sets res = gadget_decompose(a) (int32_t* output) */ +EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int32_t* res, + uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void default_zn32_vmp_prepare_contiguous_ref( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* mat, uint64_t nrows, uint64_t ncols // a +); + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void default_zn32_vmp_apply_i32_ref( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int16_t* input) */ +EXPORT void default_zn32_vmp_apply_i16_ref( // + const MOD_Z* module, // N + int32_t* res, uint64_t res_size, // res + const int16_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int8_t* input) */ +EXPORT void default_zn32_vmp_apply_i8_ref( // + const MOD_Z* module, // N + int32_t* res, uint64_t res_size, // res + const int8_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void default_zn32_vmp_apply_i32_avx( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int16_t* input) */ +EXPORT void default_zn32_vmp_apply_i16_avx( // + const MOD_Z* module, // N + int32_t* res, uint64_t res_size, // res + const int16_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int8_t* input) */ +EXPORT void default_zn32_vmp_apply_i8_avx( // + const MOD_Z* module, // N + int32_t* res, uint64_t res_size, // res + const int8_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +// explicit conversions + +/** reduction mod 1, output in torus32 space */ +EXPORT void dbl_to_tn32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** real centerlift mod 1, output in double space */ +EXPORT void tn32_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +); + +/** round to the nearest int, output in i32 space */ +EXPORT void dbl_round_to_i32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** small int (int32 space) to double */ +EXPORT void i32_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +); + +/** round to the nearest int, output in int64 space */ +EXPORT void dbl_round_to_i64_ref(const MOD_Z* module, // + int64_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** small int (int64 space) to double */ +EXPORT void i64_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size // a +); + +#endif // SPQLIOS_ZN_ARITHMETIC_PRIVATE_H diff --git a/spqlios/lib/spqlios/arithmetic/zn_conversions_ref.c b/spqlios/lib/spqlios/arithmetic/zn_conversions_ref.c new file mode 100644 index 0000000..f016a71 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/zn_conversions_ref.c @@ -0,0 +1,108 @@ +#include + +#include "zn_arithmetic_private.h" + +typedef union { + double dv; + int64_t s64v; + int32_t s32v; + uint64_t u64v; + uint32_t u32v; +} di_t; + +/** reduction mod 1, output in torus32 space */ +EXPORT void dbl_to_tn32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + static const double ADD_CST = 0.5 + (double)(INT64_C(3) << (51 - 32)); + static const int32_t XOR_CST = (INT32_C(1) << 31); + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + di_t t = {.dv = a[i] + ADD_CST}; + res[i] = t.s32v ^ XOR_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(int32_t)); +} + +/** real centerlift mod 1, output in double space */ +EXPORT void tn32_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +) { + static const uint32_t XOR_CST = (UINT32_C(1) << 31); + static const di_t OR_CST = {.dv = (double)(INT64_C(1) << (52 - 32))}; + static const double SUB_CST = 0.5 + (double)(INT64_C(1) << (52 - 32)); + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + uint32_t ai = a[i] ^ XOR_CST; + di_t t = {.u64v = OR_CST.u64v | (uint64_t)ai}; + res[i] = t.dv - SUB_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(double)); +} + +/** round to the nearest int, output in i32 space */ +EXPORT void dbl_round_to_i32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + static const double ADD_CST = (double)((INT64_C(3) << (51)) + (INT64_C(1) << (31))); + static const int32_t XOR_CST = INT32_C(1) << 31; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + di_t t = {.dv = a[i] + ADD_CST}; + res[i] = t.s32v ^ XOR_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(int32_t)); +} + +/** small int (int32 space) to double */ +EXPORT void i32_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +) { + static const uint32_t XOR_CST = (UINT32_C(1) << 31); + static const di_t OR_CST = {.dv = (double)(INT64_C(1) << 52)}; + static const double SUB_CST = (double)((INT64_C(1) << 52) + (INT64_C(1) << 31)); + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + uint32_t ai = a[i] ^ XOR_CST; + di_t t = {.u64v = OR_CST.u64v | (uint64_t)ai}; + res[i] = t.dv - SUB_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(double)); +} + +/** round to the nearest int, output in int64 space */ +EXPORT void dbl_round_to_i64_ref(const MOD_Z* module, // + int64_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + static const double ADD_CST = (double)(INT64_C(3) << (51)); + static const int64_t AND_CST = (INT64_C(1) << 52) - 1; + static const int64_t SUB_CST = INT64_C(1) << 51; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + di_t t = {.dv = a[i] + ADD_CST}; + res[i] = (t.s64v & AND_CST) - SUB_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(int64_t)); +} + +/** small int (int64 space) to double */ +EXPORT void i64_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size // a +) { + static const uint64_t ADD_CST = UINT64_C(1) << 51; + static const uint64_t AND_CST = (UINT64_C(1) << 52) - 1; + static const di_t OR_CST = {.dv = (INT64_C(1) << 52)}; + static const double SUB_CST = INT64_C(3) << 51; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + di_t t = {.u64v = ((a[i] + ADD_CST) & AND_CST) | OR_CST.u64v}; + res[i] = t.dv - SUB_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(double)); +} diff --git a/spqlios/lib/spqlios/arithmetic/zn_vmp_int16_avx.c b/spqlios/lib/spqlios/arithmetic/zn_vmp_int16_avx.c new file mode 100644 index 0000000..563f199 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/zn_vmp_int16_avx.c @@ -0,0 +1,4 @@ +#define INTTYPE int16_t +#define INTSN i16 + +#include "zn_vmp_int32_avx.c" diff --git a/spqlios/lib/spqlios/arithmetic/zn_vmp_int16_ref.c b/spqlios/lib/spqlios/arithmetic/zn_vmp_int16_ref.c new file mode 100644 index 0000000..0626c9b --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/zn_vmp_int16_ref.c @@ -0,0 +1,4 @@ +#define INTTYPE int16_t +#define INTSN i16 + +#include "zn_vmp_int32_ref.c" diff --git a/spqlios/lib/spqlios/arithmetic/zn_vmp_int32_avx.c b/spqlios/lib/spqlios/arithmetic/zn_vmp_int32_avx.c new file mode 100644 index 0000000..3fbc8fb --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/zn_vmp_int32_avx.c @@ -0,0 +1,223 @@ +// This file is actually a template: it will be compiled multiple times with +// different INTTYPES +#ifndef INTTYPE +#define INTTYPE int32_t +#define INTSN i32 +#endif + +#include +#include + +#include "zn_arithmetic_private.h" + +#define concat_inner(aa, bb, cc) aa##_##bb##_##cc +#define concat(aa, bb, cc) concat_inner(aa, bb, cc) +#define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc) + +static void zn32_vec_mat32cols_avx_prefetch(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b) { + if (nrows == 0) { + memset(res, 0, 32 * sizeof(int32_t)); + return; + } + const int32_t* bb = b; + const int32_t* pref_bb = b; + const uint64_t pref_iters = 128; + const uint64_t pref_start = pref_iters < nrows ? pref_iters : nrows; + const uint64_t pref_last = pref_iters > nrows ? 0 : nrows - pref_iters; + // let's do some prefetching of the GSW key, since on some cpus, + // it helps + for (uint64_t i = 0; i < pref_start; ++i) { + __builtin_prefetch(pref_bb, 0, _MM_HINT_T0); + __builtin_prefetch(pref_bb + 16, 0, _MM_HINT_T0); + pref_bb += 32; + } + // we do the first iteration + __m256i x = _mm256_set1_epi32(a[0]); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + __m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))); + __m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))); + __m256i r3 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))); + bb += 32; + uint64_t row = 1; + for (; // + row < pref_last; // + ++row, bb += 32) { + // prefetch the next iteration + __builtin_prefetch(pref_bb, 0, _MM_HINT_T0); + __builtin_prefetch(pref_bb + 16, 0, _MM_HINT_T0); + pref_bb += 32; + INTTYPE ai = a[row]; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)))); + r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)))); + } + for (; // + row < nrows; // + ++row, bb += 32) { + INTTYPE ai = a[row]; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)))); + r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); + _mm256_storeu_si256((__m256i*)(res + 8), r1); + _mm256_storeu_si256((__m256i*)(res + 16), r2); + _mm256_storeu_si256((__m256i*)(res + 24), r3); +} + +void zn32_vec_fn(mat32cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + if (nrows == 0) { + memset(res, 0, 32 * sizeof(int32_t)); + return; + } + const INTTYPE* aa = a; + const INTTYPE* const aaend = a + nrows; + const int32_t* bb = b; + __m256i x = _mm256_set1_epi32(*aa); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + __m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))); + __m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))); + __m256i r3 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))); + bb += b_sl; + ++aa; + for (; // + aa < aaend; // + bb += b_sl, ++aa) { + INTTYPE ai = *aa; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)))); + r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); + _mm256_storeu_si256((__m256i*)(res + 8), r1); + _mm256_storeu_si256((__m256i*)(res + 16), r2); + _mm256_storeu_si256((__m256i*)(res + 24), r3); +} + +void zn32_vec_fn(mat24cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + if (nrows == 0) { + memset(res, 0, 24 * sizeof(int32_t)); + return; + } + const INTTYPE* aa = a; + const INTTYPE* const aaend = a + nrows; + const int32_t* bb = b; + __m256i x = _mm256_set1_epi32(*aa); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + __m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))); + __m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))); + bb += b_sl; + ++aa; + for (; // + aa < aaend; // + bb += b_sl, ++aa) { + INTTYPE ai = *aa; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); + _mm256_storeu_si256((__m256i*)(res + 8), r1); + _mm256_storeu_si256((__m256i*)(res + 16), r2); +} +void zn32_vec_fn(mat16cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + if (nrows == 0) { + memset(res, 0, 16 * sizeof(int32_t)); + return; + } + const INTTYPE* aa = a; + const INTTYPE* const aaend = a + nrows; + const int32_t* bb = b; + __m256i x = _mm256_set1_epi32(*aa); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + __m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))); + bb += b_sl; + ++aa; + for (; // + aa < aaend; // + bb += b_sl, ++aa) { + INTTYPE ai = *aa; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); + _mm256_storeu_si256((__m256i*)(res + 8), r1); +} + +void zn32_vec_fn(mat8cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + if (nrows == 0) { + memset(res, 0, 8 * sizeof(int32_t)); + return; + } + const INTTYPE* aa = a; + const INTTYPE* const aaend = a + nrows; + const int32_t* bb = b; + __m256i x = _mm256_set1_epi32(*aa); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + bb += b_sl; + ++aa; + for (; // + aa < aaend; // + bb += b_sl, ++aa) { + INTTYPE ai = *aa; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); +} + +typedef void (*vm_f)(uint64_t nrows, // + int32_t* res, // + const INTTYPE* a, // + const int32_t* b, uint64_t b_sl // +); +static const vm_f zn32_vec_mat8kcols_avx[4] = { // + zn32_vec_fn(mat8cols_avx), // + zn32_vec_fn(mat16cols_avx), // + zn32_vec_fn(mat24cols_avx), // + zn32_vec_fn(mat32cols_avx)}; + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void concat(default_zn32_vmp_apply, INTSN, avx)( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // + const INTTYPE* a, uint64_t a_size, // + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + const uint64_t rows = a_size < nrows ? a_size : nrows; + const uint64_t cols = res_size < ncols ? res_size : ncols; + const uint64_t ncolblk = cols >> 5; + const uint64_t ncolrem = cols & 31; + // copy the first full blocks + const uint64_t full_blk_size = nrows * 32; + const int32_t* mat = (int32_t*)pmat; + int32_t* rr = res; + for (uint64_t blk = 0; // + blk < ncolblk; // + ++blk, mat += full_blk_size, rr += 32) { + zn32_vec_mat32cols_avx_prefetch(rows, rr, a, mat); + } + // last block + if (ncolrem) { + uint64_t orig_rem = ncols - (ncolblk << 5); + uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem; + int32_t tmp[32]; + zn32_vec_mat8kcols_avx[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl); + memcpy(rr, tmp, ncolrem * sizeof(int32_t)); + } + // trailing bytes + memset(res + cols, 0, (res_size - cols) * sizeof(int32_t)); +} diff --git a/spqlios/lib/spqlios/arithmetic/zn_vmp_int32_ref.c b/spqlios/lib/spqlios/arithmetic/zn_vmp_int32_ref.c new file mode 100644 index 0000000..c3d0bc9 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/zn_vmp_int32_ref.c @@ -0,0 +1,88 @@ +// This file is actually a template: it will be compiled multiple times with +// different INTTYPES +#ifndef INTTYPE +#define INTTYPE int32_t +#define INTSN i32 +#endif + +#include + +#include "zn_arithmetic_private.h" + +#define concat_inner(aa, bb, cc) aa##_##bb##_##cc +#define concat(aa, bb, cc) concat_inner(aa, bb, cc) +#define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc) + +// the ref version shares the same implementation for each fixed column size +// optimized implementations may do something different. +static __always_inline void IMPL_zn32_vec_matcols_ref( + const uint64_t NCOLS, // fixed number of columns + uint64_t nrows, // nrows of b + int32_t* res, // result: size NCOLS, only the first min(b_sl, NCOLS) are relevant + const INTTYPE* a, // a: nrows-sized vector + const int32_t* b, uint64_t b_sl // b: nrows * min(b_sl, NCOLS) matrix +) { + memset(res, 0, NCOLS * sizeof(int32_t)); + for (uint64_t row = 0; row < nrows; ++row) { + int32_t ai = a[row]; + const int32_t* bb = b + row * b_sl; + for (uint64_t i = 0; i < NCOLS; ++i) { + res[i] += ai * bb[i]; + } + } +} + +void zn32_vec_fn(mat32cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_matcols_ref(32, nrows, res, a, b, b_sl); +} +void zn32_vec_fn(mat24cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_matcols_ref(24, nrows, res, a, b, b_sl); +} +void zn32_vec_fn(mat16cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_matcols_ref(16, nrows, res, a, b, b_sl); +} +void zn32_vec_fn(mat8cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_matcols_ref(8, nrows, res, a, b, b_sl); +} + +typedef void (*vm_f)(uint64_t nrows, // + int32_t* res, // + const INTTYPE* a, // + const int32_t* b, uint64_t b_sl // +); +static const vm_f zn32_vec_mat8kcols_ref[4] = { // + zn32_vec_fn(mat8cols_ref), // + zn32_vec_fn(mat16cols_ref), // + zn32_vec_fn(mat24cols_ref), // + zn32_vec_fn(mat32cols_ref)}; + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void concat(default_zn32_vmp_apply, INTSN, ref)( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // + const INTTYPE* a, uint64_t a_size, // + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + const uint64_t rows = a_size < nrows ? a_size : nrows; + const uint64_t cols = res_size < ncols ? res_size : ncols; + const uint64_t ncolblk = cols >> 5; + const uint64_t ncolrem = cols & 31; + // copy the first full blocks + const uint32_t full_blk_size = nrows * 32; + const int32_t* mat = (int32_t*)pmat; + int32_t* rr = res; + for (uint64_t blk = 0; // + blk < ncolblk; // + ++blk, mat += full_blk_size, rr += 32) { + zn32_vec_fn(mat32cols_ref)(rows, rr, a, mat, 32); + } + // last block + if (ncolrem) { + uint64_t orig_rem = ncols - (ncolblk << 5); + uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem; + int32_t tmp[32]; + zn32_vec_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl); + memcpy(rr, tmp, ncolrem * sizeof(int32_t)); + } + // trailing bytes + memset(res + cols, 0, (res_size - cols) * sizeof(int32_t)); +} diff --git a/spqlios/lib/spqlios/arithmetic/zn_vmp_int8_avx.c b/spqlios/lib/spqlios/arithmetic/zn_vmp_int8_avx.c new file mode 100644 index 0000000..74480aa --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/zn_vmp_int8_avx.c @@ -0,0 +1,4 @@ +#define INTTYPE int8_t +#define INTSN i8 + +#include "zn_vmp_int32_avx.c" diff --git a/spqlios/lib/spqlios/arithmetic/zn_vmp_int8_ref.c b/spqlios/lib/spqlios/arithmetic/zn_vmp_int8_ref.c new file mode 100644 index 0000000..d1de571 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/zn_vmp_int8_ref.c @@ -0,0 +1,4 @@ +#define INTTYPE int8_t +#define INTSN i8 + +#include "zn_vmp_int32_ref.c" diff --git a/spqlios/lib/spqlios/arithmetic/zn_vmp_ref.c b/spqlios/lib/spqlios/arithmetic/zn_vmp_ref.c new file mode 100644 index 0000000..d75dca2 --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/zn_vmp_ref.c @@ -0,0 +1,138 @@ +#include + +#include "zn_arithmetic_private.h" + +/** @brief size in bytes of a prepared matrix (for custom allocation) */ +EXPORT uint64_t bytes_of_zn32_vmp_pmat(const MOD_Z* module, // N + uint64_t nrows, uint64_t ncols // dimensions +) { + return (nrows * ncols + 7) * sizeof(int32_t); +} + +/** @brief allocates a prepared matrix (release with delete_zn32_vmp_pmat) */ +EXPORT ZN32_VMP_PMAT* new_zn32_vmp_pmat(const MOD_Z* module, // N + uint64_t nrows, uint64_t ncols) { + return (ZN32_VMP_PMAT*)spqlios_alloc(bytes_of_zn32_vmp_pmat(module, nrows, ncols)); +} + +/** @brief deletes a prepared matrix (release with free) */ +EXPORT void delete_zn32_vmp_pmat(ZN32_VMP_PMAT* ptr) { spqlios_free(ptr); } + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void default_zn32_vmp_prepare_contiguous_ref( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* mat, uint64_t nrows, uint64_t ncols // a +) { + int32_t* const out = (int32_t*)pmat; + const uint64_t nblk = ncols >> 5; + const uint64_t ncols_rem = ncols & 31; + const uint64_t final_elems = (8 - nrows * ncols) & 7; + for (uint64_t blk = 0; blk < nblk; ++blk) { + int32_t* outblk = out + blk * nrows * 32; + const int32_t* srcblk = mat + blk * 32; + for (uint64_t row = 0; row < nrows; ++row) { + int32_t* dest = outblk + row * 32; + const int32_t* src = srcblk + row * ncols; + for (uint64_t i = 0; i < 32; ++i) { + dest[i] = src[i]; + } + } + } + // copy the last block if any + if (ncols_rem) { + int32_t* outblk = out + nblk * nrows * 32; + const int32_t* srcblk = mat + nblk * 32; + for (uint64_t row = 0; row < nrows; ++row) { + int32_t* dest = outblk + row * ncols_rem; + const int32_t* src = srcblk + row * ncols; + for (uint64_t i = 0; i < ncols_rem; ++i) { + dest[i] = src[i]; + } + } + } + // zero-out the final elements that may be accessed + if (final_elems) { + int32_t* f = out + nrows * ncols; + for (uint64_t i = 0; i < final_elems; ++i) { + f[i] = 0; + } + } +} + +#if 0 + +#define IMPL_zn32_vec_ixxx_matyyycols_ref(NCOLS) \ + memset(res, 0, NCOLS * sizeof(int32_t)); \ + for (uint64_t row = 0; row < nrows; ++row) { \ + int32_t ai = a[row]; \ + const int32_t* bb = b + row * b_sl; \ + for (uint64_t i = 0; i < NCOLS; ++i) { \ + res[i] += ai * bb[i]; \ + } \ + } + +#define IMPL_zn32_vec_ixxx_mat8cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(8) +#define IMPL_zn32_vec_ixxx_mat16cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(16) +#define IMPL_zn32_vec_ixxx_mat24cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(24) +#define IMPL_zn32_vec_ixxx_mat32cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(32) + +void zn32_vec_i8_mat32cols_ref(uint64_t nrows, int32_t* res, const int8_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat32cols_ref() +} +void zn32_vec_i16_mat32cols_ref(uint64_t nrows, int32_t* res, const int16_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat32cols_ref() +} + +void zn32_vec_i32_mat32cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat32cols_ref() +} +void zn32_vec_i32_mat24cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat24cols_ref() +} +void zn32_vec_i32_mat16cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat16cols_ref() +} +void zn32_vec_i32_mat8cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat8cols_ref() +} +typedef void (*zn32_vec_i32_mat8kcols_ref_f)(uint64_t nrows, // + int32_t* res, // + const int32_t* a, // + const int32_t* b, uint64_t b_sl // +); +zn32_vec_i32_mat8kcols_ref_f zn32_vec_i32_mat8kcols_ref[4] = { // + zn32_vec_i32_mat8cols_ref, zn32_vec_i32_mat16cols_ref, // + zn32_vec_i32_mat24cols_ref, zn32_vec_i32_mat32cols_ref}; + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void default_zn32_vmp_apply_i32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // + const int32_t* a, uint64_t a_size, // + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + const uint64_t rows = a_size < nrows ? a_size : nrows; + const uint64_t cols = res_size < ncols ? res_size : ncols; + const uint64_t ncolblk = cols >> 5; + const uint64_t ncolrem = cols & 31; + // copy the first full blocks + const uint32_t full_blk_size = nrows * 32; + const int32_t* mat = (int32_t*)pmat; + int32_t* rr = res; + for (uint64_t blk = 0; // + blk < ncolblk; // + ++blk, mat += full_blk_size, rr += 32) { + zn32_vec_i32_mat32cols_ref(rows, rr, a, mat, 32); + } + // last block + if (ncolrem) { + uint64_t orig_rem = ncols - (ncolblk << 5); + uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem; + int32_t tmp[32]; + zn32_vec_i32_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl); + memcpy(rr, tmp, ncolrem * sizeof(int32_t)); + } + // trailing bytes + memset(res + cols, 0, (res_size - cols) * sizeof(int32_t)); +} + +#endif diff --git a/spqlios/lib/spqlios/arithmetic/znx_small.c b/spqlios/lib/spqlios/arithmetic/znx_small.c new file mode 100644 index 0000000..8ca0b0e --- /dev/null +++ b/spqlios/lib/spqlios/arithmetic/znx_small.c @@ -0,0 +1,38 @@ +#include "vec_znx_arithmetic_private.h" + +/** @brief res = a * b : small integer polynomial product */ +EXPORT void fft64_znx_small_single_product(const MODULE* module, // N + int64_t* res, // output + const int64_t* a, // a + const int64_t* b, // b + uint8_t* tmp) { + const uint64_t nn = module->nn; + double* const ffta = (double*)tmp; + double* const fftb = ((double*)tmp) + nn; + reim_from_znx64(module->mod.fft64.p_conv, ffta, a); + reim_from_znx64(module->mod.fft64.p_conv, fftb, b); + reim_fft(module->mod.fft64.p_fft, ffta); + reim_fft(module->mod.fft64.p_fft, fftb); + reim_fftvec_mul_simple(module->m, ffta, ffta, fftb); + reim_ifft(module->mod.fft64.p_ifft, ffta); + reim_to_znx64(module->mod.fft64.p_reim_to_znx, res, ffta); +} + +/** @brief tmp bytes required for znx_small_single_product */ +EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module) { + return 2 * module->nn * sizeof(double); +} + +/** @brief res = a * b : small integer polynomial product */ +EXPORT void znx_small_single_product(const MODULE* module, // N + int64_t* res, // output + const int64_t* a, // a + const int64_t* b, // b + uint8_t* tmp) { + module->func.znx_small_single_product(module, res, a, b, tmp); +} + +/** @brief tmp bytes required for znx_small_single_product */ +EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module) { + return module->func.znx_small_single_product_tmp_bytes(module); +} diff --git a/spqlios/lib/spqlios/coeffs/coeffs_arithmetic.c b/spqlios/lib/spqlios/coeffs/coeffs_arithmetic.c new file mode 100644 index 0000000..0fbbefe --- /dev/null +++ b/spqlios/lib/spqlios/coeffs/coeffs_arithmetic.c @@ -0,0 +1,496 @@ +#include "coeffs_arithmetic.h" + +#include +#include + +/** res = a + b */ +EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = a[i] + b[i]; + } +} +/** res = a - b */ +EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = a[i] - b[i]; + } +} + +EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = -a[i]; + } +} +EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) { memcpy(res, a, nn * sizeof(int64_t)); } + +EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res) { memset(res, 0, nn * sizeof(int64_t)); } + +EXPORT void rnx_divide_by_m_ref(uint64_t n, double m, double* res, const double* a) { + const double invm = 1. / m; + for (uint64_t i = 0; i < n; ++i) { + res[i] = a[i] * invm; + } +} + +EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in) { + uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn) + + if (a < nn) { // rotate to the left + uint64_t nma = nn - a; + // rotate first half + for (uint64_t j = 0; j < nma; j++) { + res[j] = in[j + a]; + } + for (uint64_t j = nma; j < nn; j++) { + res[j] = -in[j - nma]; + } + } else { + a -= nn; + uint64_t nma = nn - a; + for (uint64_t j = 0; j < nma; j++) { + res[j] = -in[j + a]; + } + for (uint64_t j = nma; j < nn; j++) { + // rotate first half + res[j] = in[j - nma]; + } + } +} + +EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) { + uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn) + + if (a < nn) { // rotate to the left + uint64_t nma = nn - a; + // rotate first half + for (uint64_t j = 0; j < nma; j++) { + res[j] = in[j + a]; + } + for (uint64_t j = nma; j < nn; j++) { + res[j] = -in[j - nma]; + } + } else { + a -= nn; + uint64_t nma = nn - a; + for (uint64_t j = 0; j < nma; j++) { + res[j] = -in[j + a]; + } + for (uint64_t j = nma; j < nn; j++) { + // rotate first half + res[j] = in[j - nma]; + } + } +} + +EXPORT void rnx_mul_xp_minus_one(uint64_t nn, int64_t p, double* res, const double* in) { + uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn) + if (a < nn) { // rotate to the left + uint64_t nma = nn - a; + // rotate first half + for (uint64_t j = 0; j < nma; j++) { + res[j] = in[j + a] - in[j]; + } + for (uint64_t j = nma; j < nn; j++) { + res[j] = -in[j - nma] - in[j]; + } + } else { + a -= nn; + uint64_t nma = nn - a; + for (uint64_t j = 0; j < nma; j++) { + res[j] = -in[j + a] - in[j]; + } + for (uint64_t j = nma; j < nn; j++) { + // rotate first half + res[j] = in[j - nma] - in[j]; + } + } +} + +EXPORT void znx_mul_xp_minus_one(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) { + uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn) + if (a < nn) { // rotate to the left + uint64_t nma = nn - a; + // rotate first half + for (uint64_t j = 0; j < nma; j++) { + res[j] = in[j + a] - in[j]; + } + for (uint64_t j = nma; j < nn; j++) { + res[j] = -in[j - nma] - in[j]; + } + } else { + a -= nn; + uint64_t nma = nn - a; + for (uint64_t j = 0; j < nma; j++) { + res[j] = -in[j + a] - in[j]; + } + for (uint64_t j = nma; j < nn; j++) { + // rotate first half + res[j] = in[j - nma] - in[j]; + } + } +} + +// 0 < p < 2nn +EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in) { + res[0] = in[0]; + uint64_t a = 0; + uint64_t _2mn = 2 * nn - 1; + for (uint64_t i = 1; i < nn; i++) { + a = (a + p) & _2mn; // i*p mod 2n + if (a < nn) { + res[a] = in[i]; // res[ip mod 2n] = res[i] + } else { + res[a - nn] = -in[i]; + } + } +} + +EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) { + res[0] = in[0]; + uint64_t a = 0; + uint64_t _2mn = 2 * nn - 1; + for (uint64_t i = 1; i < nn; i++) { + a = (a + p) & _2mn; + if (a < nn) { + res[a] = in[i]; // res[ip mod 2n] = res[i] + } else { + res[a - nn] = -in[i]; + } + } +} + +EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + uint64_t nb_modif = 0; + uint64_t j_start = 0; + while (nb_modif < nn) { + // follow the cycle that start with j_start + uint64_t j = j_start; + double tmp1 = res[j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + double tmp2 = res[new_j_n]; + res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1; + tmp1 = tmp2; + // move to the new location, and store the number of items modified + ++nb_modif; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator. + ++j_start; + } +} + +EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + uint64_t nb_modif = 0; + uint64_t j_start = 0; + while (nb_modif < nn) { + // follow the cycle that start with j_start + uint64_t j = j_start; + int64_t tmp1 = res[j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + int64_t tmp2 = res[new_j_n]; + res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1; + tmp1 = tmp2; + // move to the new location, and store the number of items modified + ++nb_modif; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator. + ++j_start; + } +} + +EXPORT void rnx_mul_xp_minus_one_inplace(uint64_t nn, int64_t p, double* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + uint64_t nb_modif = 0; + uint64_t j_start = 0; + while (nb_modif < nn) { + // follow the cycle that start with j_start + uint64_t j = j_start; + double tmp1 = res[j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + double tmp2 = res[new_j_n]; + res[new_j_n] = ((new_j < nn) ? tmp1 : -tmp1) - res[new_j_n]; + tmp1 = tmp2; + // move to the new location, and store the number of items modified + ++nb_modif; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator. + ++j_start; + } +} + +__always_inline int64_t get_base_k_digit(const int64_t x, const uint64_t base_k) { + return (x << (64 - base_k)) >> (64 - base_k); +} + +__always_inline int64_t get_base_k_carry(const int64_t x, const int64_t digit, const uint64_t base_k) { + return (x - digit) >> base_k; +} + +EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in, + const int64_t* carry_in) { + assert(in); + if (out != 0) { + if (carry_in != 0x0 && carry_out != 0x0) { + // with carry in and carry out is computed + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + const int64_t cin = carry_in[i]; + + int64_t digit = get_base_k_digit(x, base_k); + int64_t carry = get_base_k_carry(x, digit, base_k); + int64_t digit_plus_cin = digit + cin; + int64_t y = get_base_k_digit(digit_plus_cin, base_k); + int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k); + + out[i] = y; + carry_out[i] = cout; + } + } else if (carry_in != 0) { + // with carry in and carry out is dropped + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + const int64_t cin = carry_in[i]; + + int64_t digit = get_base_k_digit(x, base_k); + int64_t digit_plus_cin = digit + cin; + int64_t y = get_base_k_digit(digit_plus_cin, base_k); + + out[i] = y; + } + + } else if (carry_out != 0) { + // no carry in and carry out is computed + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + + int64_t y = get_base_k_digit(x, base_k); + int64_t cout = get_base_k_carry(x, y, base_k); + + out[i] = y; + carry_out[i] = cout; + } + + } else { + // no carry in and carry out is dropped + for (uint64_t i = 0; i < nn; ++i) { + out[i] = get_base_k_digit(in[i], base_k); + } + } + } else { + assert(carry_out); + if (carry_in != 0x0) { + // with carry in and carry out is computed + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + const int64_t cin = carry_in[i]; + + int64_t digit = get_base_k_digit(x, base_k); + int64_t carry = get_base_k_carry(x, digit, base_k); + int64_t digit_plus_cin = digit + cin; + int64_t y = get_base_k_digit(digit_plus_cin, base_k); + int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k); + + carry_out[i] = cout; + } + } else { + // no carry in and carry out is computed + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + + int64_t y = get_base_k_digit(x, base_k); + int64_t cout = get_base_k_carry(x, y, base_k); + + carry_out[i] = cout; + } + } + } +} + +void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + const uint64_t m = nn >> 1; + // reduce p mod 2n + p &= _2mn; + // uint64_t vp = p & _2mn; + /// uint64_t target_modifs = m >> 1; + // we proceed by increasing binary valuation + for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn; + binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) { + // In this loop, we are going to treat the orbit of indexes = binval mod 2.binval. + // At the beginning of this loop we have: + // vp = binval * p mod 2n + // target_modif = m / binval (i.e. order of the orbit binval % 2.binval) + + // first, handle the orders 1 and 2. + // if p*binval == binval % 2n: we're done! + if (vp == binval) return; + // if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit! + if (((vp + binval) & _2mn) == 0) { + for (uint64_t j = binval; j < m; j += binval) { + int64_t tmp = res[j]; + res[j] = -res[nn - j]; + res[nn - j] = -tmp; + } + res[m] = -res[m]; + return; + } + // if p*binval == binval + n % 2n: negate the orbit and exit + if (((vp - binval) & _mn) == 0) { + for (uint64_t j = binval; j < nn; j += 2 * binval) { + res[j] = -res[j]; + } + return; + } + // if p*binval == n - binval % 2n: mirror the orbit and continue! + if (((vp + binval) & _mn) == 0) { + for (uint64_t j = binval; j < m; j += 2 * binval) { + int64_t tmp = res[j]; + res[j] = res[nn - j]; + res[nn - j] = tmp; + } + continue; + } + // otherwise we will follow the orbit cycles, + // starting from binval and -binval in parallel + uint64_t j_start = binval; + uint64_t nb_modif = 0; + while (nb_modif < orb_size) { + // follow the cycle that start with j_start + uint64_t j = j_start; + int64_t tmp1 = res[j]; + int64_t tmp2 = res[nn - j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j * p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + int64_t tmp1a = res[new_j_n]; + int64_t tmp2a = res[nn - new_j_n]; + if (new_j < nn) { + res[new_j_n] = tmp1; + res[nn - new_j_n] = tmp2; + } else { + res[new_j_n] = -tmp1; + res[nn - new_j_n] = -tmp2; + } + tmp1 = tmp1a; + tmp2 = tmp2a; + // move to the new location, and store the number of items modified + nb_modif += 2; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do *5, because 5 is a generator. + j_start = (5 * j_start) & _mn; + } + } +} + +void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + const uint64_t m = nn >> 1; + // reduce p mod 2n + p &= _2mn; + // uint64_t vp = p & _2mn; + /// uint64_t target_modifs = m >> 1; + // we proceed by increasing binary valuation + for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn; + binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) { + // In this loop, we are going to treat the orbit of indexes = binval mod 2.binval. + // At the beginning of this loop we have: + // vp = binval * p mod 2n + // target_modif = m / binval (i.e. order of the orbit binval % 2.binval) + + // first, handle the orders 1 and 2. + // if p*binval == binval % 2n: we're done! + if (vp == binval) return; + // if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit! + if (((vp + binval) & _2mn) == 0) { + for (uint64_t j = binval; j < m; j += binval) { + double tmp = res[j]; + res[j] = -res[nn - j]; + res[nn - j] = -tmp; + } + res[m] = -res[m]; + return; + } + // if p*binval == binval + n % 2n: negate the orbit and exit + if (((vp - binval) & _mn) == 0) { + for (uint64_t j = binval; j < nn; j += 2 * binval) { + res[j] = -res[j]; + } + return; + } + // if p*binval == n - binval % 2n: mirror the orbit and continue! + if (((vp + binval) & _mn) == 0) { + for (uint64_t j = binval; j < m; j += 2 * binval) { + double tmp = res[j]; + res[j] = res[nn - j]; + res[nn - j] = tmp; + } + continue; + } + // otherwise we will follow the orbit cycles, + // starting from binval and -binval in parallel + uint64_t j_start = binval; + uint64_t nb_modif = 0; + while (nb_modif < orb_size) { + // follow the cycle that start with j_start + uint64_t j = j_start; + double tmp1 = res[j]; + double tmp2 = res[nn - j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j * p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + double tmp1a = res[new_j_n]; + double tmp2a = res[nn - new_j_n]; + if (new_j < nn) { + res[new_j_n] = tmp1; + res[nn - new_j_n] = tmp2; + } else { + res[new_j_n] = -tmp1; + res[nn - new_j_n] = -tmp2; + } + tmp1 = tmp1a; + tmp2 = tmp2a; + // move to the new location, and store the number of items modified + nb_modif += 2; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do *5, because 5 is a generator. + j_start = (5 * j_start) & _mn; + } + } +} diff --git a/spqlios/lib/spqlios/coeffs/coeffs_arithmetic.h b/spqlios/lib/spqlios/coeffs/coeffs_arithmetic.h new file mode 100644 index 0000000..d4c9e5a --- /dev/null +++ b/spqlios/lib/spqlios/coeffs/coeffs_arithmetic.h @@ -0,0 +1,78 @@ +#ifndef SPQLIOS_COEFFS_ARITHMETIC_H +#define SPQLIOS_COEFFS_ARITHMETIC_H + +#include "../commons.h" + +/** res = a + b */ +EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); +EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); +/** res = a - b */ +EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); +EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); +/** res = -a */ +EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a); +EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a); +/** res = a */ +EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a); +/** res = 0 */ +EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res); + +/** res = a / m where m is a power of 2 */ +EXPORT void rnx_divide_by_m_ref(uint64_t nn, double m, double* res, const double* a); +EXPORT void rnx_divide_by_m_avx(uint64_t nn, double m, double* res, const double* a); + +/** + * @param res = X^p *in mod X^nn +1 + * @param nn the ring dimension + * @param p a power for the rotation -2nn <= p <= 2nn + * @param in is a rnx/znx vector of dimension nn + */ +EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in); +EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in); +EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res); +EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res); + +/** + * @brief res(X) = in(X^p) + * @param nn the ring dimension + * @param p is odd integer and must be between 0 < p < 2nn + * @param in is a rnx/znx vector of dimension nn + */ +EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in); +EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in); +EXPORT void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res); +EXPORT void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res); + +/** + * @brief res = (X^p-1).in + * @param nn the ring dimension + * @param p must be between -2nn <= p <= 2nn + * @param in is a rnx/znx vector of dimension nn + */ +EXPORT void rnx_mul_xp_minus_one(uint64_t nn, int64_t p, double* res, const double* in); +EXPORT void znx_mul_xp_minus_one(uint64_t nn, int64_t p, int64_t* res, const int64_t* in); +EXPORT void rnx_mul_xp_minus_one_inplace(uint64_t nn, int64_t p, double* res); + +/** + * @brief Normalize input plus carry mod-2^k. The following + * equality holds @c {in + carry_in == out + carry_out . 2^k}. + * + * @c in must be in [-2^62 .. 2^62] + * + * @c out is in [ -2^(base_k-1), 2^(base_k-1) [. + * + * @c carry_in and @carry_out have at most 64+1-k bits. + * + * Null @c carry_in or @c carry_out are ignored. + * + * @param[in] nn the ring dimension + * @param[in] base_k the base k + * @param out output normalized znx + * @param carry_out output carry znx + * @param[in] in input znx + * @param[in] carry_in input carry znx + */ +EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in, + const int64_t* carry_in); + +#endif // SPQLIOS_COEFFS_ARITHMETIC_H diff --git a/spqlios/lib/spqlios/coeffs/coeffs_arithmetic_avx.c b/spqlios/lib/spqlios/coeffs/coeffs_arithmetic_avx.c new file mode 100644 index 0000000..8fcd608 --- /dev/null +++ b/spqlios/lib/spqlios/coeffs/coeffs_arithmetic_avx.c @@ -0,0 +1,124 @@ +#include + +#include "../commons_private.h" +#include "coeffs_arithmetic.h" + +// res = a + b. dimension n must be a power of 2 +EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) { + if (nn <= 2) { + if (nn == 1) { + res[0] = a[0] + b[0]; + } else { + _mm_storeu_si128((__m128i*)res, // + _mm_add_epi64( // + _mm_loadu_si128((__m128i*)a), // + _mm_loadu_si128((__m128i*)b))); + } + } else { + const __m256i* aa = (__m256i*)a; + const __m256i* bb = (__m256i*)b; + __m256i* rr = (__m256i*)res; + __m256i* const rrend = (__m256i*)(res + nn); + do { + _mm256_storeu_si256(rr, // + _mm256_add_epi64( // + _mm256_loadu_si256(aa), // + _mm256_loadu_si256(bb))); + ++rr; + ++aa; + ++bb; + } while (rr < rrend); + } +} + +// res = a - b. dimension n must be a power of 2 +EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) { + if (nn <= 2) { + if (nn == 1) { + res[0] = a[0] - b[0]; + } else { + _mm_storeu_si128((__m128i*)res, // + _mm_sub_epi64( // + _mm_loadu_si128((__m128i*)a), // + _mm_loadu_si128((__m128i*)b))); + } + } else { + const __m256i* aa = (__m256i*)a; + const __m256i* bb = (__m256i*)b; + __m256i* rr = (__m256i*)res; + __m256i* const rrend = (__m256i*)(res + nn); + do { + _mm256_storeu_si256(rr, // + _mm256_sub_epi64( // + _mm256_loadu_si256(aa), // + _mm256_loadu_si256(bb))); + ++rr; + ++aa; + ++bb; + } while (rr < rrend); + } +} + +EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a) { + if (nn <= 2) { + if (nn == 1) { + res[0] = -a[0]; + } else { + _mm_storeu_si128((__m128i*)res, // + _mm_sub_epi64( // + _mm_set1_epi64x(0), // + _mm_loadu_si128((__m128i*)a))); + } + } else { + const __m256i* aa = (__m256i*)a; + __m256i* rr = (__m256i*)res; + __m256i* const rrend = (__m256i*)(res + nn); + do { + _mm256_storeu_si256(rr, // + _mm256_sub_epi64( // + _mm256_set1_epi64x(0), // + _mm256_loadu_si256(aa))); + ++rr; + ++aa; + } while (rr < rrend); + } +} + +EXPORT void rnx_divide_by_m_avx(uint64_t n, double m, double* res, const double* a) { + // TODO: see if there is a faster way of dividing by a power of 2? + const double invm = 1. / m; + if (n < 8) { + switch (n) { + case 1: + *res = *a * invm; + break; + case 2: + _mm_storeu_pd(res, // + _mm_mul_pd(_mm_loadu_pd(a), // + _mm_set1_pd(invm))); + break; + case 4: + _mm256_storeu_pd(res, // + _mm256_mul_pd(_mm256_loadu_pd(a), // + _mm256_set1_pd(invm))); + break; + default: + NOT_SUPPORTED(); // non-power of 2 + } + return; + } + const __m256d invm256 = _mm256_set1_pd(invm); + double* rr = res; + const double* aa = a; + const double* const aaend = a + n; + do { + _mm256_storeu_pd(rr, // + _mm256_mul_pd(_mm256_loadu_pd(aa), // + invm256)); + _mm256_storeu_pd(rr + 4, // + _mm256_mul_pd(_mm256_loadu_pd(aa + 4), // + invm256)); + rr += 8; + aa += 8; + } while (aa < aaend); +} diff --git a/spqlios/lib/spqlios/commons.c b/spqlios/lib/spqlios/commons.c new file mode 100644 index 0000000..adcff79 --- /dev/null +++ b/spqlios/lib/spqlios/commons.c @@ -0,0 +1,165 @@ +#include "commons.h" + +#include +#include +#include + +EXPORT void* UNDEFINED_p_ii(int32_t n, int32_t m) { UNDEFINED(); } +EXPORT void* UNDEFINED_p_uu(uint32_t n, uint32_t m) { UNDEFINED(); } +EXPORT double* UNDEFINED_dp_pi(const void* p, int32_t n) { UNDEFINED(); } +EXPORT void* UNDEFINED_vp_pi(const void* p, int32_t n) { UNDEFINED(); } +EXPORT void* UNDEFINED_vp_pu(const void* p, uint32_t n) { UNDEFINED(); } +EXPORT void UNDEFINED_v_vpdp(const void* p, double* a) { UNDEFINED(); } +EXPORT void UNDEFINED_v_vpvp(const void* p, void* a) { UNDEFINED(); } +EXPORT double* NOT_IMPLEMENTED_dp_i(int32_t n) { NOT_IMPLEMENTED(); } +EXPORT void* NOT_IMPLEMENTED_vp_i(int32_t n) { NOT_IMPLEMENTED(); } +EXPORT void* NOT_IMPLEMENTED_vp_u(uint32_t n) { NOT_IMPLEMENTED(); } +EXPORT void NOT_IMPLEMENTED_v_dp(double* a) { NOT_IMPLEMENTED(); } +EXPORT void NOT_IMPLEMENTED_v_vp(void* p) { NOT_IMPLEMENTED(); } +EXPORT void NOT_IMPLEMENTED_v_idpdpdp(int32_t n, double* a, const double* b, const double* c) { NOT_IMPLEMENTED(); } +EXPORT void NOT_IMPLEMENTED_v_uvpcvpcvp(uint32_t n, void* r, const void* a, const void* b) { NOT_IMPLEMENTED(); } +EXPORT void NOT_IMPLEMENTED_v_uvpvpcvp(uint32_t n, void* a, void* b, const void* o) { NOT_IMPLEMENTED(); } + +#ifdef _WIN32 +#define __always_inline inline __attribute((always_inline)) +#endif + +void internal_accurate_sincos(double* rcos, double* rsin, double x) { + double _4_x_over_pi = 4 * x / M_PI; + int64_t int_part = ((int64_t)rint(_4_x_over_pi)) & 7; + double frac_part = _4_x_over_pi - (double)(int_part); + double frac_x = M_PI * frac_part / 4.; + // compute the taylor series + double cosp = 1.; + double sinp = 0.; + double powx = 1.; + int64_t nn = 0; + while (fabs(powx) > 1e-20) { + ++nn; + powx = powx * frac_x / (double)(nn); // x^n/n! + switch (nn & 3) { + case 0: + cosp += powx; + break; + case 1: + sinp += powx; + break; + case 2: + cosp -= powx; + break; + case 3: + sinp -= powx; + break; + default: + abort(); // impossible + } + } + // final multiplication + switch (int_part) { + case 0: + *rcos = cosp; + *rsin = sinp; + break; + case 1: + *rcos = M_SQRT1_2 * (cosp - sinp); + *rsin = M_SQRT1_2 * (cosp + sinp); + break; + case 2: + *rcos = -sinp; + *rsin = cosp; + break; + case 3: + *rcos = -M_SQRT1_2 * (cosp + sinp); + *rsin = M_SQRT1_2 * (cosp - sinp); + break; + case 4: + *rcos = -cosp; + *rsin = -sinp; + break; + case 5: + *rcos = -M_SQRT1_2 * (cosp - sinp); + *rsin = -M_SQRT1_2 * (cosp + sinp); + break; + case 6: + *rcos = sinp; + *rsin = -cosp; + break; + case 7: + *rcos = M_SQRT1_2 * (cosp + sinp); + *rsin = -M_SQRT1_2 * (cosp - sinp); + break; + default: + abort(); // impossible + } + if (fabs(cos(x) - *rcos) > 1e-10 || fabs(sin(x) - *rsin) > 1e-10) { + printf("cos(%.17lf) =? %.17lf instead of %.17lf\n", x, *rcos, cos(x)); + printf("sin(%.17lf) =? %.17lf instead of %.17lf\n", x, *rsin, sin(x)); + printf("fracx = %.17lf\n", frac_x); + printf("cosp = %.17lf\n", cosp); + printf("sinp = %.17lf\n", sinp); + printf("nn = %d\n", (int)(nn)); + } +} + +double internal_accurate_cos(double x) { + double rcos, rsin; + internal_accurate_sincos(&rcos, &rsin, x); + return rcos; +} +double internal_accurate_sin(double x) { + double rcos, rsin; + internal_accurate_sincos(&rcos, &rsin, x); + return rsin; +} + +EXPORT void spqlios_debug_free(void* addr) { free((uint8_t*)addr - 64); } + +EXPORT void* spqlios_debug_alloc(uint64_t size) { return (uint8_t*)malloc(size + 64) + 64; } + +EXPORT void spqlios_free(void* addr) { +#ifndef NDEBUG + // in debug mode, we deallocated with spqlios_debug_free() + spqlios_debug_free(addr); +#else + // in release mode, the function will free aligned memory +#ifdef _WIN32 + _aligned_free(addr); +#else + free(addr); +#endif +#endif +} + +EXPORT void* spqlios_alloc(uint64_t size) { +#ifndef NDEBUG + // in debug mode, the function will not necessarily have any particular alignment + // it will also ensure that memory can only be deallocated with spqlios_free() + return spqlios_debug_alloc(size); +#else + // in release mode, the function will return 64-bytes aligned memory +#ifdef _WIN32 + void* reps = _aligned_malloc((size + 63) & (UINT64_C(-64)), 64); +#else + void* reps = aligned_alloc(64, (size + 63) & (UINT64_C(-64))); +#endif + if (reps == 0) FATAL_ERROR("Out of memory"); + return reps; +#endif +} + +EXPORT void* spqlios_alloc_custom_align(uint64_t align, uint64_t size) { +#ifndef NDEBUG + // in debug mode, the function will not necessarily have any particular alignment + // it will also ensure that memory can only be deallocated with spqlios_free() + return spqlios_debug_alloc(size); +#else + // in release mode, the function will return aligned memory +#ifdef _WIN32 + void* reps = _aligned_malloc(size, align); +#else + void* reps = aligned_alloc(align, size); +#endif + if (reps == 0) FATAL_ERROR("Out of memory"); + return reps; +#endif +} \ No newline at end of file diff --git a/spqlios/lib/spqlios/commons.h b/spqlios/lib/spqlios/commons.h new file mode 100644 index 0000000..653d083 --- /dev/null +++ b/spqlios/lib/spqlios/commons.h @@ -0,0 +1,77 @@ +#ifndef SPQLIOS_COMMONS_H +#define SPQLIOS_COMMONS_H + +#ifdef __cplusplus +#include +#include +#include +#define EXPORT extern "C" +#define EXPORT_DECL extern "C" +#else +#include +#include +#include +#define EXPORT +#define EXPORT_DECL extern +#define nullptr 0x0; +#endif + +#define UNDEFINED() \ + { \ + fprintf(stderr, "UNDEFINED!!!\n"); \ + abort(); \ + } +#define NOT_IMPLEMENTED() \ + { \ + fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \ + abort(); \ + } +#define FATAL_ERROR(MESSAGE) \ + { \ + fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \ + abort(); \ + } + +EXPORT void* UNDEFINED_p_ii(int32_t n, int32_t m); +EXPORT void* UNDEFINED_p_uu(uint32_t n, uint32_t m); +EXPORT double* UNDEFINED_dp_pi(const void* p, int32_t n); +EXPORT void* UNDEFINED_vp_pi(const void* p, int32_t n); +EXPORT void* UNDEFINED_vp_pu(const void* p, uint32_t n); +EXPORT void UNDEFINED_v_vpdp(const void* p, double* a); +EXPORT void UNDEFINED_v_vpvp(const void* p, void* a); +EXPORT double* NOT_IMPLEMENTED_dp_i(int32_t n); +EXPORT void* NOT_IMPLEMENTED_vp_i(int32_t n); +EXPORT void* NOT_IMPLEMENTED_vp_u(uint32_t n); +EXPORT void NOT_IMPLEMENTED_v_dp(double* a); +EXPORT void NOT_IMPLEMENTED_v_vp(void* p); +EXPORT void NOT_IMPLEMENTED_v_idpdpdp(int32_t n, double* a, const double* b, const double* c); +EXPORT void NOT_IMPLEMENTED_v_uvpcvpcvp(uint32_t n, void* r, const void* a, const void* b); +EXPORT void NOT_IMPLEMENTED_v_uvpvpcvp(uint32_t n, void* a, void* b, const void* o); + +// windows + +#if defined(_WIN32) || defined(__APPLE__) +#define __always_inline inline __attribute((always_inline)) +#endif + +EXPORT void spqlios_free(void* address); + +EXPORT void* spqlios_alloc(uint64_t size); +EXPORT void* spqlios_alloc_custom_align(uint64_t align, uint64_t size); + +#define USE_LIBM_SIN_COS +#ifndef USE_LIBM_SIN_COS +// if at some point, we want to remove the libm dependency, we can +// consider this: +EXPORT double internal_accurate_cos(double x); +EXPORT double internal_accurate_sin(double x); +EXPORT void internal_accurate_sincos(double* rcos, double* rsin, double x); +#define m_accurate_cos internal_accurate_cos +#define m_accurate_sin internal_accurate_sin +#else +// let's use libm sin and cos +#define m_accurate_cos cos +#define m_accurate_sin sin +#endif + +#endif // SPQLIOS_COMMONS_H diff --git a/spqlios/lib/spqlios/commons_private.c b/spqlios/lib/spqlios/commons_private.c new file mode 100644 index 0000000..fddc190 --- /dev/null +++ b/spqlios/lib/spqlios/commons_private.c @@ -0,0 +1,55 @@ +#include "commons_private.h" + +#include +#include + +#include "commons.h" + +EXPORT void* spqlios_error(const char* error) { + fputs(error, stderr); + abort(); + return nullptr; +} +EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2) { + if (!ptr2) { + free(ptr); + } + return ptr2; +} + +EXPORT uint32_t log2m(uint32_t m) { + uint32_t a = m - 1; + if (m & a) FATAL_ERROR("m must be a power of two"); + a = (a & 0x55555555u) + ((a >> 1) & 0x55555555u); + a = (a & 0x33333333u) + ((a >> 2) & 0x33333333u); + a = (a & 0x0F0F0F0Fu) + ((a >> 4) & 0x0F0F0F0Fu); + a = (a & 0x00FF00FFu) + ((a >> 8) & 0x00FF00FFu); + return (a & 0x0000FFFFu) + ((a >> 16) & 0x0000FFFFu); +} + +EXPORT uint64_t is_not_pow2_double(void* doublevalue) { return (*(uint64_t*)doublevalue) & 0x7FFFFFFFFFFFFUL; } + +uint32_t revbits(uint32_t nbits, uint32_t value) { + uint32_t res = 0; + for (uint32_t i = 0; i < nbits; ++i) { + res = (res << 1) + (value & 1); + value >>= 1; + } + return res; +} + +/** + * @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,... + * essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/ +double fracrevbits(uint32_t i) { + if (i == 0) return 0; + if (i == 1) return 0.5; + if (i % 2 == 0) + return fracrevbits(i / 2) / 2.; + else + return fracrevbits((i - 1) / 2) / 2. + 0.5; +} + +uint64_t ceilto64b(uint64_t size) { return (size + UINT64_C(63)) & (UINT64_C(-64)); } + +uint64_t ceilto32b(uint64_t size) { return (size + UINT64_C(31)) & (UINT64_C(-32)); } diff --git a/spqlios/lib/spqlios/commons_private.h b/spqlios/lib/spqlios/commons_private.h new file mode 100644 index 0000000..e2b0514 --- /dev/null +++ b/spqlios/lib/spqlios/commons_private.h @@ -0,0 +1,72 @@ +#ifndef SPQLIOS_COMMONS_PRIVATE_H +#define SPQLIOS_COMMONS_PRIVATE_H + +#include "commons.h" + +#ifdef __cplusplus +#include +#include +#include +#else +#include +#include +#include +#define nullptr 0x0; +#endif + +/** @brief log2 of a power of two (UB if m is not a power of two) */ +EXPORT uint32_t log2m(uint32_t m); + +/** @brief checks if the doublevalue is a power of two */ +EXPORT uint64_t is_not_pow2_double(void* doublevalue); + +#define UNDEFINED() \ + { \ + fprintf(stderr, "UNDEFINED!!!\n"); \ + abort(); \ + } +#define NOT_IMPLEMENTED() \ + { \ + fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \ + abort(); \ + } +#define NOT_SUPPORTED() \ + { \ + fprintf(stderr, "NOT SUPPORTED!!!\n"); \ + abort(); \ + } +#define FATAL_ERROR(MESSAGE) \ + { \ + fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \ + abort(); \ + } + +#define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)]) + +/** @brief reports the error and returns nullptr */ +EXPORT void* spqlios_error(const char* error); +/** @brief if ptr2 is not null, returns ptr, otherwise free ptr and return null */ +EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2); + +#ifdef __x86_64__ +#define CPU_SUPPORTS __builtin_cpu_supports +#else +// TODO for now, we do not have any optimization for non x86 targets +#define CPU_SUPPORTS(xxxx) 0 +#endif + +/** @brief returns the n bits of value in reversed order */ +EXPORT uint32_t revbits(uint32_t nbits, uint32_t value); + +/** + * @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,... + * essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/ +EXPORT double fracrevbits(uint32_t i); + +/** @brief smallest multiple of 64 higher or equal to size */ +EXPORT uint64_t ceilto64b(uint64_t size); + +/** @brief smallest multiple of 32 higher or equal to size */ +EXPORT uint64_t ceilto32b(uint64_t size); + +#endif // SPQLIOS_COMMONS_PRIVATE_H diff --git a/spqlios/lib/spqlios/cplx/README.md b/spqlios/lib/spqlios/cplx/README.md new file mode 100644 index 0000000..443cfa4 --- /dev/null +++ b/spqlios/lib/spqlios/cplx/README.md @@ -0,0 +1,22 @@ +In this folder, we deal with the full complex FFT in `C[X] mod X^M-i`. +One complex is represented by two consecutive doubles `(real,imag)` +Note that a real polynomial sum_{j=0}^{N-1} p_j.X^j mod X^N+1 +corresponds to the complex polynomial of half degree `M=N/2`: +`sum_{j=0}^{M-1} (p_{j} + i.p_{j+M}) X^j mod X^M-i` + +For a complex polynomial A(X) sum c_i X^i of degree M-1 +or a real polynomial sum a_i X^i of degree N + +coefficient space: +a_0,a_M,a_1,a_{M+1},...,a_{M-1},a_{2M-1} +or equivalently +Re(c_0),Im(c_0),Re(c_1),Im(c_1),...Re(c_{M-1}),Im(c_{M-1}) + +eval space: +c(omega_{0}),...,c(omega_{M-1}) + +where +omega_j = omega^{1+rev_{2N}(j)} +and omega = exp(i.pi/N) + +rev_{2N}(j) is the number that has the log2(2N) bits of j in reverse order. \ No newline at end of file diff --git a/spqlios/lib/spqlios/cplx/cplx_common.c b/spqlios/lib/spqlios/cplx/cplx_common.c new file mode 100644 index 0000000..1d9f509 --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_common.c @@ -0,0 +1,80 @@ +#include "cplx_fft_internal.h" + +void cplx_set(CPLX r, const CPLX a) { + r[0] = a[0]; + r[1] = a[1]; +} +void cplx_neg(CPLX r, const CPLX a) { + r[0] = -a[0]; + r[1] = -a[1]; +} +void cplx_add(CPLX r, const CPLX a, const CPLX b) { + r[0] = a[0] + b[0]; + r[1] = a[1] + b[1]; +} +void cplx_sub(CPLX r, const CPLX a, const CPLX b) { + r[0] = a[0] - b[0]; + r[1] = a[1] - b[1]; +} +void cplx_mul(CPLX r, const CPLX a, const CPLX b) { + double re = a[0] * b[0] - a[1] * b[1]; + r[1] = a[0] * b[1] + a[1] * b[0]; + r[0] = re; +} + +/** + * @brief splits 2h evaluations of one polynomials into 2 times h evaluations of even/odd polynomial + * Input: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y) + * Output: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z) + * where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z + * @param h number of "coefficients" h >= 1 + * @param data 2h complex coefficients interleaved and 256b aligned + * @param powom y represented as (yre,yim) + */ +EXPORT void cplx_split_fft_ref(int32_t h, CPLX* data, const CPLX powom) { + CPLX* d0 = data; + CPLX* d1 = data + h; + for (uint64_t i = 0; i < h; ++i) { + CPLX diff; + cplx_sub(diff, d0[i], d1[i]); + cplx_add(d0[i], d0[i], d1[i]); + cplx_mul(d1[i], diff, powom); + } +} + +/** + * @brief Do two layers of itwiddle (i.e. split). + * Input/output: d0,d1,d2,d3 of length h + * Algo: + * itwiddle(d0,d1,om[0]),itwiddle(d2,d3,i.om[0]) + * itwiddle(d0,d2,om[1]),itwiddle(d1,d3,om[1]) + * @param h number of "coefficients" h >= 1 + * @param data 4h complex coefficients interleaved and 256b aligned + * @param powom om[0] (re,im) and om[1] where om[1]=om[0]^2 + */ +EXPORT void cplx_bisplit_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]) { + CPLX* d0 = data; + CPLX* d2 = data + 2*h; + const CPLX* om0 = powom; + CPLX iom0; + iom0[0]=powom[0][1]; + iom0[1]=-powom[0][0]; + const CPLX* om1 = powom+1; + cplx_split_fft_ref(h, d0, *om0); + cplx_split_fft_ref(h, d2, iom0); + cplx_split_fft_ref(2*h, d0, *om1); +} + +/** + * Input: Q(y),Q(-y) + * Output: P_0(z),P_1(z) + * where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z + * @param data 2 complexes coefficients interleaved and 256b aligned + * @param powom (z,-z) interleaved: (zre,zim,-zre,-zim) + */ +void split_fft_last_ref(CPLX* data, const CPLX powom) { + CPLX diff; + cplx_sub(diff, data[0], data[1]); + cplx_add(data[0], data[0], data[1]); + cplx_mul(data[1], diff, powom); +} diff --git a/spqlios/lib/spqlios/cplx/cplx_conversions.c b/spqlios/lib/spqlios/cplx/cplx_conversions.c new file mode 100644 index 0000000..d912ccf --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_conversions.c @@ -0,0 +1,158 @@ +#include +#include + +#include "../commons_private.h" +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +EXPORT void cplx_from_znx32_ref(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { + const uint32_t m = precomp->m; + const int32_t* inre = x; + const int32_t* inim = x + m; + CPLX* out = r; + for (uint32_t i = 0; i < m; ++i) { + out[i][0] = (double)inre[i]; + out[i][1] = (double)inim[i]; + } +} + +EXPORT void cplx_from_tnx32_ref(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { + static const double _2p32 = 1. / (INT64_C(1) << 32); + const uint32_t m = precomp->m; + const int32_t* inre = x; + const int32_t* inim = x + m; + CPLX* out = r; + for (uint32_t i = 0; i < m; ++i) { + out[i][0] = ((double)inre[i]) * _2p32; + out[i][1] = ((double)inim[i]) * _2p32; + } +} + +EXPORT void cplx_to_tnx32_ref(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) { + static const double _2p32 = (INT64_C(1) << 32); + const uint32_t m = precomp->m; + double factor = _2p32 / precomp->divisor; + int32_t* outre = r; + int32_t* outim = r + m; + const CPLX* in = x; + // Note: this formula will only work if abs(in) < 2^32 + for (uint32_t i = 0; i < m; ++i) { + outre[i] = (int32_t)(int64_t)(rint(in[i][0] * factor)); + outim[i] = (int32_t)(int64_t)(rint(in[i][1] * factor)); + } +} + +void* init_cplx_from_znx32_precomp(CPLX_FROM_ZNX32_PRECOMP* res, uint32_t m) { + res->m = m; + if (CPU_SUPPORTS("avx2")) { + if (m >= 8) { + res->function = cplx_from_znx32_avx2_fma; + } else { + res->function = cplx_from_znx32_ref; + } + } else { + res->function = cplx_from_znx32_ref; + } + return res; +} + +CPLX_FROM_ZNX32_PRECOMP* new_cplx_from_znx32_precomp(uint32_t m) { + CPLX_FROM_ZNX32_PRECOMP* res = malloc(sizeof(CPLX_FROM_ZNX32_PRECOMP)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_cplx_from_znx32_precomp(res, m)); +} + +void* init_cplx_from_tnx32_precomp(CPLX_FROM_TNX32_PRECOMP* res, uint32_t m) { + res->m = m; + if (CPU_SUPPORTS("avx2")) { + if (m >= 8) { + res->function = cplx_from_tnx32_avx2_fma; + } else { + res->function = cplx_from_tnx32_ref; + } + } else { + res->function = cplx_from_tnx32_ref; + } + return res; +} + +CPLX_FROM_TNX32_PRECOMP* new_cplx_from_tnx32_precomp(uint32_t m) { + CPLX_FROM_TNX32_PRECOMP* res = malloc(sizeof(CPLX_FROM_TNX32_PRECOMP)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_cplx_from_tnx32_precomp(res, m)); +} + +void* init_cplx_to_tnx32_precomp(CPLX_TO_TNX32_PRECOMP* res, uint32_t m, double divisor, uint32_t log2overhead) { + if (is_not_pow2_double(&divisor)) return spqlios_error("divisor must be a power of 2"); + if (m & (m - 1)) return spqlios_error("m must be a power of 2"); + if (log2overhead > 52) return spqlios_error("log2overhead is too large"); + res->m = m; + res->divisor = divisor; + if (CPU_SUPPORTS("avx2")) { + if (log2overhead <= 18) { + if (m >= 8) { + res->function = cplx_to_tnx32_avx2_fma; + } else { + res->function = cplx_to_tnx32_ref; + } + } else { + res->function = cplx_to_tnx32_ref; + } + } else { + res->function = cplx_to_tnx32_ref; + } + return res; +} + +EXPORT CPLX_TO_TNX32_PRECOMP* new_cplx_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead) { + CPLX_TO_TNX32_PRECOMP* res = malloc(sizeof(CPLX_TO_TNX32_PRECOMP)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_cplx_to_tnx32_precomp(res, m, divisor, log2overhead)); +} + +/** + * @brief Simpler API for the znx32 to cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_from_znx32_simple(uint32_t m, void* r, const int32_t* x) { + // not checking for log2bound which is not relevant here + static CPLX_FROM_ZNX32_PRECOMP precomp[32]; + CPLX_FROM_ZNX32_PRECOMP* p = precomp + log2m(m); + if (!p->function) { + if (!init_cplx_from_znx32_precomp(p, m)) abort(); + } + p->function(p, r, x); +} + +/** + * @brief Simpler API for the tnx32 to cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_from_tnx32_simple(uint32_t m, void* r, const int32_t* x) { + static CPLX_FROM_TNX32_PRECOMP precomp[32]; + CPLX_FROM_TNX32_PRECOMP* p = precomp + log2m(m); + if (!p->function) { + if (!init_cplx_from_tnx32_precomp(p, m)) abort(); + } + p->function(p, r, x); +} +/** + * @brief Simpler API for the cplx to tnx32 conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x) { + struct LAST_CPLX_TO_TNX32_PRECOMP { + CPLX_TO_TNX32_PRECOMP p; + double last_divisor; + double last_log2over; + }; + static __thread struct LAST_CPLX_TO_TNX32_PRECOMP precomp[32]; + struct LAST_CPLX_TO_TNX32_PRECOMP* p = precomp + log2m(m); + if (!p->p.function || divisor != p->last_divisor || log2overhead != p->last_log2over) { + memset(p, 0, sizeof(*p)); + if (!init_cplx_to_tnx32_precomp(&p->p, m, divisor, log2overhead)) abort(); + p->last_divisor = divisor; + p->last_log2over = log2overhead; + } + p->p.function(&p->p, r, x); +} diff --git a/spqlios/lib/spqlios/cplx/cplx_conversions_avx2_fma.c b/spqlios/lib/spqlios/cplx/cplx_conversions_avx2_fma.c new file mode 100644 index 0000000..9dc19db --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_conversions_avx2_fma.c @@ -0,0 +1,104 @@ +#include +#include + +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +typedef int32_t I8MEM[8]; +typedef double D4MEM[4]; + +__always_inline void cplx_from_any_fma(uint64_t m, void* r, const int32_t* x, const __m256i C, const __m256d R) { + const __m256i S = _mm256_set1_epi32(0x80000000); + const I8MEM* inre = (I8MEM*)(x); + const I8MEM* inim = (I8MEM*)(x+m); + D4MEM* out = (D4MEM*) r; + const uint64_t ms8 = m/8; + for (uint32_t i=0; im; + cplx_from_any_fma(m, r, x, C, R); +} + +EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { + //note: the hex code of 2^-1 + 2^30 is 0x4130000080000000 + const __m256i C = _mm256_set1_epi32(0x41300000); + const __m256d R = _mm256_set1_pd(0.5 + (INT64_C(1) << 20)); + // double XX = (double)(INT64_C(1) + (INT64_C(1)<<31) + (INT64_C(1)<<52))/(INT64_C(1)<<32); + //printf("\n\n%016lx\n", *(uint64_t*)&XX); + //abort(); + const uint64_t m = precomp->m; + cplx_from_any_fma(m, r, x, C, R); +} + +EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) { + const __m256d R = _mm256_set1_pd((0.5 + (INT64_C(3) << 19)) * precomp->divisor); + const __m256i MASK = _mm256_set1_epi64x(0xFFFFFFFFUL); + const __m256i S = _mm256_set1_epi32(0x80000000); + //const __m256i IDX = _mm256_set_epi32(0,4,1,5,2,6,3,7); + const __m256i IDX = _mm256_set_epi32(7,3,6,2,5,1,4,0); + const uint64_t m = precomp->m; + const uint64_t ms8 = m/8; + I8MEM* outre = (I8MEM*) r; + I8MEM* outim = (I8MEM*) (r+m); + const D4MEM* in = x; + // Note: this formula will only work if abs(in) < 2^32 + for (uint32_t i=0; ifunction(tables, r, a); +} +EXPORT void cplx_from_tnx32(const CPLX_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a) { + tables->function(tables, r, a); +} +EXPORT void cplx_to_tnx32(const CPLX_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a) { + tables->function(tables, r, a); +} +EXPORT void cplx_fftvec_mul(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) { + tables->function(tables, r, a, b); +} +EXPORT void cplx_fftvec_addmul(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) { + tables->function(tables, r, a, b); +} diff --git a/spqlios/lib/spqlios/cplx/cplx_fallbacks_aarch64.c b/spqlios/lib/spqlios/cplx/cplx_fallbacks_aarch64.c new file mode 100644 index 0000000..6233f46 --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_fallbacks_aarch64.c @@ -0,0 +1,41 @@ +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) { + UNDEFINED(); // not defined for non x86 targets +} +EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) { + UNDEFINED(); +} +EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) { + UNDEFINED(); +} +EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, + const void* b) { + UNDEFINED(); +} +EXPORT void cplx_fft16_avx_fma(void* data, const void* omega) { UNDEFINED(); } +EXPORT void cplx_ifft16_avx_fma(void* data, const void* omega) { UNDEFINED(); } +EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); } +EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); } +EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c) { UNDEFINED(); } +EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* tables, void* data){UNDEFINED()} EXPORT + void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* itables, void* data){UNDEFINED()} EXPORT + void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om){ + UNDEFINED()} EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, + const void* om){UNDEFINED()} EXPORT + void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice, + const void* om){UNDEFINED()} EXPORT + void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice, + const void* om){UNDEFINED()} + +// DEPRECATED? +EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT + void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT + void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a){UNDEFINED()} + +// executors +//EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* itables, void* data) { +// itables->function(itables, data); +//} +//EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data) { tables->function(tables, data); } diff --git a/spqlios/lib/spqlios/cplx/cplx_fft.h b/spqlios/lib/spqlios/cplx/cplx_fft.h new file mode 100644 index 0000000..01699bd --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_fft.h @@ -0,0 +1,221 @@ +#ifndef SPQLIOS_CPLX_FFT_H +#define SPQLIOS_CPLX_FFT_H + +#include "../commons.h" + +typedef struct cplx_fft_precomp CPLX_FFT_PRECOMP; +typedef struct cplx_ifft_precomp CPLX_IFFT_PRECOMP; +typedef struct cplx_mul_precomp CPLX_FFTVEC_MUL_PRECOMP; +typedef struct cplx_addmul_precomp CPLX_FFTVEC_ADDMUL_PRECOMP; +typedef struct cplx_from_znx32_precomp CPLX_FROM_ZNX32_PRECOMP; +typedef struct cplx_from_tnx32_precomp CPLX_FROM_TNX32_PRECOMP; +typedef struct cplx_to_tnx32_precomp CPLX_TO_TNX32_PRECOMP; +typedef struct cplx_to_znx32_precomp CPLX_TO_ZNX32_PRECOMP; +typedef struct cplx_from_rnx64_precomp CPLX_FROM_RNX64_PRECOMP; +typedef struct cplx_to_rnx64_precomp CPLX_TO_RNX64_PRECOMP; +typedef struct cplx_round_to_rnx64_precomp CPLX_ROUND_TO_RNX64_PRECOMP; + +/** + * @brief precomputes fft tables. + * The FFT tables contains a constant section that is required for efficient FFT operations in dimension nn. + * The resulting pointer is to be passed as "tables" argument to any call to the fft function. + * The user can optionnally allocate zero or more computation buffers, which are scratch spaces that are contiguous to + * the constant tables in memory, and allow for more efficient operations. It is the user's responsibility to ensure + * that each of those buffers are never used simultaneously by two ffts on different threads at the same time. The fft + * table must be deleted by delete_fft_precomp after its last usage. + */ +EXPORT CPLX_FFT_PRECOMP* new_cplx_fft_precomp(uint32_t m, uint32_t num_buffers); + +/** + * @brief gets the address of a fft buffer allocated during new_fft_precomp. + * This buffer can be used as data pointer in subsequent calls to fft, + * and does not need to be released afterwards. + */ +EXPORT void* cplx_fft_precomp_get_buffer(const CPLX_FFT_PRECOMP* tables, uint32_t buffer_index); + +/** + * @brief allocates a new fft buffer. + * This buffer can be used as data pointer in subsequent calls to fft, + * and must be deleted afterwards by calling delete_fft_buffer. + */ +EXPORT void* new_cplx_fft_buffer(uint32_t m); + +/** + * @brief allocates a new fft buffer. + * This buffer can be used as data pointer in subsequent calls to fft, + * and must be deleted afterwards by calling delete_fft_buffer. + */ +EXPORT void delete_cplx_fft_buffer(void* buffer); + +/** + * @brief deallocates a fft table and all its built-in buffers. + */ +#define delete_cplx_fft_precomp free + +/** + * @brief computes a direct fft in-place over data. + */ +EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data); + +EXPORT CPLX_IFFT_PRECOMP* new_cplx_ifft_precomp(uint32_t m, uint32_t num_buffers); +EXPORT void* cplx_ifft_precomp_get_buffer(const CPLX_IFFT_PRECOMP* tables, uint32_t buffer_index); +EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* tables, void* data); +#define delete_cplx_ifft_precomp free + +EXPORT CPLX_FFTVEC_MUL_PRECOMP* new_cplx_fftvec_mul_precomp(uint32_t m); +EXPORT void cplx_fftvec_mul(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b); +#define delete_cplx_fftvec_mul_precomp free + +EXPORT CPLX_FFTVEC_ADDMUL_PRECOMP* new_cplx_fftvec_addmul_precomp(uint32_t m); +EXPORT void cplx_fftvec_addmul(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b); +#define delete_cplx_fftvec_addmul_precomp free + +/** + * @brief prepares a conversion from ZnX to the cplx layout. + * All the coefficients must be strictly lower than 2^log2bound in absolute value. Any attempt to use + * this function on a larger coefficient is undefined behaviour. The resulting precomputed data must + * be freed with `new_cplx_from_znx32_precomp` + * @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m + * int32 coefficients in natural order modulo X^n+1 + * @param log2bound bound on the input coefficients. Must be between 0 and 32 + */ +EXPORT CPLX_FROM_ZNX32_PRECOMP* new_cplx_from_znx32_precomp(uint32_t m); +/** + * @brief converts from ZnX to the cplx layout. + * @param tables precomputed data obtained by new_cplx_from_znx32_precomp. + * @param r resulting array of m complexes coefficients mod X^m-i + * @param x input array of n bounded integer coefficients mod X^n+1 + */ +EXPORT void cplx_from_znx32(const CPLX_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a); +/** @brief frees a precomputed conversion data initialized with new_cplx_from_znx32_precomp. */ +#define delete_cplx_from_znx32_precomp free + +/** + * @brief prepares a conversion from TnX to the cplx layout. + * @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m + * torus32 coefficients. The resulting precomputed data must + * be freed with `delete_cplx_from_tnx32_precomp` + */ +EXPORT CPLX_FROM_TNX32_PRECOMP* new_cplx_from_tnx32_precomp(uint32_t m); +/** + * @brief converts from TnX to the cplx layout. + * @param tables precomputed data obtained by new_cplx_from_tnx32_precomp. + * @param r resulting array of m complexes coefficients mod X^m-i + * @param x input array of n torus32 coefficients mod X^n+1 + */ +EXPORT void cplx_from_tnx32(const CPLX_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a); +/** @brief frees a precomputed conversion data initialized with new_cplx_from_tnx32_precomp. */ +#define delete_cplx_from_tnx32_precomp free + +/** + * @brief prepares a rescale and conversion from the cplx layout to TnX. + * @param m the target complex dimension m from C[X] mod X^m-i. Note that the outputs have n=2m + * torus32 coefficients. + * @param divisor must be a power of two. The inputs are rescaled by divisor before being reduced modulo 1. + * Remember that the output of an iFFT must be divided by m. + * @param log2overhead all inputs absolute values must be within divisor.2^log2overhead. + * For any inputs outside of these bounds, the conversion is undefined behaviour. + * The maximum supported log2overhead is 52, and the algorithm is faster for log2overhead=18. + */ +EXPORT CPLX_TO_TNX32_PRECOMP* new_cplx_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead); +/** + * @brief rescale, converts and reduce mod 1 from cplx layout to torus32. + * @param tables precomputed data obtained by new_cplx_from_tnx32_precomp. + * @param r resulting array of n torus32 coefficients mod X^n+1 + * @param x input array of m cplx coefficients mod X^m-i + */ +EXPORT void cplx_to_tnx32(const CPLX_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a); +#define delete_cplx_to_tnx32_precomp free + +EXPORT CPLX_TO_ZNX32_PRECOMP* new_cplx_to_znx32_precomp(uint32_t m, double divisor); +EXPORT void cplx_to_znx32(const CPLX_TO_ZNX32_PRECOMP* precomp, int32_t* r, const void* x); +#define delete_cplx_to_znx32_simple free + +EXPORT CPLX_FROM_RNX64_PRECOMP* new_cplx_from_rnx64_simple(uint32_t m); +EXPORT void cplx_from_rnx64(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x); +#define delete_cplx_from_rnx64_simple free + +EXPORT CPLX_TO_RNX64_PRECOMP* new_cplx_to_rnx64(uint32_t m, double divisor); +EXPORT void cplx_to_rnx64(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x); +#define delete_cplx_round_to_rnx64_simple free + +EXPORT CPLX_ROUND_TO_RNX64_PRECOMP* new_cplx_round_to_rnx64(uint32_t m, double divisor, uint32_t log2bound); +EXPORT void cplx_round_to_rnx64(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x); +#define delete_cplx_round_to_rnx64_simple free + +/** + * @brief Simpler API for the fft function. + * For each dimension, the precomputed tables for this dimension are generated automatically. + * It is advised to do one dry-run per desired dimension before using in a multithread environment */ +EXPORT void cplx_fft_simple(uint32_t m, void* data); +/** + * @brief Simpler API for the ifft function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension in the main thread before using in a multithread + * environment */ +EXPORT void cplx_ifft_simple(uint32_t m, void* data); +/** + * @brief Simpler API for the fftvec multiplication function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b); +/** + * @brief Simpler API for the fftvec addmul function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b); +/** + * @brief Simpler API for the znx32 to cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_from_znx32_simple(uint32_t m, void* r, const int32_t* x); +/** + * @brief Simpler API for the tnx32 to cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_from_tnx32_simple(uint32_t m, void* r, const int32_t* x); +/** + * @brief Simpler API for the cplx to tnx32 conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x); + +/** + * @brief converts, divides and round from cplx to znx32 (simple API) + * @param m the complex dimension + * @param divisor the divisor: a power of two, often m after an ifft + * @param r the result: must be a double array of size 2m. r must be distinct from x + * @param x the input: must hold m complex numbers. + */ +EXPORT void cplx_to_znx32_simple(uint32_t m, double divisor, int32_t* r, const void* x); + +/** + * @brief converts from rnx64 to cplx (simple API) + * The bound on the output is assumed to be within ]2^-31,2^31[. + * Any coefficient that would fall outside this range is undefined behaviour. + * @param m the complex dimension + * @param r the result: must be an array of m complex numbers. r must be distinct from x + * @param x the input: must be an array of 2m doubles. + */ +EXPORT void cplx_from_rnx64_simple(uint32_t m, void* r, const double* x); + +/** + * @brief converts, divides from cplx to rnx64 (simple API) + * @param m the complex dimension + * @param divisor the divisor: a power of two, often m after an ifft + * @param r the result: must be a double array of size 2m. r must be distinct from x + * @param x the input: must hold m complex numbers. + */ +EXPORT void cplx_to_rnx64_simple(uint32_t m, double divisor, double* r, const void* x); + +/** + * @brief converts, divides and round to integer from cplx to rnx32 (simple API) + * @param m the complex dimension + * @param divisor the divisor: a power of two, often m after an ifft + * @param log2bound a guarantee on the log2bound of the output. log2bound<=48 will use a more efficient algorithm. + * @param r the result: must be a double array of size 2m. r must be distinct from x + * @param x the input: must hold m complex numbers. + */ +EXPORT void cplx_round_to_rnx64_simple(uint32_t m, double divisor, uint32_t log2bound, double* r, const void* x); + +#endif // SPQLIOS_CPLX_FFT_H diff --git a/spqlios/lib/spqlios/cplx/cplx_fft16_avx_fma.s b/spqlios/lib/spqlios/cplx/cplx_fft16_avx_fma.s new file mode 100644 index 0000000..40e3985 --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_fft16_avx_fma.s @@ -0,0 +1,156 @@ +# shifted FFT over X^16-i +# 1st argument (rdi) contains 16 complexes +# 2nd argument (rsi) contains: 8 complexes +# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma +# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta) +# j = sqrt(i), k=sqrt(j) +.globl cplx_fft16_avx_fma +cplx_fft16_avx_fma: +vmovupd (%rdi),%ymm8 +vmovupd 0x20(%rdi),%ymm9 +vmovupd 0x40(%rdi),%ymm10 +vmovupd 0x60(%rdi),%ymm11 +vmovupd 0x80(%rdi),%ymm12 +vmovupd 0xa0(%rdi),%ymm13 +vmovupd 0xc0(%rdi),%ymm14 +vmovupd 0xe0(%rdi),%ymm15 + +.first_pass: +vmovupd (%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vshufpd $5, %ymm12, %ymm12, %ymm4 +vshufpd $5, %ymm13, %ymm13, %ymm5 +vshufpd $5, %ymm14, %ymm14, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm1,%ymm5 +vmulpd %ymm6,%ymm1,%ymm6 +vmulpd %ymm7,%ymm1,%ymm7 +vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4 +vfmaddsub231pd %ymm13, %ymm0, %ymm5 +vfmaddsub231pd %ymm14, %ymm0, %ymm6 +vfmaddsub231pd %ymm15, %ymm0, %ymm7 +vsubpd %ymm4,%ymm8,%ymm12 +vsubpd %ymm5,%ymm9,%ymm13 +vsubpd %ymm6,%ymm10,%ymm14 +vsubpd %ymm7,%ymm11,%ymm15 +vaddpd %ymm4,%ymm8,%ymm8 +vaddpd %ymm5,%ymm9,%ymm9 +vaddpd %ymm6,%ymm10,%ymm10 +vaddpd %ymm7,%ymm11,%ymm11 + +.second_pass: +vmovupd 16(%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vshufpd $5, %ymm10, %ymm10, %ymm4 +vshufpd $5, %ymm11, %ymm11, %ymm5 +vshufpd $5, %ymm14, %ymm14, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm1,%ymm5 +vmulpd %ymm6,%ymm0,%ymm6 +vmulpd %ymm7,%ymm0,%ymm7 +vfmaddsub231pd %ymm10, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4 +vfmaddsub231pd %ymm11, %ymm0, %ymm5 +vfmsubadd231pd %ymm14, %ymm1, %ymm6 +vfmsubadd231pd %ymm15, %ymm1, %ymm7 +vsubpd %ymm4,%ymm8,%ymm10 +vsubpd %ymm5,%ymm9,%ymm11 +vaddpd %ymm6,%ymm12,%ymm14 +vaddpd %ymm7,%ymm13,%ymm15 +vaddpd %ymm4,%ymm8,%ymm8 +vaddpd %ymm5,%ymm9,%ymm9 +vsubpd %ymm6,%ymm12,%ymm12 +vsubpd %ymm7,%ymm13,%ymm13 + +.third_pass: +vmovupd 32(%rsi),%xmm0 /* gamma */ +vmovupd 48(%rsi),%xmm2 /* delta */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 +vinsertf128 $1, %xmm2, %ymm2, %ymm2 +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vshufpd $5, %ymm9, %ymm9, %ymm4 +vshufpd $5, %ymm11, %ymm11, %ymm5 +vshufpd $5, %ymm13, %ymm13, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm0,%ymm5 +vmulpd %ymm6,%ymm3,%ymm6 +vmulpd %ymm7,%ymm2,%ymm7 +vfmaddsub231pd %ymm9, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4 +vfmsubadd231pd %ymm11, %ymm1, %ymm5 +vfmaddsub231pd %ymm13, %ymm2, %ymm6 +vfmsubadd231pd %ymm15, %ymm3, %ymm7 +vsubpd %ymm4,%ymm8,%ymm9 +vaddpd %ymm5,%ymm10,%ymm11 +vsubpd %ymm6,%ymm12,%ymm13 +vaddpd %ymm7,%ymm14,%ymm15 +vaddpd %ymm4,%ymm8,%ymm8 +vsubpd %ymm5,%ymm10,%ymm10 +vaddpd %ymm6,%ymm12,%ymm12 +vsubpd %ymm7,%ymm14,%ymm14 + +.fourth_pass: +vmovupd 64(%rsi),%ymm0 /* gamma */ +vmovupd 96(%rsi),%ymm2 /* delta */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5 -- x gamma +vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7 -- x igamma +vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13 -- x delta +vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15 -- x idelta +vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6 +vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14 +vshufpd $5, %ymm4, %ymm4, %ymm12 +vshufpd $5, %ymm5, %ymm5, %ymm13 +vshufpd $5, %ymm6, %ymm6, %ymm14 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm12,%ymm1,%ymm12 +vmulpd %ymm13,%ymm0,%ymm13 +vmulpd %ymm14,%ymm3,%ymm14 +vmulpd %ymm15,%ymm2,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12 +vfmsubadd231pd %ymm5, %ymm1, %ymm13 +vfmaddsub231pd %ymm6, %ymm2, %ymm14 +vfmsubadd231pd %ymm7, %ymm3, %ymm15 +vsubpd %ymm12,%ymm8,%ymm4 +vaddpd %ymm13,%ymm9,%ymm5 +vsubpd %ymm14,%ymm10,%ymm6 +vaddpd %ymm15,%ymm11,%ymm7 +vaddpd %ymm12,%ymm8,%ymm8 +vsubpd %ymm13,%ymm9,%ymm9 +vaddpd %ymm14,%ymm10,%ymm10 +vsubpd %ymm15,%ymm11,%ymm11 + +vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma +vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma +vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta +vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta +vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14 +vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6 + +.save_and_return: +vmovupd %ymm8,(%rdi) +vmovupd %ymm9,0x20(%rdi) +vmovupd %ymm10,0x40(%rdi) +vmovupd %ymm11,0x60(%rdi) +vmovupd %ymm12,0x80(%rdi) +vmovupd %ymm13,0xa0(%rdi) +vmovupd %ymm14,0xc0(%rdi) +vmovupd %ymm15,0xe0(%rdi) +ret +.size cplx_fft16_avx_fma, .-cplx_fft16_avx_fma +.section .note.GNU-stack,"",@progbits diff --git a/spqlios/lib/spqlios/cplx/cplx_fft16_avx_fma_win32.s b/spqlios/lib/spqlios/cplx/cplx_fft16_avx_fma_win32.s new file mode 100644 index 0000000..d7d4bf1 --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_fft16_avx_fma_win32.s @@ -0,0 +1,190 @@ + .text + .p2align 4 + .globl cplx_fft16_avx_fma + .def cplx_fft16_avx_fma; .scl 2; .type 32; .endef +cplx_fft16_avx_fma: + + pushq %rdi + pushq %rsi + movq %rcx,%rdi + movq %rdx,%rsi + subq $0x100,%rsp + movdqu %xmm6,(%rsp) + movdqu %xmm7,0x10(%rsp) + movdqu %xmm8,0x20(%rsp) + movdqu %xmm9,0x30(%rsp) + movdqu %xmm10,0x40(%rsp) + movdqu %xmm11,0x50(%rsp) + movdqu %xmm12,0x60(%rsp) + movdqu %xmm13,0x70(%rsp) + movdqu %xmm14,0x80(%rsp) + movdqu %xmm15,0x90(%rsp) + callq cplx_fft16_avx_fma_amd64 + movdqu (%rsp),%xmm6 + movdqu 0x10(%rsp),%xmm7 + movdqu 0x20(%rsp),%xmm8 + movdqu 0x30(%rsp),%xmm9 + movdqu 0x40(%rsp),%xmm10 + movdqu 0x50(%rsp),%xmm11 + movdqu 0x60(%rsp),%xmm12 + movdqu 0x70(%rsp),%xmm13 + movdqu 0x80(%rsp),%xmm14 + movdqu 0x90(%rsp),%xmm15 + addq $0x100,%rsp + popq %rsi + popq %rdi + retq + +# shifted FFT over X^16-i +# 1st argument (rdi) contains 16 complexes +# 2nd argument (rsi) contains: 8 complexes +# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma +# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta) +# j = sqrt(i), k=sqrt(j) +cplx_fft16_avx_fma_amd64: +vmovupd (%rdi),%ymm8 +vmovupd 0x20(%rdi),%ymm9 +vmovupd 0x40(%rdi),%ymm10 +vmovupd 0x60(%rdi),%ymm11 +vmovupd 0x80(%rdi),%ymm12 +vmovupd 0xa0(%rdi),%ymm13 +vmovupd 0xc0(%rdi),%ymm14 +vmovupd 0xe0(%rdi),%ymm15 + +.first_pass: +vmovupd (%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vshufpd $5, %ymm12, %ymm12, %ymm4 +vshufpd $5, %ymm13, %ymm13, %ymm5 +vshufpd $5, %ymm14, %ymm14, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm1,%ymm5 +vmulpd %ymm6,%ymm1,%ymm6 +vmulpd %ymm7,%ymm1,%ymm7 +vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4 +vfmaddsub231pd %ymm13, %ymm0, %ymm5 +vfmaddsub231pd %ymm14, %ymm0, %ymm6 +vfmaddsub231pd %ymm15, %ymm0, %ymm7 +vsubpd %ymm4,%ymm8,%ymm12 +vsubpd %ymm5,%ymm9,%ymm13 +vsubpd %ymm6,%ymm10,%ymm14 +vsubpd %ymm7,%ymm11,%ymm15 +vaddpd %ymm4,%ymm8,%ymm8 +vaddpd %ymm5,%ymm9,%ymm9 +vaddpd %ymm6,%ymm10,%ymm10 +vaddpd %ymm7,%ymm11,%ymm11 + +.second_pass: +vmovupd 16(%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vshufpd $5, %ymm10, %ymm10, %ymm4 +vshufpd $5, %ymm11, %ymm11, %ymm5 +vshufpd $5, %ymm14, %ymm14, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm1,%ymm5 +vmulpd %ymm6,%ymm0,%ymm6 +vmulpd %ymm7,%ymm0,%ymm7 +vfmaddsub231pd %ymm10, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4 +vfmaddsub231pd %ymm11, %ymm0, %ymm5 +vfmsubadd231pd %ymm14, %ymm1, %ymm6 +vfmsubadd231pd %ymm15, %ymm1, %ymm7 +vsubpd %ymm4,%ymm8,%ymm10 +vsubpd %ymm5,%ymm9,%ymm11 +vaddpd %ymm6,%ymm12,%ymm14 +vaddpd %ymm7,%ymm13,%ymm15 +vaddpd %ymm4,%ymm8,%ymm8 +vaddpd %ymm5,%ymm9,%ymm9 +vsubpd %ymm6,%ymm12,%ymm12 +vsubpd %ymm7,%ymm13,%ymm13 + +.third_pass: +vmovupd 32(%rsi),%xmm0 /* gamma */ +vmovupd 48(%rsi),%xmm2 /* delta */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 +vinsertf128 $1, %xmm2, %ymm2, %ymm2 +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vshufpd $5, %ymm9, %ymm9, %ymm4 +vshufpd $5, %ymm11, %ymm11, %ymm5 +vshufpd $5, %ymm13, %ymm13, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm0,%ymm5 +vmulpd %ymm6,%ymm3,%ymm6 +vmulpd %ymm7,%ymm2,%ymm7 +vfmaddsub231pd %ymm9, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4 +vfmsubadd231pd %ymm11, %ymm1, %ymm5 +vfmaddsub231pd %ymm13, %ymm2, %ymm6 +vfmsubadd231pd %ymm15, %ymm3, %ymm7 +vsubpd %ymm4,%ymm8,%ymm9 +vaddpd %ymm5,%ymm10,%ymm11 +vsubpd %ymm6,%ymm12,%ymm13 +vaddpd %ymm7,%ymm14,%ymm15 +vaddpd %ymm4,%ymm8,%ymm8 +vsubpd %ymm5,%ymm10,%ymm10 +vaddpd %ymm6,%ymm12,%ymm12 +vsubpd %ymm7,%ymm14,%ymm14 + +.fourth_pass: +vmovupd 64(%rsi),%ymm0 /* gamma */ +vmovupd 96(%rsi),%ymm2 /* delta */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5 -- x gamma +vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7 -- x igamma +vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13 -- x delta +vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15 -- x idelta +vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6 +vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14 +vshufpd $5, %ymm4, %ymm4, %ymm12 +vshufpd $5, %ymm5, %ymm5, %ymm13 +vshufpd $5, %ymm6, %ymm6, %ymm14 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm12,%ymm1,%ymm12 +vmulpd %ymm13,%ymm0,%ymm13 +vmulpd %ymm14,%ymm3,%ymm14 +vmulpd %ymm15,%ymm2,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12 +vfmsubadd231pd %ymm5, %ymm1, %ymm13 +vfmaddsub231pd %ymm6, %ymm2, %ymm14 +vfmsubadd231pd %ymm7, %ymm3, %ymm15 +vsubpd %ymm12,%ymm8,%ymm4 +vaddpd %ymm13,%ymm9,%ymm5 +vsubpd %ymm14,%ymm10,%ymm6 +vaddpd %ymm15,%ymm11,%ymm7 +vaddpd %ymm12,%ymm8,%ymm8 +vsubpd %ymm13,%ymm9,%ymm9 +vaddpd %ymm14,%ymm10,%ymm10 +vsubpd %ymm15,%ymm11,%ymm11 + +vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma +vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma +vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta +vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta +vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14 +vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6 + +.save_and_return: +vmovupd %ymm8,(%rdi) +vmovupd %ymm9,0x20(%rdi) +vmovupd %ymm10,0x40(%rdi) +vmovupd %ymm11,0x60(%rdi) +vmovupd %ymm12,0x80(%rdi) +vmovupd %ymm13,0xa0(%rdi) +vmovupd %ymm14,0xc0(%rdi) +vmovupd %ymm15,0xe0(%rdi) +ret diff --git a/spqlios/lib/spqlios/cplx/cplx_fft_asserts.c b/spqlios/lib/spqlios/cplx/cplx_fft_asserts.c new file mode 100644 index 0000000..c5937ef --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_fft_asserts.c @@ -0,0 +1,8 @@ +#include "cplx_fft_private.h" +#include "../commons_private.h" + +__always_inline void my_asserts() { + STATIC_ASSERT(sizeof(FFT_FUNCTION)==8); + STATIC_ASSERT(sizeof(CPLX_FFT_PRECOMP)==40); + STATIC_ASSERT(sizeof(CPLX_IFFT_PRECOMP)==40); +} diff --git a/spqlios/lib/spqlios/cplx/cplx_fft_avx2_fma.c b/spqlios/lib/spqlios/cplx/cplx_fft_avx2_fma.c new file mode 100644 index 0000000..35b6ace --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_fft_avx2_fma.c @@ -0,0 +1,266 @@ +#include +#include + +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +typedef double D4MEM[4]; + +/** + * @brief complex fft via bfs strategy (for m between 2 and 8) + * @param dat the data to run the algorithm on + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +void cplx_fft_avx2_fma_bfs_2(D4MEM* dat, const D4MEM** omg, uint32_t m) { + double* data = (double*)dat; + int32_t _2nblock = m >> 1; // = h in ref code + D4MEM* const finaldd = (D4MEM*)(data + 2 * m); + while (_2nblock >= 2) { + int32_t nblock = _2nblock >> 1; // =h/2 in ref code + D4MEM* dd = (D4MEM*)data; + do { + const __m256d om = _mm256_load_pd(*omg[0]); + const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om)); + const __m256d omre = _mm256_unpacklo_pd(om, om); + D4MEM* const ddend = (dd + nblock); + D4MEM* ddmid = ddend; + do { + const __m256d b = _mm256_loadu_pd(ddmid[0]); + const __m256d t1 = _mm256_mul_pd(b, omre); + const __m256d barb = _mm256_shuffle_pd(b, b, 5); + const __m256d t2 = _mm256_fmadd_pd(barb, omim, t1); + const __m256d a = _mm256_loadu_pd(dd[0]); + const __m256d newa = _mm256_add_pd(a, t2); + const __m256d newb = _mm256_sub_pd(a, t2); + _mm256_storeu_pd(dd[0], newa); + _mm256_storeu_pd(ddmid[0], newb); + dd += 1; + ddmid += 1; + } while (dd < ddend); + dd += nblock; + *omg += 1; + } while (dd < finaldd); + _2nblock >>= 1; + } + // last iteration when _2nblock == 1 + { + D4MEM* dd = (D4MEM*)data; + do { + const __m256d om = _mm256_load_pd(*omg[0]); + const __m256d omre = _mm256_unpacklo_pd(om, om); + const __m256d omim = _mm256_unpackhi_pd(om, om); + const __m256d ab = _mm256_loadu_pd(dd[0]); + const __m256d bb = _mm256_permute4x64_pd(ab, 0b11101110); + const __m256d bbbar = _mm256_permute4x64_pd(ab, 0b10111011); + const __m256d t1 = _mm256_mul_pd(bbbar, omim); + const __m256d t2 = _mm256_fmaddsub_pd(bb, omre, t1); + const __m256d aa = _mm256_permute4x64_pd(ab, 0b01000100); + const __m256d newab = _mm256_add_pd(aa, t2); + _mm256_storeu_pd(dd[0], newab); + dd += 1; + *omg += 1; + } while (dd < finaldd); + } +} + +__always_inline void cplx_twiddle_fft_avx2(int32_t h, D4MEM* data, const void* omg) { + const __m256d om = _mm256_loadu_pd(omg); + const __m256d omim = _mm256_unpackhi_pd(om, om); + const __m256d omre = _mm256_unpacklo_pd(om, om); + D4MEM* d0 = data; + D4MEM* const ddend = d0 + (h>>1); + D4MEM* d1 = ddend; + do { + const __m256d b = _mm256_loadu_pd(d1[0]); + const __m256d barb = _mm256_shuffle_pd(b, b, 5); + const __m256d t1 = _mm256_mul_pd(barb, omim); + const __m256d t2 = _mm256_fmaddsub_pd(b, omre, t1); + const __m256d a = _mm256_loadu_pd(d0[0]); + const __m256d newa = _mm256_add_pd(a, t2); + const __m256d newb = _mm256_sub_pd(a, t2); + _mm256_storeu_pd(d0[0], newa); + _mm256_storeu_pd(d1[0], newb); + d0 += 1; + d1 += 1; + } while (d0 < ddend); +} + +__always_inline void cplx_bitwiddle_fft_avx2(int32_t h, void* data, const void* powom) { + const __m256d omx = _mm256_loadu_pd(powom); + const __m256d oma = _mm256_permute2f128_pd(omx, omx, 0x00); + const __m256d omb = _mm256_permute2f128_pd(omx, omx, 0x11); + const __m256d omaim = _mm256_unpackhi_pd(oma, oma); + const __m256d omare = _mm256_unpacklo_pd(oma, oma); + const __m256d ombim = _mm256_unpackhi_pd(omb, omb); + const __m256d ombre = _mm256_unpacklo_pd(omb, omb); + D4MEM* d0 = (D4MEM*) data; + D4MEM* const ddend = d0 + (h>>1); + D4MEM* d1 = ddend; + D4MEM* d2 = d0+h; + D4MEM* d3 = d1+h; + __m256d reg0,reg1,reg2,reg3,tmp0,tmp1; + do { + reg0 = _mm256_loadu_pd(d0[0]); + reg1 = _mm256_loadu_pd(d1[0]); + reg2 = _mm256_loadu_pd(d2[0]); + reg3 = _mm256_loadu_pd(d3[0]); + tmp0 = _mm256_shuffle_pd(reg2, reg2, 5); + tmp1 = _mm256_shuffle_pd(reg3, reg3, 5); + tmp0 = _mm256_mul_pd(tmp0, omaim); + tmp1 = _mm256_mul_pd(tmp1, omaim); + tmp0 = _mm256_fmaddsub_pd(reg2, omare, tmp0); + tmp1 = _mm256_fmaddsub_pd(reg3, omare, tmp1); + reg2 = _mm256_sub_pd(reg0, tmp0); + reg3 = _mm256_sub_pd(reg1, tmp1); + reg0 = _mm256_add_pd(reg0, tmp0); + reg1 = _mm256_add_pd(reg1, tmp1); + //-------------------------------------- + tmp0 = _mm256_shuffle_pd(reg1, reg1, 5); + tmp1 = _mm256_shuffle_pd(reg3, reg3, 5); + tmp0 = _mm256_mul_pd(tmp0, ombim); //(r,i) + tmp1 = _mm256_mul_pd(tmp1, ombre); //(-i,r) + tmp0 = _mm256_fmaddsub_pd(reg1, ombre, tmp0); + tmp1 = _mm256_fmsubadd_pd(reg3, ombim, tmp1); + reg1 = _mm256_sub_pd(reg0, tmp0); + reg3 = _mm256_add_pd(reg2, tmp1); + reg0 = _mm256_add_pd(reg0, tmp0); + reg2 = _mm256_sub_pd(reg2, tmp1); + ///// + _mm256_storeu_pd(d0[0], reg0); + _mm256_storeu_pd(d1[0], reg1); + _mm256_storeu_pd(d2[0], reg2); + _mm256_storeu_pd(d3[0], reg3); + d0 += 1; + d1 += 1; + d2 += 1; + d3 += 1; + } while (d0 < ddend); +} + +/** + * @brief complex fft via bfs strategy (for m >= 16) + * @param dat the data to run the algorithm on + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +void cplx_fft_avx2_fma_bfs_16(D4MEM* dat, const D4MEM** omg, uint32_t m) { + double* data = (double*)dat; + D4MEM* const finaldd = (D4MEM*)(data + 2 * m); + uint32_t mm = m; + uint32_t log2m = _mm_popcnt_u32(m-1); // log2(m) + if (log2m % 2 == 1) { + uint32_t h = mm>>1; + cplx_twiddle_fft_avx2(h, dat, **omg); + *omg += 1; + mm >>= 1; + } + while(mm>16) { + uint32_t h = mm/4; + for (CPLX* d = (CPLX*) data; d < (CPLX*) finaldd; d += mm) { + cplx_bitwiddle_fft_avx2(h, d, (CPLX*) *omg); + *omg += 1; + } + mm=h; + } + { + D4MEM* dd = (D4MEM*)data; + do { + cplx_fft16_avx_fma(dd, *omg); + dd += 8; + *omg += 4; + } while (dd < finaldd); + _mm256_zeroupper(); + } + /* + int32_t _2nblock = m >> 1; // = h in ref code + while (_2nblock >= 16) { + int32_t nblock = _2nblock >> 1; // =h/2 in ref code + D4MEM* dd = (D4MEM*)data; + do { + const __m256d om = _mm256_load_pd(*omg[0]); + const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om)); + const __m256d omre = _mm256_unpacklo_pd(om, om); + D4MEM* const ddend = (dd + nblock); + D4MEM* ddmid = ddend; + do { + const __m256d b = _mm256_loadu_pd(ddmid[0]); + const __m256d t1 = _mm256_mul_pd(b, omre); + const __m256d barb = _mm256_shuffle_pd(b, b, 5); + const __m256d t2 = _mm256_fmadd_pd(barb, omim, t1); + const __m256d a = _mm256_loadu_pd(dd[0]); + const __m256d newa = _mm256_add_pd(a, t2); + const __m256d newb = _mm256_sub_pd(a, t2); + _mm256_storeu_pd(dd[0], newa); + _mm256_storeu_pd(ddmid[0], newb); + dd += 1; + ddmid += 1; + } while (dd < ddend); + dd += nblock; + *omg += 1; + } while (dd < finaldd); + _2nblock >>= 1; + } + // last iteration when _2nblock == 8 + { + D4MEM* dd = (D4MEM*)data; + do { + cplx_fft16_avx_fma(dd, *omg); + dd += 8; + *omg += 4; + } while (dd < finaldd); + _mm256_zeroupper(); + } + */ +} + +/** + * @brief complex fft via dfs recursion (for m >= 16) + * @param dat the data to run the algorithm on + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +void cplx_fft_avx2_fma_rec_16(D4MEM* dat, const D4MEM** omg, uint32_t m) { + if (m <= 8) return cplx_fft_avx2_fma_bfs_2(dat, omg, m); + if (m <= 2048) return cplx_fft_avx2_fma_bfs_16(dat, omg, m); + double* data = (double*)dat; + int32_t _2nblock = m >> 1; // = h in ref code + int32_t nblock = _2nblock >> 1; // =h/2 in ref code + D4MEM* dd = (D4MEM*)data; + const __m256d om = _mm256_load_pd(*omg[0]); + const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om)); + const __m256d omre = _mm256_unpacklo_pd(om, om); + D4MEM* const ddend = (dd + nblock); + D4MEM* ddmid = ddend; + do { + const __m256d b = _mm256_loadu_pd(ddmid[0]); + const __m256d t1 = _mm256_mul_pd(b, omre); + const __m256d barb = _mm256_shuffle_pd(b, b, 5); + const __m256d t2 = _mm256_fmadd_pd(barb, omim, t1); + const __m256d a = _mm256_loadu_pd(dd[0]); + const __m256d newa = _mm256_add_pd(a, t2); + const __m256d newb = _mm256_sub_pd(a, t2); + _mm256_storeu_pd(dd[0], newa); + _mm256_storeu_pd(ddmid[0], newb); + dd += 1; + ddmid += 1; + } while (dd < ddend); + *omg += 1; + cplx_fft_avx2_fma_rec_16(dat, omg, _2nblock); + cplx_fft_avx2_fma_rec_16(ddend, omg, _2nblock); +} + +/** + * @brief complex fft via best strategy (for m>=1) + * @param dat the data to run the algorithm on: m complex numbers + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* precomp, void* d) { + const uint32_t m = precomp->m; + const D4MEM* omg = (D4MEM*)precomp->powomegas; + if (m <= 1) return; + if (m <= 8) return cplx_fft_avx2_fma_bfs_2(d, &omg, m); + if (m <= 2048) return cplx_fft_avx2_fma_bfs_16(d, &omg, m); + cplx_fft_avx2_fma_rec_16(d, &omg, m); +} diff --git a/spqlios/lib/spqlios/cplx/cplx_fft_avx512.c b/spqlios/lib/spqlios/cplx/cplx_fft_avx512.c new file mode 100644 index 0000000..23915b4 --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_fft_avx512.c @@ -0,0 +1,453 @@ +#include + +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +typedef double D2MEM[2]; +typedef double D4MEM[4]; +typedef double D8MEM[8]; + +EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, + const void* b) { + const uint32_t m = precomp->m; + const D8MEM* aa = (D8MEM*)a; + const D8MEM* bb = (D8MEM*)b; + D8MEM* rr = (D8MEM*)r; + const D8MEM* const aend = aa + (m >> 2); + do { + /* +BEGIN_TEMPLATE +const __m512d ari% = _mm512_loadu_pd(aa[%]); +const __m512d bri% = _mm512_loadu_pd(bb[%]); +const __m512d rri% = _mm512_loadu_pd(rr[%]); +const __m512d bir% = _mm512_shuffle_pd(bri%,bri%, 0b01010101); +const __m512d aii% = _mm512_shuffle_pd(ari%,ari%, 0b11111111); +const __m512d pro% = _mm512_fmaddsub_pd(aii%,bir%,rri%); +const __m512d arr% = _mm512_shuffle_pd(ari%,ari%, 0b00000000); +const __m512d res% = _mm512_fmaddsub_pd(arr%,bri%,pro%); +_mm512_storeu_pd(rr[%],res%); +rr += @; // ONCE +aa += @; // ONCE +bb += @; // ONCE +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 2 +const __m512d ari0 = _mm512_loadu_pd(aa[0]); +const __m512d ari1 = _mm512_loadu_pd(aa[1]); +const __m512d bri0 = _mm512_loadu_pd(bb[0]); +const __m512d bri1 = _mm512_loadu_pd(bb[1]); +const __m512d rri0 = _mm512_loadu_pd(rr[0]); +const __m512d rri1 = _mm512_loadu_pd(rr[1]); +const __m512d bir0 = _mm512_shuffle_pd(bri0,bri0, 0b01010101); +const __m512d bir1 = _mm512_shuffle_pd(bri1,bri1, 0b01010101); +const __m512d aii0 = _mm512_shuffle_pd(ari0,ari0, 0b11111111); +const __m512d aii1 = _mm512_shuffle_pd(ari1,ari1, 0b11111111); +const __m512d pro0 = _mm512_fmaddsub_pd(aii0,bir0,rri0); +const __m512d pro1 = _mm512_fmaddsub_pd(aii1,bir1,rri1); +const __m512d arr0 = _mm512_shuffle_pd(ari0,ari0, 0b00000000); +const __m512d arr1 = _mm512_shuffle_pd(ari1,ari1, 0b00000000); +const __m512d res0 = _mm512_fmaddsub_pd(arr0,bri0,pro0); +const __m512d res1 = _mm512_fmaddsub_pd(arr1,bri1,pro1); +_mm512_storeu_pd(rr[0],res0); +_mm512_storeu_pd(rr[1],res1); +rr += 2; // ONCE +aa += 2; // ONCE +bb += 2; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +#if 0 +EXPORT void cplx_fftvec_mul_fma(uint32_t m, void* r, const void* a, const void* b) { + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5); // conj of b +const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15); // im of a +const __m256d pro% = _mm256_mul_pd(aii%,bir%); +const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0); // rr of a +const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%); +_mm256_storeu_pd(rr[%],res%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0, 5); // conj of b +const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1, 5); // conj of b +const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2, 5); // conj of b +const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3, 5); // conj of b +const __m256d aii0 = _mm256_shuffle_pd(ari0,ari0, 15); // im of a +const __m256d aii1 = _mm256_shuffle_pd(ari1,ari1, 15); // im of a +const __m256d aii2 = _mm256_shuffle_pd(ari2,ari2, 15); // im of a +const __m256d aii3 = _mm256_shuffle_pd(ari3,ari3, 15); // im of a +const __m256d pro0 = _mm256_mul_pd(aii0,bir0); +const __m256d pro1 = _mm256_mul_pd(aii1,bir1); +const __m256d pro2 = _mm256_mul_pd(aii2,bir2); +const __m256d pro3 = _mm256_mul_pd(aii3,bir3); +const __m256d arr0 = _mm256_shuffle_pd(ari0,ari0, 0); // rr of a +const __m256d arr1 = _mm256_shuffle_pd(ari1,ari1, 0); // rr of a +const __m256d arr2 = _mm256_shuffle_pd(ari2,ari2, 0); // rr of a +const __m256d arr3 = _mm256_shuffle_pd(ari3,ari3, 0); // rr of a +const __m256d res0 = _mm256_fmaddsub_pd(arr0,bri0,pro0); +const __m256d res1 = _mm256_fmaddsub_pd(arr1,bri1,pro1); +const __m256d res2 = _mm256_fmaddsub_pd(arr2,bri2,pro2); +const __m256d res3 = _mm256_fmaddsub_pd(arr3,bri3,pro3); +_mm256_storeu_pd(rr[0],res0); +_mm256_storeu_pd(rr[1],res1); +_mm256_storeu_pd(rr[2],res2); +_mm256_storeu_pd(rr[3],res3); + // END_INTERLEAVE + rr += 4; + aa += 4; + bb += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b) { + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d res% = _mm256_add_pd(ari%,bri%); +_mm256_storeu_pd(rr[%],res%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d res0 = _mm256_add_pd(ari0,bri0); +const __m256d res1 = _mm256_add_pd(ari1,bri1); +const __m256d res2 = _mm256_add_pd(ari2,bri2); +const __m256d res3 = _mm256_add_pd(ari3,bri3); +_mm256_storeu_pd(rr[0],res0); +_mm256_storeu_pd(rr[1],res1); +_mm256_storeu_pd(rr[2],res2); +_mm256_storeu_pd(rr[3],res3); + // END_INTERLEAVE + rr += 4; + aa += 4; + bb += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b) { + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d sum% = _mm256_add_pd(ari%,bri%); +const __m256d rri% = _mm256_loadu_pd(rr[%]); +const __m256d res% = _mm256_sub_pd(rri%,sum%); +_mm256_storeu_pd(rr[%],res%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d sum0 = _mm256_add_pd(ari0,bri0); +const __m256d sum1 = _mm256_add_pd(ari1,bri1); +const __m256d sum2 = _mm256_add_pd(ari2,bri2); +const __m256d sum3 = _mm256_add_pd(ari3,bri3); +const __m256d rri0 = _mm256_loadu_pd(rr[0]); +const __m256d rri1 = _mm256_loadu_pd(rr[1]); +const __m256d rri2 = _mm256_loadu_pd(rr[2]); +const __m256d rri3 = _mm256_loadu_pd(rr[3]); +const __m256d res0 = _mm256_sub_pd(rri0,sum0); +const __m256d res1 = _mm256_sub_pd(rri1,sum1); +const __m256d res2 = _mm256_sub_pd(rri2,sum2); +const __m256d res3 = _mm256_sub_pd(rri3,sum3); +_mm256_storeu_pd(rr[0],res0); +_mm256_storeu_pd(rr[1],res1); +_mm256_storeu_pd(rr[2],res2); +_mm256_storeu_pd(rr[3],res3); + // END_INTERLEAVE + rr += 4; + aa += 4; + bb += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) { + const double(*aa)[4] = (double(*)[4])a; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +_mm256_storeu_pd(rr[%],ari%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +_mm256_storeu_pd(rr[0],ari0); +_mm256_storeu_pd(rr[1],ari1); +_mm256_storeu_pd(rr[2],ari2); +_mm256_storeu_pd(rr[3],ari3); + // END_INTERLEAVE + rr += 4; + aa += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_twiddle_fma(uint32_t m, void* a, void* b, const void* omg) { + double(*aa)[4] = (double(*)[4])a; + double(*bb)[4] = (double(*)[4])b; + const double(*const aend)[4] = aa + (m >> 1); + const __m256d om = _mm256_loadu_pd(omg); + const __m256d omrr = _mm256_shuffle_pd(om, om, 0); + const __m256d omii = _mm256_shuffle_pd(om, om, 15); + do { + /* +BEGIN_TEMPLATE +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d bir% = _mm256_shuffle_pd(bri%,bri%,5); +__m256d p% = _mm256_mul_pd(bir%,omii); +p% = _mm256_fmaddsub_pd(bri%,omrr,p%); +const __m256d ari% = _mm256_loadu_pd(aa[%]); +_mm256_storeu_pd(aa[%],_mm256_add_pd(ari%,p%)); +_mm256_storeu_pd(bb[%],_mm256_sub_pd(ari%,p%)); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0,5); +const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1,5); +const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2,5); +const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3,5); +__m256d p0 = _mm256_mul_pd(bir0,omii); +__m256d p1 = _mm256_mul_pd(bir1,omii); +__m256d p2 = _mm256_mul_pd(bir2,omii); +__m256d p3 = _mm256_mul_pd(bir3,omii); +p0 = _mm256_fmaddsub_pd(bri0,omrr,p0); +p1 = _mm256_fmaddsub_pd(bri1,omrr,p1); +p2 = _mm256_fmaddsub_pd(bri2,omrr,p2); +p3 = _mm256_fmaddsub_pd(bri3,omrr,p3); +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +_mm256_storeu_pd(aa[0],_mm256_add_pd(ari0,p0)); +_mm256_storeu_pd(aa[1],_mm256_add_pd(ari1,p1)); +_mm256_storeu_pd(aa[2],_mm256_add_pd(ari2,p2)); +_mm256_storeu_pd(aa[3],_mm256_add_pd(ari3,p3)); +_mm256_storeu_pd(bb[0],_mm256_sub_pd(ari0,p0)); +_mm256_storeu_pd(bb[1],_mm256_sub_pd(ari1,p1)); +_mm256_storeu_pd(bb[2],_mm256_sub_pd(ari2,p2)); +_mm256_storeu_pd(bb[3],_mm256_sub_pd(ari3,p3)); + // END_INTERLEAVE + bb += 4; + aa += 4; + } while (aa < aend); +} +#endif + +EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* precomp, void* a, void* b, const void* omg) { + const uint32_t m = precomp->m; + D8MEM* aa = (D8MEM*)a; + D8MEM* bb = (D8MEM*)b; + D8MEM* const aend = aa + (m >> 2); + const __m512d om = _mm512_broadcast_f64x4(_mm256_loadu_pd(omg)); + const __m512d omrr = _mm512_shuffle_pd(om, om, 0b00000000); + const __m512d omii = _mm512_shuffle_pd(om, om, 0b11111111); + do { +/* +BEGIN_TEMPLATE +const __m512d bri% = _mm512_loadu_pd(bb[%]); +const __m512d bir% = _mm512_shuffle_pd(bri%,bri%,0b10011001); +__m512d p% = _mm512_mul_pd(bir%,omii); +p% = _mm512_fmaddsub_pd(bri%,omrr,p%); +const __m512d ari% = _mm512_loadu_pd(aa[%]); +_mm512_storeu_pd(aa[%],_mm512_add_pd(ari%,p%)); +_mm512_storeu_pd(bb[%],_mm512_sub_pd(ari%,p%)); +bb += @; // ONCE +aa += @; // ONCE +END_TEMPLATE + */ +// BEGIN_INTERLEAVE 4 +const __m512d bri0 = _mm512_loadu_pd(bb[0]); +const __m512d bri1 = _mm512_loadu_pd(bb[1]); +const __m512d bri2 = _mm512_loadu_pd(bb[2]); +const __m512d bri3 = _mm512_loadu_pd(bb[3]); +const __m512d bir0 = _mm512_shuffle_pd(bri0,bri0,0b10011001); +const __m512d bir1 = _mm512_shuffle_pd(bri1,bri1,0b10011001); +const __m512d bir2 = _mm512_shuffle_pd(bri2,bri2,0b10011001); +const __m512d bir3 = _mm512_shuffle_pd(bri3,bri3,0b10011001); +__m512d p0 = _mm512_mul_pd(bir0,omii); +__m512d p1 = _mm512_mul_pd(bir1,omii); +__m512d p2 = _mm512_mul_pd(bir2,omii); +__m512d p3 = _mm512_mul_pd(bir3,omii); +p0 = _mm512_fmaddsub_pd(bri0,omrr,p0); +p1 = _mm512_fmaddsub_pd(bri1,omrr,p1); +p2 = _mm512_fmaddsub_pd(bri2,omrr,p2); +p3 = _mm512_fmaddsub_pd(bri3,omrr,p3); +const __m512d ari0 = _mm512_loadu_pd(aa[0]); +const __m512d ari1 = _mm512_loadu_pd(aa[1]); +const __m512d ari2 = _mm512_loadu_pd(aa[2]); +const __m512d ari3 = _mm512_loadu_pd(aa[3]); +_mm512_storeu_pd(aa[0],_mm512_add_pd(ari0,p0)); +_mm512_storeu_pd(aa[1],_mm512_add_pd(ari1,p1)); +_mm512_storeu_pd(aa[2],_mm512_add_pd(ari2,p2)); +_mm512_storeu_pd(aa[3],_mm512_add_pd(ari3,p3)); +_mm512_storeu_pd(bb[0], _mm512_sub_pd(ari0, p0)); +_mm512_storeu_pd(bb[1], _mm512_sub_pd(ari1, p1)); +_mm512_storeu_pd(bb[2], _mm512_sub_pd(ari2, p2)); +_mm512_storeu_pd(bb[3], _mm512_sub_pd(ari3, p3)); +bb += 4; // ONCE +aa += 4; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +EXPORT void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* precomp, void* a, uint64_t slicea, + const void* omg) { + const uint32_t m = precomp->m; + const uint64_t OFFSET = slicea / sizeof(D8MEM); + D8MEM* aa = (D8MEM*)a; + const D8MEM* aend = aa + (m >> 2); + const __m512d om = _mm512_broadcast_f64x4(_mm256_loadu_pd(omg)); + const __m512d om1rr = _mm512_shuffle_pd(om, om, 0); + const __m512d om1ii = _mm512_shuffle_pd(om, om, 15); + const __m512d om2rr = _mm512_shuffle_pd(om, om, 0); + const __m512d om2ii = _mm512_shuffle_pd(om, om, 0); + const __m512d om3rr = _mm512_shuffle_pd(om, om, 15); + const __m512d om3ii = _mm512_shuffle_pd(om, om, 15); + do { +/* +BEGIN_TEMPLATE +__m512d ari% = _mm512_loadu_pd(aa[%]); +__m512d bri% = _mm512_loadu_pd((aa+OFFSET)[%]); +__m512d cri% = _mm512_loadu_pd((aa+2*OFFSET)[%]); +__m512d dri% = _mm512_loadu_pd((aa+3*OFFSET)[%]); +__m512d pa% = _mm512_shuffle_pd(cri%,cri%,5); +__m512d pb% = _mm512_shuffle_pd(dri%,dri%,5); +pa% = _mm512_mul_pd(pa%,om1ii); +pb% = _mm512_mul_pd(pb%,om1ii); +pa% = _mm512_fmaddsub_pd(cri%,om1rr,pa%); +pb% = _mm512_fmaddsub_pd(dri%,om1rr,pb%); +cri% = _mm512_sub_pd(ari%,pa%); +dri% = _mm512_sub_pd(bri%,pb%); +ari% = _mm512_add_pd(ari%,pa%); +bri% = _mm512_add_pd(bri%,pb%); +pa% = _mm512_shuffle_pd(bri%,bri%,5); +pb% = _mm512_shuffle_pd(dri%,dri%,5); +pa% = _mm512_mul_pd(pa%,om2ii); +pb% = _mm512_mul_pd(pb%,om3ii); +pa% = _mm512_fmaddsub_pd(bri%,om2rr,pa%); +pb% = _mm512_fmaddsub_pd(dri%,om3rr,pb%); +bri% = _mm512_sub_pd(ari%,pa%); +dri% = _mm512_sub_pd(cri%,pb%); +ari% = _mm512_add_pd(ari%,pa%); +cri% = _mm512_add_pd(cri%,pb%); +_mm512_storeu_pd(aa[%], ari%); +_mm512_storeu_pd((aa+OFFSET)[%],bri%); +_mm512_storeu_pd((aa+2*OFFSET)[%],cri%); +_mm512_storeu_pd((aa+3*OFFSET)[%],dri%); +aa += @; // ONCE +END_TEMPLATE + */ +// BEGIN_INTERLEAVE 2 +__m512d ari0 = _mm512_loadu_pd(aa[0]); +__m512d ari1 = _mm512_loadu_pd(aa[1]); +__m512d bri0 = _mm512_loadu_pd((aa+OFFSET)[0]); +__m512d bri1 = _mm512_loadu_pd((aa+OFFSET)[1]); +__m512d cri0 = _mm512_loadu_pd((aa+2*OFFSET)[0]); +__m512d cri1 = _mm512_loadu_pd((aa+2*OFFSET)[1]); +__m512d dri0 = _mm512_loadu_pd((aa+3*OFFSET)[0]); +__m512d dri1 = _mm512_loadu_pd((aa+3*OFFSET)[1]); +__m512d pa0 = _mm512_shuffle_pd(cri0,cri0,5); +__m512d pa1 = _mm512_shuffle_pd(cri1,cri1,5); +__m512d pb0 = _mm512_shuffle_pd(dri0,dri0,5); +__m512d pb1 = _mm512_shuffle_pd(dri1,dri1,5); +pa0 = _mm512_mul_pd(pa0,om1ii); +pa1 = _mm512_mul_pd(pa1,om1ii); +pb0 = _mm512_mul_pd(pb0,om1ii); +pb1 = _mm512_mul_pd(pb1,om1ii); +pa0 = _mm512_fmaddsub_pd(cri0,om1rr,pa0); +pa1 = _mm512_fmaddsub_pd(cri1,om1rr,pa1); +pb0 = _mm512_fmaddsub_pd(dri0,om1rr,pb0); +pb1 = _mm512_fmaddsub_pd(dri1,om1rr,pb1); +cri0 = _mm512_sub_pd(ari0,pa0); +cri1 = _mm512_sub_pd(ari1,pa1); +dri0 = _mm512_sub_pd(bri0,pb0); +dri1 = _mm512_sub_pd(bri1,pb1); +ari0 = _mm512_add_pd(ari0,pa0); +ari1 = _mm512_add_pd(ari1,pa1); +bri0 = _mm512_add_pd(bri0,pb0); +bri1 = _mm512_add_pd(bri1,pb1); +pa0 = _mm512_shuffle_pd(bri0,bri0,5); +pa1 = _mm512_shuffle_pd(bri1,bri1,5); +pb0 = _mm512_shuffle_pd(dri0,dri0,5); +pb1 = _mm512_shuffle_pd(dri1,dri1,5); +pa0 = _mm512_mul_pd(pa0,om2ii); +pa1 = _mm512_mul_pd(pa1,om2ii); +pb0 = _mm512_mul_pd(pb0,om3ii); +pb1 = _mm512_mul_pd(pb1,om3ii); +pa0 = _mm512_fmaddsub_pd(bri0,om2rr,pa0); +pa1 = _mm512_fmaddsub_pd(bri1,om2rr,pa1); +pb0 = _mm512_fmaddsub_pd(dri0,om3rr,pb0); +pb1 = _mm512_fmaddsub_pd(dri1,om3rr,pb1); +bri0 = _mm512_sub_pd(ari0,pa0); +bri1 = _mm512_sub_pd(ari1,pa1); +dri0 = _mm512_sub_pd(cri0,pb0); +dri1 = _mm512_sub_pd(cri1,pb1); +ari0 = _mm512_add_pd(ari0,pa0); +ari1 = _mm512_add_pd(ari1,pa1); +cri0 = _mm512_add_pd(cri0,pb0); +cri1 = _mm512_add_pd(cri1,pb1); +_mm512_storeu_pd(aa[0], ari0); +_mm512_storeu_pd(aa[1], ari1); +_mm512_storeu_pd((aa+OFFSET)[0],bri0); +_mm512_storeu_pd((aa+OFFSET)[1],bri1); +_mm512_storeu_pd((aa+2*OFFSET)[0],cri0); +_mm512_storeu_pd((aa+2*OFFSET)[1],cri1); +_mm512_storeu_pd((aa+3*OFFSET)[0],dri0); +_mm512_storeu_pd((aa+3*OFFSET)[1],dri1); +aa += 2; // ONCE + // END_INTERLEAVE + } while (aa < aend); + _mm256_zeroupper(); +} + + diff --git a/spqlios/lib/spqlios/cplx/cplx_fft_internal.h b/spqlios/lib/spqlios/cplx/cplx_fft_internal.h new file mode 100644 index 0000000..7aa17fd --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_fft_internal.h @@ -0,0 +1,123 @@ +#ifndef SPQLIOS_CPLX_FFT_INTERNAL_H +#define SPQLIOS_CPLX_FFT_INTERNAL_H + +#include "cplx_fft.h" + +/** @brief a complex number contains two doubles real,imag */ +typedef double CPLX[2]; + +EXPORT void cplx_set(CPLX r, const CPLX a); +EXPORT void cplx_neg(CPLX r, const CPLX a); +EXPORT void cplx_add(CPLX r, const CPLX a, const CPLX b); +EXPORT void cplx_sub(CPLX r, const CPLX a, const CPLX b); +EXPORT void cplx_mul(CPLX r, const CPLX a, const CPLX b); + +/** + * @brief splits 2h evaluations of one polynomials into 2 times h evaluations of even/odd polynomial + * Input: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y) + * Output: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z) + * where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z + * @param h number of "coefficients" h >= 1 + * @param data 2h complex coefficients interleaved and 256b aligned + * @param powom y represented as (yre,yim) + */ +EXPORT void cplx_split_fft_ref(int32_t h, CPLX* data, const CPLX powom); +EXPORT void cplx_bisplit_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]); + +/** + * Input: Q(y),Q(-y) + * Output: P_0(z),P_1(z) + * where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z + * @param data 2 complexes coefficients interleaved and 256b aligned + * @param powom (z,-z) interleaved: (zre,zim,-zre,-zim) + */ +EXPORT void split_fft_last_ref(CPLX* data, const CPLX powom); + +EXPORT void cplx_ifft_naive(const uint32_t m, const double entry_pwr, CPLX* data); +EXPORT void cplx_ifft16_avx_fma(void* data, const void* omega); +EXPORT void cplx_ifft16_ref(void* data, const void* omega); + +/** + * @brief compute the ifft evaluations of P in place + * ifft(data) = ifft_rec(data, i); + * function ifft_rec(data, omega) { + * if #data = 1: return data + * let s = sqrt(omega) w. re(s)>0 + * let (u,v) = data + * return split_fft([ifft_rec(u, s), ifft_rec(v, -s)],s) + * } + * @param itables precomputed tables (contains all the powers of omega in the order they are used) + * @param data vector of m complexes (coeffs as input, evals as output) + */ +EXPORT void cplx_ifft_ref(const CPLX_IFFT_PRECOMP* itables, void* data); +EXPORT void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* itables, void* data); +EXPORT void cplx_fft_naive(const uint32_t m, const double entry_pwr, CPLX* data); +EXPORT void cplx_fft16_avx_fma(void* data, const void* omega); +EXPORT void cplx_fft16_ref(void* data, const void* omega); + +/** + * @brief compute the fft evaluations of P in place + * fft(data) = fft_rec(data, i); + * function fft_rec(data, omega) { + * if #data = 1: return data + * let s = sqrt(omega) w. re(s)>0 + * let (u,v) = merge_fft(data, s) + * return [fft_rec(u, s), fft_rec(v, -s)] + * } + * @param tables precomputed tables (contains all the powers of omega in the order they are used) + * @param data vector of m complexes (coeffs as input, evals as output) + */ +EXPORT void cplx_fft_ref(const CPLX_FFT_PRECOMP* tables, void* data); +EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* tables, void* data); + +/** + * @brief merges 2 times h evaluations of even/odd polynomials into 2h evaluations of a sigle polynomial + * Input: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z) + * Output: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y) + * where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z + * @param h number of "coefficients" h >= 1 + * @param data 2h complex coefficients interleaved and 256b aligned + * @param powom y represented as (yre,yim) + */ +EXPORT void cplx_twiddle_fft_ref(int32_t h, CPLX* data, const CPLX powom); + +EXPORT void citwiddle(CPLX a, CPLX b, const CPLX om); +EXPORT void ctwiddle(CPLX a, CPLX b, const CPLX om); +EXPORT void invctwiddle(CPLX a, CPLX b, const CPLX ombar); +EXPORT void invcitwiddle(CPLX a, CPLX b, const CPLX ombar); + +// CONVERSIONS + +/** @brief r = x from ZnX (coeffs as signed int32_t's ) to double */ +EXPORT void cplx_from_znx32_ref(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x); +EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x); +/** @brief r = x to ZnX (coeffs as signed int32_t's ) to double */ +EXPORT void cplx_to_znx32_ref(const CPLX_TO_ZNX32_PRECOMP* precomp, int32_t* r, const void* x); +EXPORT void cplx_to_znx32_avx2_fma(const CPLX_TO_ZNX32_PRECOMP* precomp, int32_t* r, const void* x); +/** @brief r = x mod 1 from TnX (coeffs as signed int32_t's) to double */ +EXPORT void cplx_from_tnx32_ref(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x); +EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x); +/** @brief r = x mod 1 from TnX (coeffs as signed int32_t's) */ +EXPORT void cplx_to_tnx32_ref(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c); +EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c); +/** @brief r = x from RnX (coeffs as doubles ) to double */ +EXPORT void cplx_from_rnx64_ref(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x); +EXPORT void cplx_from_rnx64_avx2_fma(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x); +/** @brief r = x to RnX (coeffs as doubles ) to double */ +EXPORT void cplx_to_rnx64_ref(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x); +EXPORT void cplx_to_rnx64_avx2_fma(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x); +/** @brief r = x to integers in RnX (coeffs as doubles ) to double */ +EXPORT void cplx_round_to_rnx64_ref(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x); +EXPORT void cplx_round_to_rnx64_avx2_fma(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x); + +// fftvec operations +/** @brief element-wise addmul r += ab */ +EXPORT void cplx_fftvec_addmul_ref(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b); +EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b); +EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b); +EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b); +/** @brief element-wise mul r = ab */ +EXPORT void cplx_fftvec_mul_ref(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b); +EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b); + +#endif // SPQLIOS_CPLX_FFT_INTERNAL_H diff --git a/spqlios/lib/spqlios/cplx/cplx_fft_private.h b/spqlios/lib/spqlios/cplx/cplx_fft_private.h new file mode 100644 index 0000000..be1ae3e --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_fft_private.h @@ -0,0 +1,109 @@ +#ifndef SPQLIOS_CPLX_FFT_PRIVATE_H +#define SPQLIOS_CPLX_FFT_PRIVATE_H + +#include "cplx_fft.h" + +typedef struct cplx_twiddle_precomp CPLX_FFTVEC_TWIDDLE_PRECOMP; +typedef struct cplx_bitwiddle_precomp CPLX_FFTVEC_BITWIDDLE_PRECOMP; + +typedef void (*IFFT_FUNCTION)(const CPLX_IFFT_PRECOMP*, void*); +typedef void (*FFT_FUNCTION)(const CPLX_FFT_PRECOMP*, void*); +// conversions +typedef void (*FROM_ZNX32_FUNCTION)(const CPLX_FROM_ZNX32_PRECOMP*, void*, const int32_t*); +typedef void (*TO_ZNX32_FUNCTION)(const CPLX_FROM_ZNX32_PRECOMP*, int32_t*, const void*); +typedef void (*FROM_TNX32_FUNCTION)(const CPLX_FROM_TNX32_PRECOMP*, void*, const int32_t*); +typedef void (*TO_TNX32_FUNCTION)(const CPLX_TO_TNX32_PRECOMP*, int32_t*, const void*); +typedef void (*FROM_RNX64_FUNCTION)(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x); +typedef void (*TO_RNX64_FUNCTION)(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x); +typedef void (*ROUND_TO_RNX64_FUNCTION)(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x); +// fftvec operations +typedef void (*FFTVEC_MUL_FUNCTION)(const CPLX_FFTVEC_MUL_PRECOMP*, void*, const void*, const void*); +typedef void (*FFTVEC_ADDMUL_FUNCTION)(const CPLX_FFTVEC_ADDMUL_PRECOMP*, void*, const void*, const void*); + +typedef void (*FFTVEC_TWIDDLE_FUNCTION)(const CPLX_FFTVEC_TWIDDLE_PRECOMP*, void*, const void*, const void*); +typedef void (*FFTVEC_BITWIDDLE_FUNCTION)(const CPLX_FFTVEC_BITWIDDLE_PRECOMP*, void*, uint64_t, const void*); + +struct cplx_ifft_precomp { + IFFT_FUNCTION function; + int64_t m; + uint64_t buf_size; + double* powomegas; + void* aligned_buffers; +}; + +struct cplx_fft_precomp { + FFT_FUNCTION function; + int64_t m; + uint64_t buf_size; + double* powomegas; + void* aligned_buffers; +}; + +struct cplx_from_znx32_precomp { + FROM_ZNX32_FUNCTION function; + int64_t m; +}; + +struct cplx_to_znx32_precomp { + TO_ZNX32_FUNCTION function; + int64_t m; + double divisor; +}; + +struct cplx_from_tnx32_precomp { + FROM_TNX32_FUNCTION function; + int64_t m; +}; + +struct cplx_to_tnx32_precomp { + TO_TNX32_FUNCTION function; + int64_t m; + double divisor; +}; + +struct cplx_from_rnx64_precomp { + FROM_RNX64_FUNCTION function; + int64_t m; +}; + +struct cplx_to_rnx64_precomp { + TO_RNX64_FUNCTION function; + int64_t m; + double divisor; +}; + +struct cplx_round_to_rnx64_precomp { + ROUND_TO_RNX64_FUNCTION function; + int64_t m; + double divisor; + uint32_t log2bound; +}; + +typedef struct cplx_mul_precomp { + FFTVEC_MUL_FUNCTION function; + int64_t m; +} CPLX_FFTVEC_MUL_PRECOMP; + +typedef struct cplx_addmul_precomp { + FFTVEC_ADDMUL_FUNCTION function; + int64_t m; +} CPLX_FFTVEC_ADDMUL_PRECOMP; + +struct cplx_twiddle_precomp { + FFTVEC_TWIDDLE_FUNCTION function; + int64_t m; + }; + +struct cplx_bitwiddle_precomp { + FFTVEC_BITWIDDLE_FUNCTION function; + int64_t m; +}; + +EXPORT void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om); +EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om); +EXPORT void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice, + const void* om); +EXPORT void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice, + const void* om); + +#endif // SPQLIOS_CPLX_FFT_PRIVATE_H diff --git a/spqlios/lib/spqlios/cplx/cplx_fft_ref.c b/spqlios/lib/spqlios/cplx/cplx_fft_ref.c new file mode 100644 index 0000000..dea721d --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_fft_ref.c @@ -0,0 +1,367 @@ +#include +#include + +#include "../commons_private.h" +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +/** @brief (a,b) <- (a+omega.b,a-omega.b) */ +void ctwiddle(CPLX a, CPLX b, const CPLX om) { + double re = om[0] * b[0] - om[1] * b[1]; + double im = om[0] * b[1] + om[1] * b[0]; + b[0] = a[0] - re; + b[1] = a[1] - im; + a[0] += re; + a[1] += im; +} + +/** @brief (a,b) <- (a+i.omega.b,a-i.omega.b) */ +void citwiddle(CPLX a, CPLX b, const CPLX om) { + double re = -om[1] * b[0] - om[0] * b[1]; + double im = -om[1] * b[1] + om[0] * b[0]; + b[0] = a[0] - re; + b[1] = a[1] - im; + a[0] += re; + a[1] += im; +} + +/** + * @brief FFT modulo X^16-omega^2 (in registers) + * @param data contains 16 complexes + * @param omega 8 complexes in this order: + * omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma + * alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta) + * j = sqrt(i), k=sqrt(j) + */ +void cplx_fft16_ref(void* data, const void* omega) { + CPLX* d = data; + const CPLX* om = omega; + // first pass + for (uint64_t i = 0; i < 8; ++i) { + ctwiddle(d[0 + i], d[8 + i], om[0]); + } + // + ctwiddle(d[0], d[4], om[1]); + ctwiddle(d[1], d[5], om[1]); + ctwiddle(d[2], d[6], om[1]); + ctwiddle(d[3], d[7], om[1]); + citwiddle(d[8], d[12], om[1]); + citwiddle(d[9], d[13], om[1]); + citwiddle(d[10], d[14], om[1]); + citwiddle(d[11], d[15], om[1]); + // + ctwiddle(d[0], d[2], om[2]); + ctwiddle(d[1], d[3], om[2]); + citwiddle(d[4], d[6], om[2]); + citwiddle(d[5], d[7], om[2]); + ctwiddle(d[8], d[10], om[3]); + ctwiddle(d[9], d[11], om[3]); + citwiddle(d[12], d[14], om[3]); + citwiddle(d[13], d[15], om[3]); + // + ctwiddle(d[0], d[1], om[4]); + citwiddle(d[2], d[3], om[4]); + ctwiddle(d[4], d[5], om[5]); + citwiddle(d[6], d[7], om[5]); + ctwiddle(d[8], d[9], om[6]); + citwiddle(d[10], d[11], om[6]); + ctwiddle(d[12], d[13], om[7]); + citwiddle(d[14], d[15], om[7]); +} + +double cos_2pix(double x) { return m_accurate_cos(2 * M_PI * x); } +double sin_2pix(double x) { return m_accurate_sin(2 * M_PI * x); } +void cplx_set_e2pix(CPLX res, double x) { + res[0] = cos_2pix(x); + res[1] = sin_2pix(x); +} + +void cplx_fft16_precomp(const double entry_pwr, CPLX** omg) { + static const double j_pow = 1. / 8.; + static const double k_pow = 1. / 16.; + const double pom = entry_pwr / 2.; + const double pom_2 = entry_pwr / 4.; + const double pom_4 = entry_pwr / 8.; + const double pom_8 = entry_pwr / 16.; + cplx_set_e2pix((*omg)[0], pom); + cplx_set_e2pix((*omg)[1], pom_2); + cplx_set_e2pix((*omg)[2], pom_4); + cplx_set_e2pix((*omg)[3], pom_4 + j_pow); + cplx_set_e2pix((*omg)[4], pom_8); + cplx_set_e2pix((*omg)[5], pom_8 + j_pow); + cplx_set_e2pix((*omg)[6], pom_8 + k_pow); + cplx_set_e2pix((*omg)[7], pom_8 + j_pow + k_pow); + *omg += 8; +} + +/** + * @brief h twiddles-fft on the same omega + * (also called merge-fft)merges 2 times h evaluations of even/odd polynomials into 2h evaluations of a sigle polynomial + * Input: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z) + * Output: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y) + * where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z + * @param h number of "coefficients" h >= 1 + * @param data 2h complex coefficients interleaved and 256b aligned + * @param powom y represented as (yre,yim) + */ +void cplx_twiddle_fft_ref(int32_t h, CPLX* data, const CPLX powom) { + CPLX* d0 = data; + CPLX* d1 = data + h; + for (uint64_t i = 0; i < h; ++i) { + ctwiddle(d0[i], d1[i], powom); + } +} + +void cplx_bitwiddle_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]) { + CPLX* d0 = data; + CPLX* d1 = data + h; + CPLX* d2 = data + 2*h; + CPLX* d3 = data + 3*h; + for (uint64_t i = 0; i < h; ++i) { + ctwiddle(d0[i], d2[i], powom[0]); + ctwiddle(d1[i], d3[i], powom[0]); + } + for (uint64_t i = 0; i < h; ++i) { + ctwiddle(d0[i], d1[i], powom[1]); + citwiddle(d2[i], d3[i], powom[1]); + } +} + +/** + * Input: P_0(z),P_1(z) + * Output: Q(y),Q(-y) + * where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z + * @param data 2 complexes coefficients interleaved and 256b aligned + * @param powom (z,-z) interleaved: (zre,zim,-zre,-zim) + */ +void merge_fft_last_ref(CPLX* data, const CPLX powom) { + CPLX prod; + cplx_mul(prod, data[1], powom); + cplx_sub(data[1], data[0], prod); + cplx_add(data[0], data[0], prod); +} + +void cplx_fft_ref_bfs_2(CPLX* dat, const CPLX** omg, uint32_t m) { + CPLX* data = (CPLX*)dat; + CPLX* const dend = data + m; + for (int32_t h = m / 2; h >= 2; h >>= 1) { + for (CPLX* d = data; d < dend; d += 2 * h) { + if (memcmp((*omg)[0], (*omg)[1], 8) != 0) abort(); + cplx_twiddle_fft_ref(h, d, **omg); + *omg += 2; + } +#if 0 + printf("after merge %d: ", h); + for (uint64_t ii=0; ii>= 1; + } + while(mm>16) { + uint32_t h = mm/4; + for (CPLX* d = data; d < dend; d += mm) { + cplx_bitwiddle_fft_ref(h, d, *omg); + *omg += 2; + } + mm=h; + } + for (CPLX* d = data; d < dend; d += 16) { + cplx_fft16_ref(d, *omg); + *omg += 8; + } +#if 0 + printf("after last: "); + for (uint64_t ii=0; ii16) { + double pom = ss / 4.; + uint32_t h = mm / 4; + for (uint32_t i = 0; i < m / mm; i++) { + double om = pom + fracrevbits(i) / 4.; + cplx_set_e2pix(omg[0][0], 2. * om); + cplx_set_e2pix(omg[0][1], om); + *omg += 2; + } + mm = h; + ss = pom; + } + { + // mm=16 + for (uint32_t i = 0; i < m / 16; i++) { + cplx_fft16_precomp(ss + fracrevbits(i), omg); + } + } +} + +/** @brief fills omega for cplx_fft_bfs_2 modulo X^m-exp(i.2.pi.entry_pwr) */ +void fill_cplx_fft_omegas_bfs_2(const double entry_pwr, CPLX** omg, uint32_t m) { + double pom = entry_pwr / 2.; + for (int32_t h = m / 2; h >= 2; h >>= 1) { + for (uint32_t i = 0; i < m / (2 * h); i++) { + cplx_set_e2pix(omg[0][0], pom + fracrevbits(i) / 2.); + cplx_set(omg[0][1], omg[0][0]); + *omg += 2; + } + pom /= 2; + } + { + // h=1 + for (uint32_t i = 0; i < m / 2; i++) { + cplx_set_e2pix((*omg)[0], pom + fracrevbits(i) / 2.); + cplx_neg((*omg)[1], (*omg)[0]); + *omg += 2; + } + } +} + +/** @brief fills omega for cplx_fft_rec modulo X^m-exp(i.2.pi.entry_pwr) */ +void fill_cplx_fft_omegas_rec_16(const double entry_pwr, CPLX** omg, uint32_t m) { + // note that the cases below are for recursive calls only! + // externally, this function shall only be called with m>=4096 + if (m == 1) return; + if (m <= 8) return fill_cplx_fft_omegas_bfs_2(entry_pwr, omg, m); + if (m <= 2048) return fill_cplx_fft_omegas_bfs_16(entry_pwr, omg, m); + double pom = entry_pwr / 2.; + cplx_set_e2pix((*omg)[0], pom); + cplx_set_e2pix((*omg)[1], pom); + *omg += 2; + fill_cplx_fft_omegas_rec_16(pom, omg, m / 2); + fill_cplx_fft_omegas_rec_16(pom + 0.5, omg, m / 2); +} + +void cplx_fft_ref_rec_16(CPLX* dat, const CPLX** omg, uint32_t m) { + if (m == 1) return; + if (m <= 8) return cplx_fft_ref_bfs_2(dat, omg, m); + if (m <= 2048) return cplx_fft_ref_bfs_16(dat, omg, m); + const uint32_t h = m / 2; + if (memcmp((*omg)[0], (*omg)[1], 8) != 0) abort(); + cplx_twiddle_fft_ref(h, dat, **omg); + *omg += 2; + cplx_fft_ref_rec_16(dat, omg, h); + cplx_fft_ref_rec_16(dat + h, omg, h); +} + +void cplx_fft_ref(const CPLX_FFT_PRECOMP* precomp, void* d) { + CPLX* data = (CPLX*)d; + const int32_t m = precomp->m; + const CPLX* omg = (CPLX*)precomp->powomegas; + if (m == 1) return; + if (m <= 8) return cplx_fft_ref_bfs_2(data, &omg, m); + if (m <= 2048) return cplx_fft_ref_bfs_16(data, &omg, m); + cplx_fft_ref_rec_16(data, &omg, m); +} + +EXPORT CPLX_FFT_PRECOMP* new_cplx_fft_precomp(uint32_t m, uint32_t num_buffers) { + const uint64_t OMG_SPACE = ceilto64b((2 * m)* sizeof(CPLX)); + const uint64_t BUF_SIZE = ceilto64b(m * sizeof(CPLX)); + void* reps = malloc(sizeof(CPLX_FFT_PRECOMP) + 63 // padding + + OMG_SPACE // tables //TODO 16? + + num_buffers * BUF_SIZE // buffers + ); + uint64_t aligned_addr = ceilto64b((uint64_t)(reps) + sizeof(CPLX_FFT_PRECOMP)); + CPLX_FFT_PRECOMP* r = (CPLX_FFT_PRECOMP*)reps; + r->m = m; + r->buf_size = BUF_SIZE; + r->powomegas = (double*)aligned_addr; + r->aligned_buffers = (void*)(aligned_addr + OMG_SPACE); + // fill in powomegas + CPLX* omg = (CPLX*)r->powomegas; + if (m <= 8) { + fill_cplx_fft_omegas_bfs_2(0.25, &omg, m); + } else if (m <= 2048) { + fill_cplx_fft_omegas_bfs_16(0.25, &omg, m); + } else { + fill_cplx_fft_omegas_rec_16(0.25, &omg, m); + } + if (((uint64_t)omg) - aligned_addr > OMG_SPACE) abort(); + // dispatch the right implementation + { + if (m <= 4) { + // currently, we do not have any acceletated + // implementation for m<=4 + r->function = cplx_fft_ref; + } else if (CPU_SUPPORTS("fma")) { + r->function = cplx_fft_avx2_fma; + } else { + r->function = cplx_fft_ref; + } + } + return reps; +} + +EXPORT void* cplx_fft_precomp_get_buffer(const CPLX_FFT_PRECOMP* tables, uint32_t buffer_index) { + return (uint8_t *)tables->aligned_buffers + buffer_index * tables->buf_size; +} + +EXPORT void cplx_fft_simple(uint32_t m, void* data) { + static CPLX_FFT_PRECOMP* p[31] = {0}; + CPLX_FFT_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_cplx_fft_precomp(m, 0); + (*f)->function(*f, data); +} + +EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data) { tables->function(tables, data); } diff --git a/spqlios/lib/spqlios/cplx/cplx_fft_sse.c b/spqlios/lib/spqlios/cplx/cplx_fft_sse.c new file mode 100644 index 0000000..07b8fb5 --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_fft_sse.c @@ -0,0 +1,310 @@ +#include + +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +typedef double D2MEM[2]; + +EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) { + const uint32_t m = precomp->m; + const D2MEM* aa = (D2MEM*)a; + const D2MEM* bb = (D2MEM*)b; + D2MEM* rr = (D2MEM*)r; + const D2MEM* const aend = aa + m; + do { + /* +BEGIN_TEMPLATE +const __m128d ari% = _mm_loadu_pd(aa[%]); +const __m128d bri% = _mm_loadu_pd(bb[%]); +const __m128d rri% = _mm_loadu_pd(rr[%]); +const __m128d bir% = _mm_shuffle_pd(bri%,bri%, 5); +const __m128d aii% = _mm_shuffle_pd(ari%,ari%, 15); +const __m128d pro% = _mm_fmaddsub_pd(aii%,bir%,rri%); +const __m128d arr% = _mm_shuffle_pd(ari%,ari%, 0); +const __m128d res% = _mm_fmaddsub_pd(arr%,bri%,pro%); +_mm_storeu_pd(rr[%],res%); +rr += @; // ONCE +aa += @; // ONCE +bb += @; // ONCE +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 2 + const __m128d ari0 = _mm_loadu_pd(aa[0]); + const __m128d ari1 = _mm_loadu_pd(aa[1]); + const __m128d bri0 = _mm_loadu_pd(bb[0]); + const __m128d bri1 = _mm_loadu_pd(bb[1]); + const __m128d rri0 = _mm_loadu_pd(rr[0]); + const __m128d rri1 = _mm_loadu_pd(rr[1]); + const __m128d bir0 = _mm_shuffle_pd(bri0, bri0, 0b01); + const __m128d bir1 = _mm_shuffle_pd(bri1, bri1, 0b01); + const __m128d aii0 = _mm_shuffle_pd(ari0, ari0, 0b11); + const __m128d aii1 = _mm_shuffle_pd(ari1, ari1, 0b11); + const __m128d pro0 = _mm_fmaddsub_pd(aii0, bir0, rri0); + const __m128d pro1 = _mm_fmaddsub_pd(aii1, bir1, rri1); + const __m128d arr0 = _mm_shuffle_pd(ari0, ari0, 0b00); + const __m128d arr1 = _mm_shuffle_pd(ari1, ari1, 0b00); + const __m128d res0 = _mm_fmaddsub_pd(arr0, bri0, pro0); + const __m128d res1 = _mm_fmaddsub_pd(arr1, bri1, pro1); + _mm_storeu_pd(rr[0], res0); + _mm_storeu_pd(rr[1], res1); + rr += 2; // ONCE + aa += 2; // ONCE + bb += 2; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +#if 0 +EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* precomp, void* r, const void* a, const void* b) { + const uint32_t m = precomp->m; + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5); // conj of b +const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15); // im of a +const __m256d pro% = _mm256_mul_pd(aii%,bir%); +const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0); // rr of a +const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%); +_mm256_storeu_pd(rr[%],res%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0, 5); // conj of b +const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1, 5); // conj of b +const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2, 5); // conj of b +const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3, 5); // conj of b +const __m256d aii0 = _mm256_shuffle_pd(ari0,ari0, 15); // im of a +const __m256d aii1 = _mm256_shuffle_pd(ari1,ari1, 15); // im of a +const __m256d aii2 = _mm256_shuffle_pd(ari2,ari2, 15); // im of a +const __m256d aii3 = _mm256_shuffle_pd(ari3,ari3, 15); // im of a +const __m256d pro0 = _mm256_mul_pd(aii0,bir0); +const __m256d pro1 = _mm256_mul_pd(aii1,bir1); +const __m256d pro2 = _mm256_mul_pd(aii2,bir2); +const __m256d pro3 = _mm256_mul_pd(aii3,bir3); +const __m256d arr0 = _mm256_shuffle_pd(ari0,ari0, 0); // rr of a +const __m256d arr1 = _mm256_shuffle_pd(ari1,ari1, 0); // rr of a +const __m256d arr2 = _mm256_shuffle_pd(ari2,ari2, 0); // rr of a +const __m256d arr3 = _mm256_shuffle_pd(ari3,ari3, 0); // rr of a +const __m256d res0 = _mm256_fmaddsub_pd(arr0,bri0,pro0); +const __m256d res1 = _mm256_fmaddsub_pd(arr1,bri1,pro1); +const __m256d res2 = _mm256_fmaddsub_pd(arr2,bri2,pro2); +const __m256d res3 = _mm256_fmaddsub_pd(arr3,bri3,pro3); +_mm256_storeu_pd(rr[0],res0); +_mm256_storeu_pd(rr[1],res1); +_mm256_storeu_pd(rr[2],res2); +_mm256_storeu_pd(rr[3],res3); + // END_INTERLEAVE + rr += 4; + aa += 4; + bb += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b) { + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d res% = _mm256_add_pd(ari%,bri%); +_mm256_storeu_pd(rr[%],res%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d res0 = _mm256_add_pd(ari0,bri0); +const __m256d res1 = _mm256_add_pd(ari1,bri1); +const __m256d res2 = _mm256_add_pd(ari2,bri2); +const __m256d res3 = _mm256_add_pd(ari3,bri3); +_mm256_storeu_pd(rr[0],res0); +_mm256_storeu_pd(rr[1],res1); +_mm256_storeu_pd(rr[2],res2); +_mm256_storeu_pd(rr[3],res3); + // END_INTERLEAVE + rr += 4; + aa += 4; + bb += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b) { + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d sum% = _mm256_add_pd(ari%,bri%); +const __m256d rri% = _mm256_loadu_pd(rr[%]); +const __m256d res% = _mm256_sub_pd(rri%,sum%); +_mm256_storeu_pd(rr[%],res%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d sum0 = _mm256_add_pd(ari0,bri0); +const __m256d sum1 = _mm256_add_pd(ari1,bri1); +const __m256d sum2 = _mm256_add_pd(ari2,bri2); +const __m256d sum3 = _mm256_add_pd(ari3,bri3); +const __m256d rri0 = _mm256_loadu_pd(rr[0]); +const __m256d rri1 = _mm256_loadu_pd(rr[1]); +const __m256d rri2 = _mm256_loadu_pd(rr[2]); +const __m256d rri3 = _mm256_loadu_pd(rr[3]); +const __m256d res0 = _mm256_sub_pd(rri0,sum0); +const __m256d res1 = _mm256_sub_pd(rri1,sum1); +const __m256d res2 = _mm256_sub_pd(rri2,sum2); +const __m256d res3 = _mm256_sub_pd(rri3,sum3); +_mm256_storeu_pd(rr[0],res0); +_mm256_storeu_pd(rr[1],res1); +_mm256_storeu_pd(rr[2],res2); +_mm256_storeu_pd(rr[3],res3); + // END_INTERLEAVE + rr += 4; + aa += 4; + bb += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) { + const double(*aa)[4] = (double(*)[4])a; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +_mm256_storeu_pd(rr[%],ari%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +_mm256_storeu_pd(rr[0],ari0); +_mm256_storeu_pd(rr[1],ari1); +_mm256_storeu_pd(rr[2],ari2); +_mm256_storeu_pd(rr[3],ari3); + // END_INTERLEAVE + rr += 4; + aa += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_twiddle_fma(uint32_t m, void* a, void* b, const void* omg) { + double(*aa)[4] = (double(*)[4])a; + double(*bb)[4] = (double(*)[4])b; + const double(*const aend)[4] = aa + (m >> 1); + const __m256d om = _mm256_loadu_pd(omg); + const __m256d omrr = _mm256_shuffle_pd(om, om, 0); + const __m256d omii = _mm256_shuffle_pd(om, om, 15); + do { + /* +BEGIN_TEMPLATE +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d bir% = _mm256_shuffle_pd(bri%,bri%,5); +__m256d p% = _mm256_mul_pd(bir%,omii); +p% = _mm256_fmaddsub_pd(bri%,omrr,p%); +const __m256d ari% = _mm256_loadu_pd(aa[%]); +_mm256_storeu_pd(aa[%],_mm256_add_pd(ari%,p%)); +_mm256_storeu_pd(bb[%],_mm256_sub_pd(ari%,p%)); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0,5); +const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1,5); +const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2,5); +const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3,5); +__m256d p0 = _mm256_mul_pd(bir0,omii); +__m256d p1 = _mm256_mul_pd(bir1,omii); +__m256d p2 = _mm256_mul_pd(bir2,omii); +__m256d p3 = _mm256_mul_pd(bir3,omii); +p0 = _mm256_fmaddsub_pd(bri0,omrr,p0); +p1 = _mm256_fmaddsub_pd(bri1,omrr,p1); +p2 = _mm256_fmaddsub_pd(bri2,omrr,p2); +p3 = _mm256_fmaddsub_pd(bri3,omrr,p3); +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +_mm256_storeu_pd(aa[0],_mm256_add_pd(ari0,p0)); +_mm256_storeu_pd(aa[1],_mm256_add_pd(ari1,p1)); +_mm256_storeu_pd(aa[2],_mm256_add_pd(ari2,p2)); +_mm256_storeu_pd(aa[3],_mm256_add_pd(ari3,p3)); +_mm256_storeu_pd(bb[0],_mm256_sub_pd(ari0,p0)); +_mm256_storeu_pd(bb[1],_mm256_sub_pd(ari1,p1)); +_mm256_storeu_pd(bb[2],_mm256_sub_pd(ari2,p2)); +_mm256_storeu_pd(bb[3],_mm256_sub_pd(ari3,p3)); + // END_INTERLEAVE + bb += 4; + aa += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_innerprod_avx2_fma(const CPLX_FFTVEC_INNERPROD_PRECOMP* precomp, const int32_t ellbar, + const uint64_t lda, const uint64_t ldb, + void* r, const void* a, const void* b) { + const uint32_t m = precomp->m; + const uint32_t blk = precomp->blk; + const uint32_t nblocks = precomp->nblocks; + const CPLX* aa = (CPLX*)a; + const CPLX* bb = (CPLX*)b; + CPLX* rr = (CPLX*)r; + const uint64_t ldda = lda >> 4; // in CPLX + const uint64_t lddb = ldb >> 4; + if (m==0) { + memset(r, 0, m*sizeof(CPLX)); + return; + } + for (uint32_t k=0; kmul_func, rrr, aaa, bbb); + for (int32_t i=1; iaddmul_func, rrr, aaa + i * ldda, bbb + i * lddb); + } + } +} +#endif + diff --git a/spqlios/lib/spqlios/cplx/cplx_fftvec_avx2_fma.c b/spqlios/lib/spqlios/cplx/cplx_fftvec_avx2_fma.c new file mode 100644 index 0000000..788c4d5 --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_fftvec_avx2_fma.c @@ -0,0 +1,389 @@ +#include +#include + +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +typedef double D4MEM[4]; + +EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* precomp, void* r, const void* a, const void* b) { + const uint32_t m = precomp->m; + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5); // conj of b +const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15); // im of a +const __m256d pro% = _mm256_mul_pd(aii%,bir%); +const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0); // rr of a +const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%); +_mm256_storeu_pd(rr[%],res%); +rr += @; // ONCE +aa += @; // ONCE +bb += @; // ONCE +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 + // This block is automatically generated from the template above + // by the interleave.pl script. Please do not edit by hand + const __m256d ari0 = _mm256_loadu_pd(aa[0]); + const __m256d ari1 = _mm256_loadu_pd(aa[1]); + const __m256d ari2 = _mm256_loadu_pd(aa[2]); + const __m256d ari3 = _mm256_loadu_pd(aa[3]); + const __m256d bri0 = _mm256_loadu_pd(bb[0]); + const __m256d bri1 = _mm256_loadu_pd(bb[1]); + const __m256d bri2 = _mm256_loadu_pd(bb[2]); + const __m256d bri3 = _mm256_loadu_pd(bb[3]); + const __m256d bir0 = _mm256_shuffle_pd(bri0, bri0, 5); // conj of b + const __m256d bir1 = _mm256_shuffle_pd(bri1, bri1, 5); // conj of b + const __m256d bir2 = _mm256_shuffle_pd(bri2, bri2, 5); // conj of b + const __m256d bir3 = _mm256_shuffle_pd(bri3, bri3, 5); // conj of b + const __m256d aii0 = _mm256_shuffle_pd(ari0, ari0, 15); // im of a + const __m256d aii1 = _mm256_shuffle_pd(ari1, ari1, 15); // im of a + const __m256d aii2 = _mm256_shuffle_pd(ari2, ari2, 15); // im of a + const __m256d aii3 = _mm256_shuffle_pd(ari3, ari3, 15); // im of a + const __m256d pro0 = _mm256_mul_pd(aii0, bir0); + const __m256d pro1 = _mm256_mul_pd(aii1, bir1); + const __m256d pro2 = _mm256_mul_pd(aii2, bir2); + const __m256d pro3 = _mm256_mul_pd(aii3, bir3); + const __m256d arr0 = _mm256_shuffle_pd(ari0, ari0, 0); // rr of a + const __m256d arr1 = _mm256_shuffle_pd(ari1, ari1, 0); // rr of a + const __m256d arr2 = _mm256_shuffle_pd(ari2, ari2, 0); // rr of a + const __m256d arr3 = _mm256_shuffle_pd(ari3, ari3, 0); // rr of a + const __m256d res0 = _mm256_fmaddsub_pd(arr0, bri0, pro0); + const __m256d res1 = _mm256_fmaddsub_pd(arr1, bri1, pro1); + const __m256d res2 = _mm256_fmaddsub_pd(arr2, bri2, pro2); + const __m256d res3 = _mm256_fmaddsub_pd(arr3, bri3, pro3); + _mm256_storeu_pd(rr[0], res0); + _mm256_storeu_pd(rr[1], res1); + _mm256_storeu_pd(rr[2], res2); + _mm256_storeu_pd(rr[3], res3); + rr += 4; // ONCE + aa += 4; // ONCE + bb += 4; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + + + +EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) { + const uint32_t m = precomp->m; + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d rri% = _mm256_loadu_pd(rr[%]); +const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5); +const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15); +const __m256d pro% = _mm256_fmaddsub_pd(aii%,bir%,rri%); +const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0); +const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%); +_mm256_storeu_pd(rr[%],res%); +rr += @; // ONCE +aa += @; // ONCE +bb += @; // ONCE +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 2 + // This block is automatically generated from the template above + // by the interleave.pl script. Please do not edit by hand + const __m256d ari0 = _mm256_loadu_pd(aa[0]); + const __m256d ari1 = _mm256_loadu_pd(aa[1]); + const __m256d bri0 = _mm256_loadu_pd(bb[0]); + const __m256d bri1 = _mm256_loadu_pd(bb[1]); + const __m256d rri0 = _mm256_loadu_pd(rr[0]); + const __m256d rri1 = _mm256_loadu_pd(rr[1]); + const __m256d bir0 = _mm256_shuffle_pd(bri0, bri0, 5); + const __m256d bir1 = _mm256_shuffle_pd(bri1, bri1, 5); + const __m256d aii0 = _mm256_shuffle_pd(ari0, ari0, 15); + const __m256d aii1 = _mm256_shuffle_pd(ari1, ari1, 15); + const __m256d pro0 = _mm256_fmaddsub_pd(aii0, bir0, rri0); + const __m256d pro1 = _mm256_fmaddsub_pd(aii1, bir1, rri1); + const __m256d arr0 = _mm256_shuffle_pd(ari0, ari0, 0); + const __m256d arr1 = _mm256_shuffle_pd(ari1, ari1, 0); + const __m256d res0 = _mm256_fmaddsub_pd(arr0, bri0, pro0); + const __m256d res1 = _mm256_fmaddsub_pd(arr1, bri1, pro1); + _mm256_storeu_pd(rr[0], res0); + _mm256_storeu_pd(rr[1], res1); + rr += 2; // ONCE + aa += 2; // ONCE + bb += 2; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +EXPORT void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* precomp, void* a, uint64_t slicea, + const void* omg) { + const uint32_t m = precomp->m; + const uint64_t OFFSET = slicea / sizeof(D4MEM); + D4MEM* aa = (D4MEM*)a; + const double(*const aend)[4] = aa + (m >> 1); + const __m256d om = _mm256_loadu_pd(omg); + const __m256d om1rr = _mm256_shuffle_pd(om, om, 0); + const __m256d om1ii = _mm256_shuffle_pd(om, om, 15); + const __m256d om2rr = _mm256_shuffle_pd(om, om, 0); + const __m256d om2ii = _mm256_shuffle_pd(om, om, 0); + const __m256d om3rr = _mm256_shuffle_pd(om, om, 15); + const __m256d om3ii = _mm256_shuffle_pd(om, om, 15); + do { + /* +BEGIN_TEMPLATE +__m256d ari% = _mm256_loadu_pd(aa[%]); +__m256d bri% = _mm256_loadu_pd((aa+OFFSET)[%]); +__m256d cri% = _mm256_loadu_pd((aa+2*OFFSET)[%]); +__m256d dri% = _mm256_loadu_pd((aa+3*OFFSET)[%]); +__m256d pa% = _mm256_shuffle_pd(cri%,cri%,5); +__m256d pb% = _mm256_shuffle_pd(dri%,dri%,5); +pa% = _mm256_mul_pd(pa%,om1ii); +pb% = _mm256_mul_pd(pb%,om1ii); +pa% = _mm256_fmaddsub_pd(cri%,om1rr,pa%); +pb% = _mm256_fmaddsub_pd(dri%,om1rr,pb%); +cri% = _mm256_sub_pd(ari%,pa%); +dri% = _mm256_sub_pd(bri%,pb%); +ari% = _mm256_add_pd(ari%,pa%); +bri% = _mm256_add_pd(bri%,pb%); +pa% = _mm256_shuffle_pd(bri%,bri%,5); +pb% = _mm256_shuffle_pd(dri%,dri%,5); +pa% = _mm256_mul_pd(pa%,om2ii); +pb% = _mm256_mul_pd(pb%,om3ii); +pa% = _mm256_fmaddsub_pd(bri%,om2rr,pa%); +pb% = _mm256_fmaddsub_pd(dri%,om3rr,pb%); +bri% = _mm256_sub_pd(ari%,pa%); +dri% = _mm256_sub_pd(cri%,pb%); +ari% = _mm256_add_pd(ari%,pa%); +cri% = _mm256_add_pd(cri%,pb%); +_mm256_storeu_pd(aa[%], ari%); +_mm256_storeu_pd((aa+OFFSET)[%],bri%); +_mm256_storeu_pd((aa+2*OFFSET)[%],cri%); +_mm256_storeu_pd((aa+3*OFFSET)[%],dri%); +aa += @; // ONCE +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 1 + // This block is automatically generated from the template above + // by the interleave.pl script. Please do not edit by hand + __m256d ari0 = _mm256_loadu_pd(aa[0]); + __m256d bri0 = _mm256_loadu_pd((aa + OFFSET)[0]); + __m256d cri0 = _mm256_loadu_pd((aa + 2 * OFFSET)[0]); + __m256d dri0 = _mm256_loadu_pd((aa + 3 * OFFSET)[0]); + __m256d pa0 = _mm256_shuffle_pd(cri0, cri0, 5); + __m256d pb0 = _mm256_shuffle_pd(dri0, dri0, 5); + pa0 = _mm256_mul_pd(pa0, om1ii); + pb0 = _mm256_mul_pd(pb0, om1ii); + pa0 = _mm256_fmaddsub_pd(cri0, om1rr, pa0); + pb0 = _mm256_fmaddsub_pd(dri0, om1rr, pb0); + cri0 = _mm256_sub_pd(ari0, pa0); + dri0 = _mm256_sub_pd(bri0, pb0); + ari0 = _mm256_add_pd(ari0, pa0); + bri0 = _mm256_add_pd(bri0, pb0); + pa0 = _mm256_shuffle_pd(bri0, bri0, 5); + pb0 = _mm256_shuffle_pd(dri0, dri0, 5); + pa0 = _mm256_mul_pd(pa0, om2ii); + pb0 = _mm256_mul_pd(pb0, om3ii); + pa0 = _mm256_fmaddsub_pd(bri0, om2rr, pa0); + pb0 = _mm256_fmaddsub_pd(dri0, om3rr, pb0); + bri0 = _mm256_sub_pd(ari0, pa0); + dri0 = _mm256_sub_pd(cri0, pb0); + ari0 = _mm256_add_pd(ari0, pa0); + cri0 = _mm256_add_pd(cri0, pb0); + _mm256_storeu_pd(aa[0], ari0); + _mm256_storeu_pd((aa + OFFSET)[0], bri0); + _mm256_storeu_pd((aa + 2 * OFFSET)[0], cri0); + _mm256_storeu_pd((aa + 3 * OFFSET)[0], dri0); + aa += 1; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +EXPORT void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b) { + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d sum% = _mm256_add_pd(ari%,bri%); +const __m256d rri% = _mm256_loadu_pd(rr[%]); +const __m256d res% = _mm256_sub_pd(rri%,sum%); +_mm256_storeu_pd(rr[%],res%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 + // This block is automatically generated from the template above + // by the interleave.pl script. Please do not edit by hand + const __m256d ari0 = _mm256_loadu_pd(aa[0]); + const __m256d ari1 = _mm256_loadu_pd(aa[1]); + const __m256d ari2 = _mm256_loadu_pd(aa[2]); + const __m256d ari3 = _mm256_loadu_pd(aa[3]); + const __m256d bri0 = _mm256_loadu_pd(bb[0]); + const __m256d bri1 = _mm256_loadu_pd(bb[1]); + const __m256d bri2 = _mm256_loadu_pd(bb[2]); + const __m256d bri3 = _mm256_loadu_pd(bb[3]); + const __m256d sum0 = _mm256_add_pd(ari0, bri0); + const __m256d sum1 = _mm256_add_pd(ari1, bri1); + const __m256d sum2 = _mm256_add_pd(ari2, bri2); + const __m256d sum3 = _mm256_add_pd(ari3, bri3); + const __m256d rri0 = _mm256_loadu_pd(rr[0]); + const __m256d rri1 = _mm256_loadu_pd(rr[1]); + const __m256d rri2 = _mm256_loadu_pd(rr[2]); + const __m256d rri3 = _mm256_loadu_pd(rr[3]); + const __m256d res0 = _mm256_sub_pd(rri0, sum0); + const __m256d res1 = _mm256_sub_pd(rri1, sum1); + const __m256d res2 = _mm256_sub_pd(rri2, sum2); + const __m256d res3 = _mm256_sub_pd(rri3, sum3); + _mm256_storeu_pd(rr[0], res0); + _mm256_storeu_pd(rr[1], res1); + _mm256_storeu_pd(rr[2], res2); + _mm256_storeu_pd(rr[3], res3); + // END_INTERLEAVE + rr += 4; + aa += 4; + bb += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b) { + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* + BEGIN_TEMPLATE + const __m256d ari% = _mm256_loadu_pd(aa[%]); + const __m256d bri% = _mm256_loadu_pd(bb[%]); + const __m256d res% = _mm256_add_pd(ari%,bri%); + _mm256_storeu_pd(rr[%],res%); + rr += @; // ONCE + aa += @; // ONCE + bb += @; // ONCE + END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 + // This block is automatically generated from the template above + // by the interleave.pl script. Please do not edit by hand + const __m256d ari0 = _mm256_loadu_pd(aa[0]); + const __m256d ari1 = _mm256_loadu_pd(aa[1]); + const __m256d ari2 = _mm256_loadu_pd(aa[2]); + const __m256d ari3 = _mm256_loadu_pd(aa[3]); + const __m256d bri0 = _mm256_loadu_pd(bb[0]); + const __m256d bri1 = _mm256_loadu_pd(bb[1]); + const __m256d bri2 = _mm256_loadu_pd(bb[2]); + const __m256d bri3 = _mm256_loadu_pd(bb[3]); + const __m256d res0 = _mm256_add_pd(ari0, bri0); + const __m256d res1 = _mm256_add_pd(ari1, bri1); + const __m256d res2 = _mm256_add_pd(ari2, bri2); + const __m256d res3 = _mm256_add_pd(ari3, bri3); + _mm256_storeu_pd(rr[0], res0); + _mm256_storeu_pd(rr[1], res1); + _mm256_storeu_pd(rr[2], res2); + _mm256_storeu_pd(rr[3], res3); + rr += 4; // ONCE + aa += 4; // ONCE + bb += 4; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +EXPORT void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* precomp, void* a, void* b, const void* omg) { + const uint32_t m = precomp->m; + double(*aa)[4] = (double(*)[4])a; + double(*bb)[4] = (double(*)[4])b; + const double(*const aend)[4] = aa + (m >> 1); + const __m256d om = _mm256_loadu_pd(omg); + const __m256d omrr = _mm256_shuffle_pd(om, om, 0); + const __m256d omii = _mm256_shuffle_pd(om, om, 15); + do { + /* +BEGIN_TEMPLATE +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d bir% = _mm256_shuffle_pd(bri%,bri%,5); +__m256d p% = _mm256_mul_pd(bir%,omii); +p% = _mm256_fmaddsub_pd(bri%,omrr,p%); +const __m256d ari% = _mm256_loadu_pd(aa[%]); +_mm256_storeu_pd(aa[%],_mm256_add_pd(ari%,p%)); +_mm256_storeu_pd(bb[%],_mm256_sub_pd(ari%,p%)); +bb += @; // ONCE +aa += @; // ONCE +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 + // This block is automatically generated from the template above + // by the interleave.pl script. Please do not edit by hand + const __m256d bri0 = _mm256_loadu_pd(bb[0]); + const __m256d bri1 = _mm256_loadu_pd(bb[1]); + const __m256d bri2 = _mm256_loadu_pd(bb[2]); + const __m256d bri3 = _mm256_loadu_pd(bb[3]); + const __m256d bir0 = _mm256_shuffle_pd(bri0, bri0, 5); + const __m256d bir1 = _mm256_shuffle_pd(bri1, bri1, 5); + const __m256d bir2 = _mm256_shuffle_pd(bri2, bri2, 5); + const __m256d bir3 = _mm256_shuffle_pd(bri3, bri3, 5); + __m256d p0 = _mm256_mul_pd(bir0, omii); + __m256d p1 = _mm256_mul_pd(bir1, omii); + __m256d p2 = _mm256_mul_pd(bir2, omii); + __m256d p3 = _mm256_mul_pd(bir3, omii); + p0 = _mm256_fmaddsub_pd(bri0, omrr, p0); + p1 = _mm256_fmaddsub_pd(bri1, omrr, p1); + p2 = _mm256_fmaddsub_pd(bri2, omrr, p2); + p3 = _mm256_fmaddsub_pd(bri3, omrr, p3); + const __m256d ari0 = _mm256_loadu_pd(aa[0]); + const __m256d ari1 = _mm256_loadu_pd(aa[1]); + const __m256d ari2 = _mm256_loadu_pd(aa[2]); + const __m256d ari3 = _mm256_loadu_pd(aa[3]); + _mm256_storeu_pd(aa[0], _mm256_add_pd(ari0, p0)); + _mm256_storeu_pd(aa[1], _mm256_add_pd(ari1, p1)); + _mm256_storeu_pd(aa[2], _mm256_add_pd(ari2, p2)); + _mm256_storeu_pd(aa[3], _mm256_add_pd(ari3, p3)); + _mm256_storeu_pd(bb[0], _mm256_sub_pd(ari0, p0)); + _mm256_storeu_pd(bb[1], _mm256_sub_pd(ari1, p1)); + _mm256_storeu_pd(bb[2], _mm256_sub_pd(ari2, p2)); + _mm256_storeu_pd(bb[3], _mm256_sub_pd(ari3, p3)); + bb += 4; // ONCE + aa += 4; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +EXPORT void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) { + const double(*aa)[4] = (double(*)[4])a; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +_mm256_storeu_pd(rr[%],ari%); +rr += @; // ONCE +aa += @; // ONCE +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 + // This block is automatically generated from the template above + // by the interleave.pl script. Please do not edit by hand + const __m256d ari0 = _mm256_loadu_pd(aa[0]); + const __m256d ari1 = _mm256_loadu_pd(aa[1]); + const __m256d ari2 = _mm256_loadu_pd(aa[2]); + const __m256d ari3 = _mm256_loadu_pd(aa[3]); + _mm256_storeu_pd(rr[0], ari0); + _mm256_storeu_pd(rr[1], ari1); + _mm256_storeu_pd(rr[2], ari2); + _mm256_storeu_pd(rr[3], ari3); + rr += 4; // ONCE + aa += 4; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} diff --git a/spqlios/lib/spqlios/cplx/cplx_fftvec_ref.c b/spqlios/lib/spqlios/cplx/cplx_fftvec_ref.c new file mode 100644 index 0000000..f7d4629 --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_fftvec_ref.c @@ -0,0 +1,85 @@ +#include + +#include "../commons_private.h" +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +EXPORT void cplx_fftvec_addmul_ref(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) { + const uint32_t m = precomp->m; + const CPLX* aa = (CPLX*)a; + const CPLX* bb = (CPLX*)b; + CPLX* rr = (CPLX*)r; + for (uint32_t i = 0; i < m; ++i) { + const double re = aa[i][0] * bb[i][0] - aa[i][1] * bb[i][1]; + const double im = aa[i][0] * bb[i][1] + aa[i][1] * bb[i][0]; + rr[i][0] += re; + rr[i][1] += im; + } +} + +EXPORT void cplx_fftvec_mul_ref(const CPLX_FFTVEC_MUL_PRECOMP* precomp, void* r, const void* a, const void* b) { + const uint32_t m = precomp->m; + const CPLX* aa = (CPLX*)a; + const CPLX* bb = (CPLX*)b; + CPLX* rr = (CPLX*)r; + for (uint32_t i = 0; i < m; ++i) { + const double re = aa[i][0] * bb[i][0] - aa[i][1] * bb[i][1]; + const double im = aa[i][0] * bb[i][1] + aa[i][1] * bb[i][0]; + rr[i][0] = re; + rr[i][1] = im; + } +} + +EXPORT void* init_cplx_fftvec_addmul_precomp(CPLX_FFTVEC_ADDMUL_PRECOMP* r, uint32_t m) { + if (m & (m - 1)) return spqlios_error("m must be a power of two"); + r->m = m; + if (m <= 4) { + r->function = cplx_fftvec_addmul_ref; + } else if (CPU_SUPPORTS("fma")) { + r->function = cplx_fftvec_addmul_fma; + } else { + r->function = cplx_fftvec_addmul_ref; + } + return r; +} + +EXPORT void* init_cplx_fftvec_mul_precomp(CPLX_FFTVEC_MUL_PRECOMP* r, uint32_t m) { + if (m & (m - 1)) return spqlios_error("m must be a power of two"); + r->m = m; + if (m <= 4) { + r->function = cplx_fftvec_mul_ref; + } else if (CPU_SUPPORTS("fma")) { + r->function = cplx_fftvec_mul_fma; + } else { + r->function = cplx_fftvec_mul_ref; + } + return r; +} + +EXPORT CPLX_FFTVEC_ADDMUL_PRECOMP* new_cplx_fftvec_addmul_precomp(uint32_t m) { + CPLX_FFTVEC_ADDMUL_PRECOMP* r = malloc(sizeof(CPLX_FFTVEC_MUL_PRECOMP)); + return spqlios_keep_or_free(r, init_cplx_fftvec_addmul_precomp(r, m)); +} + +EXPORT CPLX_FFTVEC_MUL_PRECOMP* new_cplx_fftvec_mul_precomp(uint32_t m) { + CPLX_FFTVEC_MUL_PRECOMP* r = malloc(sizeof(CPLX_FFTVEC_MUL_PRECOMP)); + return spqlios_keep_or_free(r, init_cplx_fftvec_mul_precomp(r, m)); +} + +EXPORT void cplx_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b) { + static CPLX_FFTVEC_MUL_PRECOMP p[31] = {0}; + CPLX_FFTVEC_MUL_PRECOMP* f = p + log2m(m); + if (!f->function) { + if (!init_cplx_fftvec_mul_precomp(f, m)) abort(); + } + f->function(f, r, a, b); +} + +EXPORT void cplx_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b) { + static CPLX_FFTVEC_ADDMUL_PRECOMP p[31] = {0}; + CPLX_FFTVEC_ADDMUL_PRECOMP* f = p + log2m(m); + if (!f->function) { + if (!init_cplx_fftvec_addmul_precomp(f, m)) abort(); + } + f->function(f, r, a, b); +} diff --git a/spqlios/lib/spqlios/cplx/cplx_ifft16_avx_fma.s b/spqlios/lib/spqlios/cplx/cplx_ifft16_avx_fma.s new file mode 100644 index 0000000..bc9ea10 --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_ifft16_avx_fma.s @@ -0,0 +1,157 @@ +# shifted FFT over X^16-i +# 1st argument (rdi) contains 16 complexes +# 2nd argument (rsi) contains: 8 complexes +# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma +# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta) +# j = sqrt(i), k=sqrt(j) +.globl cplx_ifft16_avx_fma +cplx_ifft16_avx_fma: +vmovupd (%rdi),%ymm8 # load data into registers %ymm8 -> %ymm15 +vmovupd 0x20(%rdi),%ymm9 +vmovupd 0x40(%rdi),%ymm10 +vmovupd 0x60(%rdi),%ymm11 +vmovupd 0x80(%rdi),%ymm12 +vmovupd 0xa0(%rdi),%ymm13 +vmovupd 0xc0(%rdi),%ymm14 +vmovupd 0xe0(%rdi),%ymm15 + +.fourth_pass: +vmovupd 0(%rsi),%ymm0 /* gamma */ +vmovupd 32(%rsi),%ymm2 /* delta */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5 +vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7 +vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13 +vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15 +vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6 +vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14 +vsubpd %ymm4,%ymm8,%ymm12 # tw: to mul by gamma +vsubpd %ymm5,%ymm9,%ymm13 # itw: to mul by i.gamma +vsubpd %ymm6,%ymm10,%ymm14 # tw: to mul by delta +vsubpd %ymm7,%ymm11,%ymm15 # itw: to mul by i.delta +vaddpd %ymm4,%ymm8,%ymm8 +vaddpd %ymm5,%ymm9,%ymm9 +vaddpd %ymm6,%ymm10,%ymm10 +vaddpd %ymm7,%ymm11,%ymm11 +vshufpd $5, %ymm12, %ymm12, %ymm4 +vshufpd $5, %ymm13, %ymm13, %ymm5 +vshufpd $5, %ymm14, %ymm14, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm0,%ymm5 +vmulpd %ymm6,%ymm3,%ymm6 +vmulpd %ymm7,%ymm2,%ymm7 +vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4 +vfmsubadd231pd %ymm13, %ymm1, %ymm5 +vfmaddsub231pd %ymm14, %ymm2, %ymm6 +vfmsubadd231pd %ymm15, %ymm3, %ymm7 + +vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma +vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma +vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta +vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta +vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14 +vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6 + + +.third_pass: +vmovupd 64(%rsi),%xmm0 /* gamma */ +vmovupd 80(%rsi),%xmm2 /* delta */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 +vinsertf128 $1, %xmm2, %ymm2, %ymm2 +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vsubpd %ymm9,%ymm8,%ymm4 +vsubpd %ymm11,%ymm10,%ymm5 +vsubpd %ymm13,%ymm12,%ymm6 +vsubpd %ymm15,%ymm14,%ymm7 +vaddpd %ymm9,%ymm8,%ymm8 +vaddpd %ymm11,%ymm10,%ymm10 +vaddpd %ymm13,%ymm12,%ymm12 +vaddpd %ymm15,%ymm14,%ymm14 +vshufpd $5, %ymm4, %ymm4, %ymm9 +vshufpd $5, %ymm5, %ymm5, %ymm11 +vshufpd $5, %ymm6, %ymm6, %ymm13 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm9,%ymm1,%ymm9 +vmulpd %ymm11,%ymm0,%ymm11 +vmulpd %ymm13,%ymm3,%ymm13 +vmulpd %ymm15,%ymm2,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm9 # ymm9 = (ymm0 * ymm4) +/- ymm9 +vfmsubadd231pd %ymm5, %ymm1, %ymm11 +vfmaddsub231pd %ymm6, %ymm2, %ymm13 +vfmsubadd231pd %ymm7, %ymm3, %ymm15 + +.second_pass: +vmovupd 96(%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vsubpd %ymm10,%ymm8,%ymm4 +vsubpd %ymm11,%ymm9,%ymm5 +vsubpd %ymm14,%ymm12,%ymm6 +vsubpd %ymm15,%ymm13,%ymm7 +vaddpd %ymm10,%ymm8,%ymm8 +vaddpd %ymm11,%ymm9,%ymm9 +vaddpd %ymm14,%ymm12,%ymm12 +vaddpd %ymm15,%ymm13,%ymm13 +vshufpd $5, %ymm4, %ymm4, %ymm10 +vshufpd $5, %ymm5, %ymm5, %ymm11 +vshufpd $5, %ymm6, %ymm6, %ymm14 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm10,%ymm1,%ymm10 +vmulpd %ymm11,%ymm1,%ymm11 +vmulpd %ymm14,%ymm0,%ymm14 +vmulpd %ymm15,%ymm0,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm10 # ymm10 = (ymm0 * ymm4) +/- ymm10 +vfmaddsub231pd %ymm5, %ymm0, %ymm11 +vfmsubadd231pd %ymm6, %ymm1, %ymm14 +vfmsubadd231pd %ymm7, %ymm1, %ymm15 + +.first_pass: +vmovupd 112(%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vsubpd %ymm12,%ymm8,%ymm4 +vsubpd %ymm13,%ymm9,%ymm5 +vsubpd %ymm14,%ymm10,%ymm6 +vsubpd %ymm15,%ymm11,%ymm7 +vaddpd %ymm12,%ymm8,%ymm8 +vaddpd %ymm13,%ymm9,%ymm9 +vaddpd %ymm14,%ymm10,%ymm10 +vaddpd %ymm15,%ymm11,%ymm11 +vshufpd $5, %ymm4, %ymm4, %ymm12 +vshufpd $5, %ymm5, %ymm5, %ymm13 +vshufpd $5, %ymm6, %ymm6, %ymm14 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm12,%ymm1,%ymm12 +vmulpd %ymm13,%ymm1,%ymm13 +vmulpd %ymm14,%ymm1,%ymm14 +vmulpd %ymm15,%ymm1,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12 +vfmaddsub231pd %ymm5, %ymm0, %ymm13 +vfmaddsub231pd %ymm6, %ymm0, %ymm14 +vfmaddsub231pd %ymm7, %ymm0, %ymm15 + +.save_and_return: +vmovupd %ymm8,(%rdi) +vmovupd %ymm9,0x20(%rdi) +vmovupd %ymm10,0x40(%rdi) +vmovupd %ymm11,0x60(%rdi) +vmovupd %ymm12,0x80(%rdi) +vmovupd %ymm13,0xa0(%rdi) +vmovupd %ymm14,0xc0(%rdi) +vmovupd %ymm15,0xe0(%rdi) +ret +.size cplx_ifft16_avx_fma, .-cplx_ifft16_avx_fma +.section .note.GNU-stack,"",@progbits diff --git a/spqlios/lib/spqlios/cplx/cplx_ifft16_avx_fma_win32.s b/spqlios/lib/spqlios/cplx/cplx_ifft16_avx_fma_win32.s new file mode 100644 index 0000000..1803882 --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_ifft16_avx_fma_win32.s @@ -0,0 +1,192 @@ + .text + .p2align 4 + .globl cplx_ifft16_avx_fma + .def cplx_ifft16_avx_fma; .scl 2; .type 32; .endef +cplx_ifft16_avx_fma: + + pushq %rdi + pushq %rsi + movq %rcx,%rdi + movq %rdx,%rsi + subq $0x100,%rsp + movdqu %xmm6,(%rsp) + movdqu %xmm7,0x10(%rsp) + movdqu %xmm8,0x20(%rsp) + movdqu %xmm9,0x30(%rsp) + movdqu %xmm10,0x40(%rsp) + movdqu %xmm11,0x50(%rsp) + movdqu %xmm12,0x60(%rsp) + movdqu %xmm13,0x70(%rsp) + movdqu %xmm14,0x80(%rsp) + movdqu %xmm15,0x90(%rsp) + callq cplx_ifft16_avx_fma_amd64 + movdqu (%rsp),%xmm6 + movdqu 0x10(%rsp),%xmm7 + movdqu 0x20(%rsp),%xmm8 + movdqu 0x30(%rsp),%xmm9 + movdqu 0x40(%rsp),%xmm10 + movdqu 0x50(%rsp),%xmm11 + movdqu 0x60(%rsp),%xmm12 + movdqu 0x70(%rsp),%xmm13 + movdqu 0x80(%rsp),%xmm14 + movdqu 0x90(%rsp),%xmm15 + addq $0x100,%rsp + popq %rsi + popq %rdi + retq + +# shifted FFT over X^16-i +# 1st argument (rdi) contains 16 complexes +# 2nd argument (rsi) contains: 8 complexes +# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma +# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta) +# j = sqrt(i), k=sqrt(j) + +cplx_ifft16_avx_fma_amd64: +vmovupd (%rdi),%ymm8 # load data into registers %ymm8 -> %ymm15 +vmovupd 0x20(%rdi),%ymm9 +vmovupd 0x40(%rdi),%ymm10 +vmovupd 0x60(%rdi),%ymm11 +vmovupd 0x80(%rdi),%ymm12 +vmovupd 0xa0(%rdi),%ymm13 +vmovupd 0xc0(%rdi),%ymm14 +vmovupd 0xe0(%rdi),%ymm15 + +.fourth_pass: +vmovupd 0(%rsi),%ymm0 /* gamma */ +vmovupd 32(%rsi),%ymm2 /* delta */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5 +vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7 +vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13 +vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15 +vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6 +vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14 +vsubpd %ymm4,%ymm8,%ymm12 # tw: to mul by gamma +vsubpd %ymm5,%ymm9,%ymm13 # itw: to mul by i.gamma +vsubpd %ymm6,%ymm10,%ymm14 # tw: to mul by delta +vsubpd %ymm7,%ymm11,%ymm15 # itw: to mul by i.delta +vaddpd %ymm4,%ymm8,%ymm8 +vaddpd %ymm5,%ymm9,%ymm9 +vaddpd %ymm6,%ymm10,%ymm10 +vaddpd %ymm7,%ymm11,%ymm11 +vshufpd $5, %ymm12, %ymm12, %ymm4 +vshufpd $5, %ymm13, %ymm13, %ymm5 +vshufpd $5, %ymm14, %ymm14, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm0,%ymm5 +vmulpd %ymm6,%ymm3,%ymm6 +vmulpd %ymm7,%ymm2,%ymm7 +vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4 +vfmsubadd231pd %ymm13, %ymm1, %ymm5 +vfmaddsub231pd %ymm14, %ymm2, %ymm6 +vfmsubadd231pd %ymm15, %ymm3, %ymm7 + +vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma +vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma +vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta +vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta +vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14 +vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6 + + +.third_pass: +vmovupd 64(%rsi),%xmm0 /* gamma */ +vmovupd 80(%rsi),%xmm2 /* delta */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 +vinsertf128 $1, %xmm2, %ymm2, %ymm2 +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vsubpd %ymm9,%ymm8,%ymm4 +vsubpd %ymm11,%ymm10,%ymm5 +vsubpd %ymm13,%ymm12,%ymm6 +vsubpd %ymm15,%ymm14,%ymm7 +vaddpd %ymm9,%ymm8,%ymm8 +vaddpd %ymm11,%ymm10,%ymm10 +vaddpd %ymm13,%ymm12,%ymm12 +vaddpd %ymm15,%ymm14,%ymm14 +vshufpd $5, %ymm4, %ymm4, %ymm9 +vshufpd $5, %ymm5, %ymm5, %ymm11 +vshufpd $5, %ymm6, %ymm6, %ymm13 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm9,%ymm1,%ymm9 +vmulpd %ymm11,%ymm0,%ymm11 +vmulpd %ymm13,%ymm3,%ymm13 +vmulpd %ymm15,%ymm2,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm9 # ymm9 = (ymm0 * ymm4) +/- ymm9 +vfmsubadd231pd %ymm5, %ymm1, %ymm11 +vfmaddsub231pd %ymm6, %ymm2, %ymm13 +vfmsubadd231pd %ymm7, %ymm3, %ymm15 + +.second_pass: +vmovupd 96(%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vsubpd %ymm10,%ymm8,%ymm4 +vsubpd %ymm11,%ymm9,%ymm5 +vsubpd %ymm14,%ymm12,%ymm6 +vsubpd %ymm15,%ymm13,%ymm7 +vaddpd %ymm10,%ymm8,%ymm8 +vaddpd %ymm11,%ymm9,%ymm9 +vaddpd %ymm14,%ymm12,%ymm12 +vaddpd %ymm15,%ymm13,%ymm13 +vshufpd $5, %ymm4, %ymm4, %ymm10 +vshufpd $5, %ymm5, %ymm5, %ymm11 +vshufpd $5, %ymm6, %ymm6, %ymm14 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm10,%ymm1,%ymm10 +vmulpd %ymm11,%ymm1,%ymm11 +vmulpd %ymm14,%ymm0,%ymm14 +vmulpd %ymm15,%ymm0,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm10 # ymm10 = (ymm0 * ymm4) +/- ymm10 +vfmaddsub231pd %ymm5, %ymm0, %ymm11 +vfmsubadd231pd %ymm6, %ymm1, %ymm14 +vfmsubadd231pd %ymm7, %ymm1, %ymm15 + +.first_pass: +vmovupd 112(%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vsubpd %ymm12,%ymm8,%ymm4 +vsubpd %ymm13,%ymm9,%ymm5 +vsubpd %ymm14,%ymm10,%ymm6 +vsubpd %ymm15,%ymm11,%ymm7 +vaddpd %ymm12,%ymm8,%ymm8 +vaddpd %ymm13,%ymm9,%ymm9 +vaddpd %ymm14,%ymm10,%ymm10 +vaddpd %ymm15,%ymm11,%ymm11 +vshufpd $5, %ymm4, %ymm4, %ymm12 +vshufpd $5, %ymm5, %ymm5, %ymm13 +vshufpd $5, %ymm6, %ymm6, %ymm14 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm12,%ymm1,%ymm12 +vmulpd %ymm13,%ymm1,%ymm13 +vmulpd %ymm14,%ymm1,%ymm14 +vmulpd %ymm15,%ymm1,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12 +vfmaddsub231pd %ymm5, %ymm0, %ymm13 +vfmaddsub231pd %ymm6, %ymm0, %ymm14 +vfmaddsub231pd %ymm7, %ymm0, %ymm15 + +.save_and_return: +vmovupd %ymm8,(%rdi) +vmovupd %ymm9,0x20(%rdi) +vmovupd %ymm10,0x40(%rdi) +vmovupd %ymm11,0x60(%rdi) +vmovupd %ymm12,0x80(%rdi) +vmovupd %ymm13,0xa0(%rdi) +vmovupd %ymm14,0xc0(%rdi) +vmovupd %ymm15,0xe0(%rdi) +ret diff --git a/spqlios/lib/spqlios/cplx/cplx_ifft_avx2_fma.c b/spqlios/lib/spqlios/cplx/cplx_ifft_avx2_fma.c new file mode 100644 index 0000000..8bfd347 --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_ifft_avx2_fma.c @@ -0,0 +1,267 @@ +#include +#include +#include + +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +typedef double D4MEM[4]; +typedef double D2MEM[2]; + +/** + * @brief complex ifft via bfs strategy (for m between 2 and 8) + * @param dat the data to run the algorithm on + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +void cplx_ifft_avx2_fma_bfs_2(D4MEM* dat, const D2MEM** omga, uint32_t m) { + double* data = (double*)dat; + D4MEM* const finaldd = (D4MEM*)(data + 2 * m); + { + // loop with h = 1 + // we do not do any particular optimization in this loop, + // since this function is only called for small dimensions + D4MEM* dd = (D4MEM*)data; + do { + /* + BEGIN_TEMPLATE + const __m256d ab% = _mm256_loadu_pd(dd[0+2*%]); + const __m256d cd% = _mm256_loadu_pd(dd[1+2*%]); + const __m256d ac% = _mm256_permute2f128_pd(ab%, cd%, 0b100000); + const __m256d bd% = _mm256_permute2f128_pd(ab%, cd%, 0b110001); + const __m256d sum% = _mm256_add_pd(ac%, bd%); + const __m256d diff% = _mm256_sub_pd(ac%, bd%); + const __m256d diffbar% = _mm256_shuffle_pd(diff%, diff%, 5); + const __m256d om% = _mm256_load_pd((*omg)[0+%]); + const __m256d omre% = _mm256_unpacklo_pd(om%, om%); + const __m256d omim% = _mm256_unpackhi_pd(om%, om%); + const __m256d t1% = _mm256_mul_pd(diffbar%, omim%); + const __m256d t2% = _mm256_fmaddsub_pd(diff%, omre%, t1%); + const __m256d newab% = _mm256_permute2f128_pd(sum%, t2%, 0b100000); + const __m256d newcd% = _mm256_permute2f128_pd(sum%, t2%, 0b110001); + _mm256_storeu_pd(dd[0+2*%], newab%); + _mm256_storeu_pd(dd[1+2*%], newcd%); + dd += 2*@; + *omg += 2*@; + END_TEMPLATE + */ + // BEGIN_INTERLEAVE 1 + const __m256d ab0 = _mm256_loadu_pd(dd[0 + 2 * 0]); + const __m256d cd0 = _mm256_loadu_pd(dd[1 + 2 * 0]); + const __m256d ac0 = _mm256_permute2f128_pd(ab0, cd0, 0b100000); + const __m256d bd0 = _mm256_permute2f128_pd(ab0, cd0, 0b110001); + const __m256d sum0 = _mm256_add_pd(ac0, bd0); + const __m256d diff0 = _mm256_sub_pd(ac0, bd0); + const __m256d diffbar0 = _mm256_shuffle_pd(diff0, diff0, 5); + const __m256d om0 = _mm256_load_pd((*omga)[0 + 0]); + const __m256d omre0 = _mm256_unpacklo_pd(om0, om0); + const __m256d omim0 = _mm256_unpackhi_pd(om0, om0); + const __m256d t10 = _mm256_mul_pd(diffbar0, omim0); + const __m256d t20 = _mm256_fmaddsub_pd(diff0, omre0, t10); + const __m256d newab0 = _mm256_permute2f128_pd(sum0, t20, 0b100000); + const __m256d newcd0 = _mm256_permute2f128_pd(sum0, t20, 0b110001); + _mm256_storeu_pd(dd[0 + 2 * 0], newab0); + _mm256_storeu_pd(dd[1 + 2 * 0], newcd0); + dd += 2 * 1; + *omga += 2 * 1; + // END_INTERLEAVE + } while (dd < finaldd); +#if 0 + printf("c after first: "); + for (uint64_t ii=0; ii> 1; + for (uint32_t _2nblock = 2; _2nblock <= ms2; _2nblock <<= 1) { + // _2nblock = h in ref code + uint32_t nblock = _2nblock >> 1; // =h/2 in ref code + D4MEM* dd = (D4MEM*)data; + do { + const __m256d om = _mm256_load_pd((*omga)[0]); + const __m256d omre = _mm256_unpacklo_pd(om, om); + const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om)); + D4MEM* const ddend = (dd + nblock); + D4MEM* ddmid = ddend; + do { + const __m256d a = _mm256_loadu_pd(dd[0]); + const __m256d b = _mm256_loadu_pd(ddmid[0]); + const __m256d newa = _mm256_add_pd(a, b); + _mm256_storeu_pd(dd[0], newa); + const __m256d diff = _mm256_sub_pd(a, b); + const __m256d t1 = _mm256_mul_pd(diff, omre); + const __m256d bardiff = _mm256_shuffle_pd(diff, diff, 5); + const __m256d t2 = _mm256_fmadd_pd(bardiff, omim, t1); + _mm256_storeu_pd(ddmid[0], t2); + dd += 1; + ddmid += 1; + } while (dd < ddend); + dd += nblock; + *omga += 2; + } while (dd < finaldd); + } +} + +/** + * @brief complex fft via bfs strategy (for m >= 16) + * @param dat the data to run the algorithm on + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +void cplx_ifft_avx2_fma_bfs_16(D4MEM* dat, const D2MEM** omga, uint32_t m) { + double* data = (double*)dat; + D4MEM* const finaldd = (D4MEM*)(data + 2 * m); + // base iteration when h = _2nblock == 8 + { + D4MEM* dd = (D4MEM*)data; + do { + cplx_ifft16_avx_fma(dd, *omga); + dd += 8; + *omga += 8; + } while (dd < finaldd); + } + // general case + const uint32_t log2m = _mm_popcnt_u32(m-1); //_popcnt32(m-1); //log2(m); + uint32_t h=16; + if (log2m % 2 == 1) { + uint32_t nblock = h >> 1; // =h/2 in ref code + D4MEM* dd = (D4MEM*)data; + do { + const __m128d om1 = _mm_loadu_pd((*omga)[0]); + const __m256d om = _mm256_set_m128d(om1,om1); + const __m256d omre = _mm256_unpacklo_pd(om, om); + const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om)); + D4MEM* const ddend = (dd + nblock); + D4MEM* ddmid = ddend; + do { + const __m256d a = _mm256_loadu_pd(dd[0]); + const __m256d b = _mm256_loadu_pd(ddmid[0]); + const __m256d newa = _mm256_add_pd(a, b); + _mm256_storeu_pd(dd[0], newa); + const __m256d diff = _mm256_sub_pd(a, b); + const __m256d t1 = _mm256_mul_pd(diff, omre); + const __m256d bardiff = _mm256_shuffle_pd(diff, diff, 5); + const __m256d t2 = _mm256_fmadd_pd(bardiff, omim, t1); + _mm256_storeu_pd(ddmid[0], t2); + dd += 1; + ddmid += 1; + } while (dd < ddend); + dd += nblock; + *omga += 1; + } while (dd < finaldd); + h = 32; + } + for (; h < m; h <<= 2) { + // _2nblock = h in ref code + uint32_t nblock = h >> 1; // =h/2 in ref code + D4MEM* dd0 = (D4MEM*)data; + do { + const __m128d om1 = _mm_loadu_pd((*omga)[0]); + const __m128d al1 = _mm_loadu_pd((*omga)[1]); + const __m256d om = _mm256_set_m128d(om1,om1); + const __m256d al = _mm256_set_m128d(al1,al1); + const __m256d omre = _mm256_unpacklo_pd(om, om); + const __m256d omim = _mm256_unpackhi_pd(om, om); + const __m256d alre = _mm256_unpacklo_pd(al, al); + const __m256d alim = _mm256_unpackhi_pd(al, al); + D4MEM* const ddend = (dd0 + nblock); + D4MEM* dd1 = ddend; + D4MEM* dd2 = dd1 + nblock; + D4MEM* dd3 = dd2 + nblock; + do { + __m256d u0 = _mm256_loadu_pd(dd0[0]); + __m256d u1 = _mm256_loadu_pd(dd1[0]); + __m256d u2 = _mm256_loadu_pd(dd2[0]); + __m256d u3 = _mm256_loadu_pd(dd3[0]); + __m256d u4 = _mm256_add_pd(u0, u1); + __m256d u5 = _mm256_sub_pd(u0, u1); + __m256d u6 = _mm256_add_pd(u2, u3); + __m256d u7 = _mm256_sub_pd(u2, u3); + u0 = _mm256_shuffle_pd(u5, u5, 5); + u2 = _mm256_shuffle_pd(u7, u7, 5); + u1 = _mm256_mul_pd(u0, omim); + u3 = _mm256_mul_pd(u2, omre); + u5 = _mm256_fmaddsub_pd(u5,omre, u1); + u7 = _mm256_fmsubadd_pd(u7,omim, u3); + ////// + u0 = _mm256_add_pd(u4,u6); + u1 = _mm256_add_pd(u5,u7); + u2 = _mm256_sub_pd(u4,u6); + u3 = _mm256_sub_pd(u5,u7); + u4 = _mm256_shuffle_pd(u2, u2, 5); + u5 = _mm256_shuffle_pd(u3, u3, 5); + u6 = _mm256_mul_pd(u4, alim); + u7 = _mm256_mul_pd(u5, alim); + u2 = _mm256_fmaddsub_pd(u2,alre, u6); + u3 = _mm256_fmaddsub_pd(u3,alre, u7); + /////// + _mm256_storeu_pd(dd0[0], u0); + _mm256_storeu_pd(dd1[0], u1); + _mm256_storeu_pd(dd2[0], u2); + _mm256_storeu_pd(dd3[0], u3); + dd0 += 1; + dd1 += 1; + dd2 += 1; + dd3 += 1; + } while (dd0 < ddend); + dd0 += 3*nblock; + *omga += 2; + } while (dd0 < finaldd); + } +} + +/** + * @brief complex ifft via dfs recursion (for m >= 16) + * @param dat the data to run the algorithm on + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +void cplx_ifft_avx2_fma_rec_16(D4MEM* dat, const D2MEM** omga, uint32_t m) { + if (m <= 8) return cplx_ifft_avx2_fma_bfs_2(dat, omga, m); + if (m <= 2048) return cplx_ifft_avx2_fma_bfs_16(dat, omga, m); + const uint32_t _2nblock = m >> 1; // = h in ref code + const uint32_t nblock = _2nblock >> 1; // =h/2 in ref code + cplx_ifft_avx2_fma_rec_16(dat, omga, _2nblock); + cplx_ifft_avx2_fma_rec_16(dat + nblock, omga, _2nblock); + { + // final iteration + D4MEM* dd = dat; + const __m256d om = _mm256_load_pd((*omga)[0]); + const __m256d omre = _mm256_unpacklo_pd(om, om); + const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om)); + D4MEM* const ddend = (dd + nblock); + D4MEM* ddmid = ddend; + do { + const __m256d a = _mm256_loadu_pd(dd[0]); + const __m256d b = _mm256_loadu_pd(ddmid[0]); + const __m256d newa = _mm256_add_pd(a, b); + _mm256_storeu_pd(dd[0], newa); + const __m256d diff = _mm256_sub_pd(a, b); + const __m256d t1 = _mm256_mul_pd(diff, omre); + const __m256d bardiff = _mm256_shuffle_pd(diff, diff, 5); + const __m256d t2 = _mm256_fmadd_pd(bardiff, omim, t1); + _mm256_storeu_pd(ddmid[0], t2); + dd += 1; + ddmid += 1; + } while (dd < ddend); + *omga += 2; + } +} + +/** + * @brief complex ifft via best strategy (for m>=1) + * @param dat the data to run the algorithm on: m complex numbers + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +EXPORT void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* precomp, void* d) { + const uint32_t m = precomp->m; + const D2MEM* omg = (D2MEM*)precomp->powomegas; + if (m <= 1) return; + if (m <= 8) return cplx_ifft_avx2_fma_bfs_2(d, &omg, m); + if (m <= 2048) return cplx_ifft_avx2_fma_bfs_16(d, &omg, m); + cplx_ifft_avx2_fma_rec_16(d, &omg, m); +} diff --git a/spqlios/lib/spqlios/cplx/cplx_ifft_ref.c b/spqlios/lib/spqlios/cplx/cplx_ifft_ref.c new file mode 100644 index 0000000..a3405a6 --- /dev/null +++ b/spqlios/lib/spqlios/cplx/cplx_ifft_ref.c @@ -0,0 +1,315 @@ +#include + +#include "../commons_private.h" +#include "cplx_fft.h" +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +/** @brief (a,b) <- (a+b,omegabar.(a-b)) */ +void invctwiddle(CPLX a, CPLX b, const CPLX ombar) { + double diffre = a[0] - b[0]; + double diffim = a[1] - b[1]; + a[0] = a[0] + b[0]; + a[1] = a[1] + b[1]; + b[0] = diffre * ombar[0] - diffim * ombar[1]; + b[1] = diffre * ombar[1] + diffim * ombar[0]; +} + +/** @brief (a,b) <- (a+b,-i.omegabar(a-b)) */ +void invcitwiddle(CPLX a, CPLX b, const CPLX ombar) { + double diffre = a[0] - b[0]; + double diffim = a[1] - b[1]; + a[0] = a[0] + b[0]; + a[1] = a[1] + b[1]; + //-i(x+iy)=-ix+y + b[0] = diffre * ombar[1] + diffim * ombar[0]; + b[1] = -diffre * ombar[0] + diffim * ombar[1]; +} + +/** @brief exp(-i.2pi.x) */ +void cplx_set_e2pimx(CPLX res, double x) { + res[0] = m_accurate_cos(2 * M_PI * x); + res[1] = -m_accurate_sin(2 * M_PI * x); +} +/** + * @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,... + * essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/ +double fracrevbits(uint32_t i); +/** @brief fft modulo X^m-exp(i.2pi.entry+pwr) -- reference code */ +void cplx_ifft_naive(const uint32_t m, const double entry_pwr, CPLX* data) { + if (m == 1) return; + const double pom = entry_pwr / 2.; + const uint32_t h = m / 2; + CPLX cpom; + cplx_set_e2pimx(cpom, pom); + // do the recursive calls + cplx_ifft_naive(h, pom, data); + cplx_ifft_naive(h, pom + 0.5, data + h); + // apply the inverse twiddle factors + for (uint64_t i = 0; i < h; ++i) { + invctwiddle(data[i], data[i + h], cpom); + } +} + +void cplx_ifft16_precomp(const double entry_pwr, CPLX** omg) { + static const double j_pow = 1. / 8.; + static const double k_pow = 1. / 16.; + const double pom = entry_pwr / 2.; + const double pom_2 = entry_pwr / 4.; + const double pom_4 = entry_pwr / 8.; + const double pom_8 = entry_pwr / 16.; + cplx_set_e2pimx((*omg)[0], pom_8); + cplx_set_e2pimx((*omg)[1], pom_8 + j_pow); + cplx_set_e2pimx((*omg)[2], pom_8 + k_pow); + cplx_set_e2pimx((*omg)[3], pom_8 + j_pow + k_pow); + cplx_set_e2pimx((*omg)[4], pom_4); + cplx_set_e2pimx((*omg)[5], pom_4 + j_pow); + cplx_set_e2pimx((*omg)[6], pom_2); + cplx_set_e2pimx((*omg)[7], pom); + *omg += 8; +} + +/** + * @brief iFFT modulo X^16-omega^2 (in registers) + * @param data contains 16 complexes + * @param omegabar 8 complexes in this order: + * gammabar,jb.gammabar,kb.gammabar,kbjb.gammabar, + * betabar,jb.betabar,alphabar,omegabar + * alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta) + * jb = sqrt(ib), kb=sqrt(jb) + */ +void cplx_ifft16_ref(void* data, const void* omegabar) { + CPLX* d = data; + const CPLX* om = omegabar; + // fourth pass inverse + invctwiddle(d[0], d[1], om[0]); + invcitwiddle(d[2], d[3], om[0]); + invctwiddle(d[4], d[5], om[1]); + invcitwiddle(d[6], d[7], om[1]); + invctwiddle(d[8], d[9], om[2]); + invcitwiddle(d[10], d[11], om[2]); + invctwiddle(d[12], d[13], om[3]); + invcitwiddle(d[14], d[15], om[3]); + // third pass inverse + invctwiddle(d[0], d[2], om[4]); + invctwiddle(d[1], d[3], om[4]); + invcitwiddle(d[4], d[6], om[4]); + invcitwiddle(d[5], d[7], om[4]); + invctwiddle(d[8], d[10], om[5]); + invctwiddle(d[9], d[11], om[5]); + invcitwiddle(d[12], d[14], om[5]); + invcitwiddle(d[13], d[15], om[5]); + // second pass inverse + invctwiddle(d[0], d[4], om[6]); + invctwiddle(d[1], d[5], om[6]); + invctwiddle(d[2], d[6], om[6]); + invctwiddle(d[3], d[7], om[6]); + invcitwiddle(d[8], d[12], om[6]); + invcitwiddle(d[9], d[13], om[6]); + invcitwiddle(d[10], d[14], om[6]); + invcitwiddle(d[11], d[15], om[6]); + // first pass + for (uint64_t i = 0; i < 8; ++i) { + invctwiddle(d[0 + i], d[8 + i], om[7]); + } +} + +void cplx_ifft_ref_bfs_2(CPLX* dat, const CPLX** omg, uint32_t m) { + CPLX* const dend = dat + m; + for (CPLX* d = dat; d < dend; d += 2) { + split_fft_last_ref(d, (*omg)[0]); + *omg += 1; + } +#if 0 + printf("after first: "); + for (uint64_t ii=0; iim; + const CPLX* omg = (CPLX*)precomp->powomegas; + if (m == 1) return; + if (m <= 8) return cplx_ifft_ref_bfs_2(data, &omg, m); + if (m <= 2048) return cplx_ifft_ref_bfs_16(data, &omg, m); + cplx_ifft_ref_rec_16(data, &omg, m); +} + +EXPORT CPLX_IFFT_PRECOMP* new_cplx_ifft_precomp(uint32_t m, uint32_t num_buffers) { + const uint64_t OMG_SPACE = ceilto64b(2 * m * sizeof(CPLX)); + const uint64_t BUF_SIZE = ceilto64b(m * sizeof(CPLX)); + void* reps = malloc(sizeof(CPLX_IFFT_PRECOMP) + 63 // padding + + OMG_SPACE // tables + + num_buffers * BUF_SIZE // buffers + ); + uint64_t aligned_addr = ceilto64b((uint64_t) reps + sizeof(CPLX_IFFT_PRECOMP)); + CPLX_IFFT_PRECOMP* r = (CPLX_IFFT_PRECOMP*)reps; + r->m = m; + r->buf_size = BUF_SIZE; + r->powomegas = (double*)aligned_addr; + r->aligned_buffers = (void*)(aligned_addr + OMG_SPACE); + // fill in powomegas + CPLX* omg = (CPLX*)r->powomegas; + fill_cplx_ifft_omegas_rec_16(0.25, &omg, m); + if (((uint64_t)omg) - aligned_addr > OMG_SPACE) abort(); + { + if (m <= 4) { + // currently, we do not have any acceletated + // implementation for m<=4 + r->function = cplx_ifft_ref; + } else if (CPU_SUPPORTS("fma")) { + r->function = cplx_ifft_avx2_fma; + } else { + r->function = cplx_ifft_ref; + } + } + return reps; +} + +EXPORT void* cplx_ifft_precomp_get_buffer(const CPLX_IFFT_PRECOMP* itables, uint32_t buffer_index) { + return (uint8_t*) itables->aligned_buffers + buffer_index * itables->buf_size; +} + +EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* itables, void* data) { + itables->function(itables, data); +} + +EXPORT void cplx_ifft_simple(uint32_t m, void* data) { + static CPLX_IFFT_PRECOMP* p[31] = {0}; + CPLX_IFFT_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_cplx_ifft_precomp(m, 0); + (*f)->function(*f, data); +} + diff --git a/spqlios/lib/spqlios/cplx/spqlios_cplx_fft.c b/spqlios/lib/spqlios/cplx/spqlios_cplx_fft.c new file mode 100644 index 0000000..e69de29 diff --git a/spqlios/lib/spqlios/ext/neon_accel/macrof.h b/spqlios/lib/spqlios/ext/neon_accel/macrof.h new file mode 100644 index 0000000..ac48eef --- /dev/null +++ b/spqlios/lib/spqlios/ext/neon_accel/macrof.h @@ -0,0 +1,138 @@ +/* + * This file is extracted from the implementation of the FFT on Arm64/Neon + * available in https://github.com/cothan/Falcon-Arm (neon/macrof.h). + * ============================================================================= + * Copyright (c) 2022 by Cryptographic Engineering Research Group (CERG) + * ECE Department, George Mason University + * Fairfax, VA, U.S.A. + * @author: Duc Tri Nguyen dnguye69@gmu.edu, cothannguyen@gmail.com + * Licensed under the Apache License, Version 2.0 (the "License"); + * ============================================================================= + * + * This 64-bit Floating point NEON macro x1 has not been modified and is provided as is. + */ + +#ifndef MACROF_H +#define MACROF_H + +#include + +// c <= addr x1 +#define vload(c, addr) c = vld1q_f64(addr); +// c <= addr interleave 2 +#define vload2(c, addr) c = vld2q_f64(addr); +// c <= addr interleave 4 +#define vload4(c, addr) c = vld4q_f64(addr); + +#define vstore(addr, c) vst1q_f64(addr, c); +// addr <= c +#define vstore2(addr, c) vst2q_f64(addr, c); +// addr <= c +#define vstore4(addr, c) vst4q_f64(addr, c); + +// c <= addr x2 +#define vloadx2(c, addr) c = vld1q_f64_x2(addr); +// c <= addr x3 +#define vloadx3(c, addr) c = vld1q_f64_x3(addr); + +// addr <= c +#define vstorex2(addr, c) vst1q_f64_x2(addr, c); + +// c = a - b +#define vfsub(c, a, b) c = vsubq_f64(a, b); + +// c = a + b +#define vfadd(c, a, b) c = vaddq_f64(a, b); + +// c = a * b +#define vfmul(c, a, b) c = vmulq_f64(a, b); + +// c = a * n (n is constant) +#define vfmuln(c, a, n) c = vmulq_n_f64(a, n); + +// Swap from a|b to b|a +#define vswap(c, a) c = vextq_f64(a, a, 1); + +// c = a * b[i] +#define vfmul_lane(c, a, b, i) c = vmulq_laneq_f64(a, b, i); + +// c = 1/a +#define vfinv(c, a) c = vdivq_f64(vdupq_n_f64(1.0), a); + +// c = -a +#define vfneg(c, a) c = vnegq_f64(a); + +#define transpose_f64(a, b, t, ia, ib, it) \ + t.val[it] = a.val[ia]; \ + a.val[ia] = vzip1q_f64(a.val[ia], b.val[ib]); \ + b.val[ib] = vzip2q_f64(t.val[it], b.val[ib]); + +/* + * c = a + jb + * c[0] = a[0] - b[1] + * c[1] = a[1] + b[0] + */ +#define vfcaddj(c, a, b) c = vcaddq_rot90_f64(a, b); + +/* + * c = a - jb + * c[0] = a[0] + b[1] + * c[1] = a[1] - b[0] + */ +#define vfcsubj(c, a, b) c = vcaddq_rot270_f64(a, b); + +// c[0] = c[0] + b[0]*a[0], c[1] = c[1] + b[1]*a[0] +#define vfcmla(c, a, b) c = vcmlaq_f64(c, a, b); + +// c[0] = c[0] - b[1]*a[1], c[1] = c[1] + b[0]*a[1] +#define vfcmla_90(c, a, b) c = vcmlaq_rot90_f64(c, a, b); + +// c[0] = c[0] - b[0]*a[0], c[1] = c[1] - b[1]*a[0] +#define vfcmla_180(c, a, b) c = vcmlaq_rot180_f64(c, a, b); + +// c[0] = c[0] + b[1]*a[1], c[1] = c[1] - b[0]*a[1] +#define vfcmla_270(c, a, b) c = vcmlaq_rot270_f64(c, a, b); + +/* + * Complex MUL: c = a*b + * c[0] = a[0]*b[0] - a[1]*b[1] + * c[1] = a[0]*b[1] + a[1]*b[0] + */ +#define FPC_CMUL(c, a, b) \ + c = vmulq_laneq_f64(b, a, 0); \ + c = vcmlaq_rot90_f64(c, a, b); + +/* + * Complex MUL: c = a * conjugate(b) = a * (b[0], -b[1]) + * c[0] = b[0]*a[0] + b[1]*a[1] + * c[1] = + b[0]*a[1] - b[1]*a[0] + */ +#define FPC_CMUL_CONJ(c, a, b) \ + c = vmulq_laneq_f64(a, b, 0); \ + c = vcmlaq_rot270_f64(c, b, a); + +#if FMA == 1 +// d = c + a *b +#define vfmla(d, c, a, b) d = vfmaq_f64(c, a, b); +// d = c - a * b +#define vfmls(d, c, a, b) d = vfmsq_f64(c, a, b); +// d = c + a * b[i] +#define vfmla_lane(d, c, a, b, i) d = vfmaq_laneq_f64(c, a, b, i); +// d = c - a * b[i] +#define vfmls_lane(d, c, a, b, i) d = vfmsq_laneq_f64(c, a, b, i); + +#else + // d = c + a *b + #define vfmla(d, c, a, b) d = vaddq_f64(c, vmulq_f64(a, b)); + // d = c - a *b + #define vfmls(d, c, a, b) d = vsubq_f64(c, vmulq_f64(a, b)); + // d = c + a * b[i] + #define vfmla_lane(d, c, a, b, i) \ + d = vaddq_f64(c, vmulq_laneq_f64(a, b, i)); + + #define vfmls_lane(d, c, a, b, i) \ + d = vsubq_f64(c, vmulq_laneq_f64(a, b, i)); + +#endif + +#endif diff --git a/spqlios/lib/spqlios/ext/neon_accel/macrofx4.h b/spqlios/lib/spqlios/ext/neon_accel/macrofx4.h new file mode 100644 index 0000000..9bddc9f --- /dev/null +++ b/spqlios/lib/spqlios/ext/neon_accel/macrofx4.h @@ -0,0 +1,428 @@ +/* + * This file is extracted from the implementation of the FFT on Arm64/Neon + * available in https://github.com/cothan/Falcon-Arm (neon/macrof.h). + * ============================================================================= + * Copyright (c) 2022 by Cryptographic Engineering Research Group (CERG) + * ECE Department, George Mason University + * Fairfax, VA, U.S.A. + * @author: Duc Tri Nguyen dnguye69@gmu.edu, cothannguyen@gmail.com + * Licensed under the Apache License, Version 2.0 (the "License"); + * ============================================================================= + * + * This 64-bit Floating point NEON macro x4 has not been modified and is provided as is. + */ + +#ifndef MACROFX4_H +#define MACROFX4_H + +#include +#include "macrof.h" + +#define vloadx4(c, addr) c = vld1q_f64_x4(addr); + +#define vstorex4(addr, c) vst1q_f64_x4(addr, c); + +#define vfdupx4(c, constant) \ + c.val[0] = vdupq_n_f64(constant); \ + c.val[1] = vdupq_n_f64(constant); \ + c.val[2] = vdupq_n_f64(constant); \ + c.val[3] = vdupq_n_f64(constant); + +#define vfnegx4(c, a) \ + c.val[0] = vnegq_f64(a.val[0]); \ + c.val[1] = vnegq_f64(a.val[1]); \ + c.val[2] = vnegq_f64(a.val[2]); \ + c.val[3] = vnegq_f64(a.val[3]); + +#define vfmulnx4(c, a, n) \ + c.val[0] = vmulq_n_f64(a.val[0], n); \ + c.val[1] = vmulq_n_f64(a.val[1], n); \ + c.val[2] = vmulq_n_f64(a.val[2], n); \ + c.val[3] = vmulq_n_f64(a.val[3], n); + +// c = a - b +#define vfsubx4(c, a, b) \ + c.val[0] = vsubq_f64(a.val[0], b.val[0]); \ + c.val[1] = vsubq_f64(a.val[1], b.val[1]); \ + c.val[2] = vsubq_f64(a.val[2], b.val[2]); \ + c.val[3] = vsubq_f64(a.val[3], b.val[3]); + +// c = a + b +#define vfaddx4(c, a, b) \ + c.val[0] = vaddq_f64(a.val[0], b.val[0]); \ + c.val[1] = vaddq_f64(a.val[1], b.val[1]); \ + c.val[2] = vaddq_f64(a.val[2], b.val[2]); \ + c.val[3] = vaddq_f64(a.val[3], b.val[3]); + +#define vfmulx4(c, a, b) \ + c.val[0] = vmulq_f64(a.val[0], b.val[0]); \ + c.val[1] = vmulq_f64(a.val[1], b.val[1]); \ + c.val[2] = vmulq_f64(a.val[2], b.val[2]); \ + c.val[3] = vmulq_f64(a.val[3], b.val[3]); + +#define vfmulx4_i(c, a, b) \ + c.val[0] = vmulq_f64(a.val[0], b); \ + c.val[1] = vmulq_f64(a.val[1], b); \ + c.val[2] = vmulq_f64(a.val[2], b); \ + c.val[3] = vmulq_f64(a.val[3], b); + +#define vfinvx4(c, a) \ + c.val[0] = vdivq_f64(vdupq_n_f64(1.0), a.val[0]); \ + c.val[1] = vdivq_f64(vdupq_n_f64(1.0), a.val[1]); \ + c.val[2] = vdivq_f64(vdupq_n_f64(1.0), a.val[2]); \ + c.val[3] = vdivq_f64(vdupq_n_f64(1.0), a.val[3]); + +#define vfcvtx4(c, a) \ + c.val[0] = vcvtq_f64_s64(a.val[0]); \ + c.val[1] = vcvtq_f64_s64(a.val[1]); \ + c.val[2] = vcvtq_f64_s64(a.val[2]); \ + c.val[3] = vcvtq_f64_s64(a.val[3]); + +#define vfmlax4(d, c, a, b) \ + vfmla(d.val[0], c.val[0], a.val[0], b.val[0]); \ + vfmla(d.val[1], c.val[1], a.val[1], b.val[1]); \ + vfmla(d.val[2], c.val[2], a.val[2], b.val[2]); \ + vfmla(d.val[3], c.val[3], a.val[3], b.val[3]); + +#define vfmlsx4(d, c, a, b) \ + vfmls(d.val[0], c.val[0], a.val[0], b.val[0]); \ + vfmls(d.val[1], c.val[1], a.val[1], b.val[1]); \ + vfmls(d.val[2], c.val[2], a.val[2], b.val[2]); \ + vfmls(d.val[3], c.val[3], a.val[3], b.val[3]); + +#define vfrintx4(c, a) \ + c.val[0] = vcvtnq_s64_f64(a.val[0]); \ + c.val[1] = vcvtnq_s64_f64(a.val[1]); \ + c.val[2] = vcvtnq_s64_f64(a.val[2]); \ + c.val[3] = vcvtnq_s64_f64(a.val[3]); + +/* + * Wrapper for FFT, split/merge and poly_float.c + */ + +#define FPC_MUL(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmul(d_re, a_re, b_re); \ + vfmls(d_re, d_re, a_im, b_im); \ + vfmul(d_im, a_re, b_im); \ + vfmla(d_im, d_im, a_im, b_re); + +#define FPC_MULx2(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmul(d_re.val[0], a_re.val[0], b_re.val[0]); \ + vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \ + vfmul(d_re.val[1], a_re.val[1], b_re.val[1]); \ + vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \ + vfmul(d_im.val[0], a_re.val[0], b_im.val[0]); \ + vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \ + vfmul(d_im.val[1], a_re.val[1], b_im.val[1]); \ + vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]); + +#define FPC_MULx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmul(d_re.val[0], a_re.val[0], b_re.val[0]); \ + vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \ + vfmul(d_re.val[1], a_re.val[1], b_re.val[1]); \ + vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \ + vfmul(d_re.val[2], a_re.val[2], b_re.val[2]); \ + vfmls(d_re.val[2], d_re.val[2], a_im.val[2], b_im.val[2]); \ + vfmul(d_re.val[3], a_re.val[3], b_re.val[3]); \ + vfmls(d_re.val[3], d_re.val[3], a_im.val[3], b_im.val[3]); \ + vfmul(d_im.val[0], a_re.val[0], b_im.val[0]); \ + vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \ + vfmul(d_im.val[1], a_re.val[1], b_im.val[1]); \ + vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]); \ + vfmul(d_im.val[2], a_re.val[2], b_im.val[2]); \ + vfmla(d_im.val[2], d_im.val[2], a_im.val[2], b_re.val[2]); \ + vfmul(d_im.val[3], a_re.val[3], b_im.val[3]); \ + vfmla(d_im.val[3], d_im.val[3], a_im.val[3], b_re.val[3]); + +#define FPC_MLA(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmla(d_re, d_re, a_re, b_re); \ + vfmls(d_re, d_re, a_im, b_im); \ + vfmla(d_im, d_im, a_re, b_im); \ + vfmla(d_im, d_im, a_im, b_re); + +#define FPC_MLAx2(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \ + vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \ + vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \ + vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \ + vfmla(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \ + vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \ + vfmla(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \ + vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]); + +#define FPC_MLAx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \ + vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \ + vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \ + vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \ + vfmla(d_re.val[2], d_re.val[2], a_re.val[2], b_re.val[2]); \ + vfmls(d_re.val[2], d_re.val[2], a_im.val[2], b_im.val[2]); \ + vfmla(d_re.val[3], d_re.val[3], a_re.val[3], b_re.val[3]); \ + vfmls(d_re.val[3], d_re.val[3], a_im.val[3], b_im.val[3]); \ + vfmla(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \ + vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \ + vfmla(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \ + vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]); \ + vfmla(d_im.val[2], d_im.val[2], a_re.val[2], b_im.val[2]); \ + vfmla(d_im.val[2], d_im.val[2], a_im.val[2], b_re.val[2]); \ + vfmla(d_im.val[3], d_im.val[3], a_re.val[3], b_im.val[3]); \ + vfmla(d_im.val[3], d_im.val[3], a_im.val[3], b_re.val[3]); + +#define FPC_MUL_CONJx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmul(d_re.val[0], b_im.val[0], a_im.val[0]); \ + vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \ + vfmul(d_re.val[1], b_im.val[1], a_im.val[1]); \ + vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \ + vfmul(d_re.val[2], b_im.val[2], a_im.val[2]); \ + vfmla(d_re.val[2], d_re.val[2], a_re.val[2], b_re.val[2]); \ + vfmul(d_re.val[3], b_im.val[3], a_im.val[3]); \ + vfmla(d_re.val[3], d_re.val[3], a_re.val[3], b_re.val[3]); \ + vfmul(d_im.val[0], b_re.val[0], a_im.val[0]); \ + vfmls(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \ + vfmul(d_im.val[1], b_re.val[1], a_im.val[1]); \ + vfmls(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \ + vfmul(d_im.val[2], b_re.val[2], a_im.val[2]); \ + vfmls(d_im.val[2], d_im.val[2], a_re.val[2], b_im.val[2]); \ + vfmul(d_im.val[3], b_re.val[3], a_im.val[3]); \ + vfmls(d_im.val[3], d_im.val[3], a_re.val[3], b_im.val[3]); + +#define FPC_MLA_CONJx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmla(d_re.val[0], d_re.val[0], b_im.val[0], a_im.val[0]); \ + vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \ + vfmla(d_re.val[1], d_re.val[1], b_im.val[1], a_im.val[1]); \ + vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \ + vfmla(d_re.val[2], d_re.val[2], b_im.val[2], a_im.val[2]); \ + vfmla(d_re.val[2], d_re.val[2], a_re.val[2], b_re.val[2]); \ + vfmla(d_re.val[3], d_re.val[3], b_im.val[3], a_im.val[3]); \ + vfmla(d_re.val[3], d_re.val[3], a_re.val[3], b_re.val[3]); \ + vfmla(d_im.val[0], d_im.val[0], b_re.val[0], a_im.val[0]); \ + vfmls(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \ + vfmla(d_im.val[1], d_im.val[1], b_re.val[1], a_im.val[1]); \ + vfmls(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \ + vfmla(d_im.val[2], d_im.val[2], b_re.val[2], a_im.val[2]); \ + vfmls(d_im.val[2], d_im.val[2], a_re.val[2], b_im.val[2]); \ + vfmla(d_im.val[3], d_im.val[3], b_re.val[3], a_im.val[3]); \ + vfmls(d_im.val[3], d_im.val[3], a_re.val[3], b_im.val[3]); + +#define FPC_MUL_LANE(d_re, d_im, a_re, a_im, b_re_im) \ + vfmul_lane(d_re, a_re, b_re_im, 0); \ + vfmls_lane(d_re, d_re, a_im, b_re_im, 1); \ + vfmul_lane(d_im, a_re, b_re_im, 1); \ + vfmla_lane(d_im, d_im, a_im, b_re_im, 0); + +#define FPC_MUL_LANEx4(d_re, d_im, a_re, a_im, b_re_im) \ + vfmul_lane(d_re.val[0], a_re.val[0], b_re_im, 0); \ + vfmls_lane(d_re.val[0], d_re.val[0], a_im.val[0], b_re_im, 1); \ + vfmul_lane(d_re.val[1], a_re.val[1], b_re_im, 0); \ + vfmls_lane(d_re.val[1], d_re.val[1], a_im.val[1], b_re_im, 1); \ + vfmul_lane(d_re.val[2], a_re.val[2], b_re_im, 0); \ + vfmls_lane(d_re.val[2], d_re.val[2], a_im.val[2], b_re_im, 1); \ + vfmul_lane(d_re.val[3], a_re.val[3], b_re_im, 0); \ + vfmls_lane(d_re.val[3], d_re.val[3], a_im.val[3], b_re_im, 1); \ + vfmul_lane(d_im.val[0], a_re.val[0], b_re_im, 1); \ + vfmla_lane(d_im.val[0], d_im.val[0], a_im.val[0], b_re_im, 0); \ + vfmul_lane(d_im.val[1], a_re.val[1], b_re_im, 1); \ + vfmla_lane(d_im.val[1], d_im.val[1], a_im.val[1], b_re_im, 0); \ + vfmul_lane(d_im.val[2], a_re.val[2], b_re_im, 1); \ + vfmla_lane(d_im.val[2], d_im.val[2], a_im.val[2], b_re_im, 0); \ + vfmul_lane(d_im.val[3], a_re.val[3], b_re_im, 1); \ + vfmla_lane(d_im.val[3], d_im.val[3], a_im.val[3], b_re_im, 0); + +#define FWD_TOP(t_re, t_im, b_re, b_im, zeta_re, zeta_im) \ + FPC_MUL(t_re, t_im, b_re, b_im, zeta_re, zeta_im); + +#define FWD_TOP_LANE(t_re, t_im, b_re, b_im, zeta) \ + FPC_MUL_LANE(t_re, t_im, b_re, b_im, zeta); + +#define FWD_TOP_LANEx4(t_re, t_im, b_re, b_im, zeta) \ + FPC_MUL_LANEx4(t_re, t_im, b_re, b_im, zeta); + +/* + * FPC + */ + +#define FPC_SUB(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re = vsubq_f64(a_re, b_re); \ + d_im = vsubq_f64(a_im, b_im); + +#define FPC_SUBx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re.val[0] = vsubq_f64(a_re.val[0], b_re.val[0]); \ + d_im.val[0] = vsubq_f64(a_im.val[0], b_im.val[0]); \ + d_re.val[1] = vsubq_f64(a_re.val[1], b_re.val[1]); \ + d_im.val[1] = vsubq_f64(a_im.val[1], b_im.val[1]); \ + d_re.val[2] = vsubq_f64(a_re.val[2], b_re.val[2]); \ + d_im.val[2] = vsubq_f64(a_im.val[2], b_im.val[2]); \ + d_re.val[3] = vsubq_f64(a_re.val[3], b_re.val[3]); \ + d_im.val[3] = vsubq_f64(a_im.val[3], b_im.val[3]); + +#define FPC_ADD(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re = vaddq_f64(a_re, b_re); \ + d_im = vaddq_f64(a_im, b_im); + +#define FPC_ADDx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re.val[0] = vaddq_f64(a_re.val[0], b_re.val[0]); \ + d_im.val[0] = vaddq_f64(a_im.val[0], b_im.val[0]); \ + d_re.val[1] = vaddq_f64(a_re.val[1], b_re.val[1]); \ + d_im.val[1] = vaddq_f64(a_im.val[1], b_im.val[1]); \ + d_re.val[2] = vaddq_f64(a_re.val[2], b_re.val[2]); \ + d_im.val[2] = vaddq_f64(a_im.val[2], b_im.val[2]); \ + d_re.val[3] = vaddq_f64(a_re.val[3], b_re.val[3]); \ + d_im.val[3] = vaddq_f64(a_im.val[3], b_im.val[3]); + +#define FWD_BOT(a_re, a_im, b_re, b_im, t_re, t_im) \ + FPC_SUB(b_re, b_im, a_re, a_im, t_re, t_im); \ + FPC_ADD(a_re, a_im, a_re, a_im, t_re, t_im); + +#define FWD_BOTx4(a_re, a_im, b_re, b_im, t_re, t_im) \ + FPC_SUBx4(b_re, b_im, a_re, a_im, t_re, t_im); \ + FPC_ADDx4(a_re, a_im, a_re, a_im, t_re, t_im); + +/* + * FPC_J + */ + +#define FPC_ADDJ(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re = vsubq_f64(a_re, b_im); \ + d_im = vaddq_f64(a_im, b_re); + +#define FPC_ADDJx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re.val[0] = vsubq_f64(a_re.val[0], b_im.val[0]); \ + d_im.val[0] = vaddq_f64(a_im.val[0], b_re.val[0]); \ + d_re.val[1] = vsubq_f64(a_re.val[1], b_im.val[1]); \ + d_im.val[1] = vaddq_f64(a_im.val[1], b_re.val[1]); \ + d_re.val[2] = vsubq_f64(a_re.val[2], b_im.val[2]); \ + d_im.val[2] = vaddq_f64(a_im.val[2], b_re.val[2]); \ + d_re.val[3] = vsubq_f64(a_re.val[3], b_im.val[3]); \ + d_im.val[3] = vaddq_f64(a_im.val[3], b_re.val[3]); + +#define FPC_SUBJ(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re = vaddq_f64(a_re, b_im); \ + d_im = vsubq_f64(a_im, b_re); + +#define FPC_SUBJx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re.val[0] = vaddq_f64(a_re.val[0], b_im.val[0]); \ + d_im.val[0] = vsubq_f64(a_im.val[0], b_re.val[0]); \ + d_re.val[1] = vaddq_f64(a_re.val[1], b_im.val[1]); \ + d_im.val[1] = vsubq_f64(a_im.val[1], b_re.val[1]); \ + d_re.val[2] = vaddq_f64(a_re.val[2], b_im.val[2]); \ + d_im.val[2] = vsubq_f64(a_im.val[2], b_re.val[2]); \ + d_re.val[3] = vaddq_f64(a_re.val[3], b_im.val[3]); \ + d_im.val[3] = vsubq_f64(a_im.val[3], b_re.val[3]); + +#define FWD_BOTJ(a_re, a_im, b_re, b_im, t_re, t_im) \ + FPC_SUBJ(b_re, b_im, a_re, a_im, t_re, t_im); \ + FPC_ADDJ(a_re, a_im, a_re, a_im, t_re, t_im); + +#define FWD_BOTJx4(a_re, a_im, b_re, b_im, t_re, t_im) \ + FPC_SUBJx4(b_re, b_im, a_re, a_im, t_re, t_im); \ + FPC_ADDJx4(a_re, a_im, a_re, a_im, t_re, t_im); + +//============== Inverse FFT +/* + * FPC_J + * a * conj(b) + * Original (without swap): + * d_re = b_im * a_im + a_re * b_re; + * d_im = b_re * a_im - a_re * b_im; + */ +#define FPC_MUL_BOTJ_LANE(d_re, d_im, a_re, a_im, b_re_im) \ + vfmul_lane(d_re, a_re, b_re_im, 0); \ + vfmla_lane(d_re, d_re, a_im, b_re_im, 1); \ + vfmul_lane(d_im, a_im, b_re_im, 0); \ + vfmls_lane(d_im, d_im, a_re, b_re_im, 1); + +#define FPC_MUL_BOTJ_LANEx4(d_re, d_im, a_re, a_im, b_re_im) \ + vfmul_lane(d_re.val[0], a_re.val[0], b_re_im, 0); \ + vfmla_lane(d_re.val[0], d_re.val[0], a_im.val[0], b_re_im, 1); \ + vfmul_lane(d_im.val[0], a_im.val[0], b_re_im, 0); \ + vfmls_lane(d_im.val[0], d_im.val[0], a_re.val[0], b_re_im, 1); \ + vfmul_lane(d_re.val[1], a_re.val[1], b_re_im, 0); \ + vfmla_lane(d_re.val[1], d_re.val[1], a_im.val[1], b_re_im, 1); \ + vfmul_lane(d_im.val[1], a_im.val[1], b_re_im, 0); \ + vfmls_lane(d_im.val[1], d_im.val[1], a_re.val[1], b_re_im, 1); \ + vfmul_lane(d_re.val[2], a_re.val[2], b_re_im, 0); \ + vfmla_lane(d_re.val[2], d_re.val[2], a_im.val[2], b_re_im, 1); \ + vfmul_lane(d_im.val[2], a_im.val[2], b_re_im, 0); \ + vfmls_lane(d_im.val[2], d_im.val[2], a_re.val[2], b_re_im, 1); \ + vfmul_lane(d_re.val[3], a_re.val[3], b_re_im, 0); \ + vfmla_lane(d_re.val[3], d_re.val[3], a_im.val[3], b_re_im, 1); \ + vfmul_lane(d_im.val[3], a_im.val[3], b_re_im, 0); \ + vfmls_lane(d_im.val[3], d_im.val[3], a_re.val[3], b_re_im, 1); + +#define FPC_MUL_BOTJ(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmul(d_re, b_im, a_im); \ + vfmla(d_re, d_re, a_re, b_re); \ + vfmul(d_im, b_re, a_im); \ + vfmls(d_im, d_im, a_re, b_im); + +#define INV_TOPJ(t_re, t_im, a_re, a_im, b_re, b_im) \ + FPC_SUB(t_re, t_im, a_re, a_im, b_re, b_im); \ + FPC_ADD(a_re, a_im, a_re, a_im, b_re, b_im); + +#define INV_TOPJx4(t_re, t_im, a_re, a_im, b_re, b_im) \ + FPC_SUBx4(t_re, t_im, a_re, a_im, b_re, b_im); \ + FPC_ADDx4(a_re, a_im, a_re, a_im, b_re, b_im); + +#define INV_BOTJ(b_re, b_im, t_re, t_im, zeta_re, zeta_im) \ + FPC_MUL_BOTJ(b_re, b_im, t_re, t_im, zeta_re, zeta_im); + +#define INV_BOTJ_LANE(b_re, b_im, t_re, t_im, zeta) \ + FPC_MUL_BOTJ_LANE(b_re, b_im, t_re, t_im, zeta); + +#define INV_BOTJ_LANEx4(b_re, b_im, t_re, t_im, zeta) \ + FPC_MUL_BOTJ_LANEx4(b_re, b_im, t_re, t_im, zeta); + +/* + * FPC_Jm + * a * -conj(b) + * d_re = a_re * b_im - a_im * b_re; + * d_im = a_im * b_im + a_re * b_re; + */ +#define FPC_MUL_BOTJm_LANE(d_re, d_im, a_re, a_im, b_re_im) \ + vfmul_lane(d_re, a_re, b_re_im, 1); \ + vfmls_lane(d_re, d_re, a_im, b_re_im, 0); \ + vfmul_lane(d_im, a_re, b_re_im, 0); \ + vfmla_lane(d_im, d_im, a_im, b_re_im, 1); + +#define FPC_MUL_BOTJm_LANEx4(d_re, d_im, a_re, a_im, b_re_im) \ + vfmul_lane(d_re.val[0], a_re.val[0], b_re_im, 1); \ + vfmls_lane(d_re.val[0], d_re.val[0], a_im.val[0], b_re_im, 0); \ + vfmul_lane(d_im.val[0], a_re.val[0], b_re_im, 0); \ + vfmla_lane(d_im.val[0], d_im.val[0], a_im.val[0], b_re_im, 1); \ + vfmul_lane(d_re.val[1], a_re.val[1], b_re_im, 1); \ + vfmls_lane(d_re.val[1], d_re.val[1], a_im.val[1], b_re_im, 0); \ + vfmul_lane(d_im.val[1], a_re.val[1], b_re_im, 0); \ + vfmla_lane(d_im.val[1], d_im.val[1], a_im.val[1], b_re_im, 1); \ + vfmul_lane(d_re.val[2], a_re.val[2], b_re_im, 1); \ + vfmls_lane(d_re.val[2], d_re.val[2], a_im.val[2], b_re_im, 0); \ + vfmul_lane(d_im.val[2], a_re.val[2], b_re_im, 0); \ + vfmla_lane(d_im.val[2], d_im.val[2], a_im.val[2], b_re_im, 1); \ + vfmul_lane(d_re.val[3], a_re.val[3], b_re_im, 1); \ + vfmls_lane(d_re.val[3], d_re.val[3], a_im.val[3], b_re_im, 0); \ + vfmul_lane(d_im.val[3], a_re.val[3], b_re_im, 0); \ + vfmla_lane(d_im.val[3], d_im.val[3], a_im.val[3], b_re_im, 1); + +#define FPC_MUL_BOTJm(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmul(d_re, a_re, b_im); \ + vfmls(d_re, d_re, a_im, b_re); \ + vfmul(d_im, a_im, b_im); \ + vfmla(d_im, d_im, a_re, b_re); + +#define INV_TOPJm(t_re, t_im, a_re, a_im, b_re, b_im) \ + FPC_SUB(t_re, t_im, b_re, b_im, a_re, a_im); \ + FPC_ADD(a_re, a_im, a_re, a_im, b_re, b_im); + +#define INV_TOPJmx4(t_re, t_im, a_re, a_im, b_re, b_im) \ + FPC_SUBx4(t_re, t_im, b_re, b_im, a_re, a_im); \ + FPC_ADDx4(a_re, a_im, a_re, a_im, b_re, b_im); + +#define INV_BOTJm(b_re, b_im, t_re, t_im, zeta_re, zeta_im) \ + FPC_MUL_BOTJm(b_re, b_im, t_re, t_im, zeta_re, zeta_im); + +#define INV_BOTJm_LANE(b_re, b_im, t_re, t_im, zeta) \ + FPC_MUL_BOTJm_LANE(b_re, b_im, t_re, t_im, zeta); + +#define INV_BOTJm_LANEx4(b_re, b_im, t_re, t_im, zeta) \ + FPC_MUL_BOTJm_LANEx4(b_re, b_im, t_re, t_im, zeta); + +#endif diff --git a/spqlios/lib/spqlios/q120/q120_arithmetic.h b/spqlios/lib/spqlios/q120/q120_arithmetic.h new file mode 100644 index 0000000..31745d0 --- /dev/null +++ b/spqlios/lib/spqlios/q120/q120_arithmetic.h @@ -0,0 +1,115 @@ +#ifndef SPQLIOS_Q120_ARITHMETIC_H +#define SPQLIOS_Q120_ARITHMETIC_H + +#include + +#include "../commons.h" +#include "q120_common.h" + +typedef struct _q120_mat1col_product_baa_precomp q120_mat1col_product_baa_precomp; +typedef struct _q120_mat1col_product_bbb_precomp q120_mat1col_product_bbb_precomp; +typedef struct _q120_mat1col_product_bbc_precomp q120_mat1col_product_bbc_precomp; + +EXPORT q120_mat1col_product_baa_precomp* q120_new_vec_mat1col_product_baa_precomp(); +EXPORT void q120_delete_vec_mat1col_product_baa_precomp(q120_mat1col_product_baa_precomp*); +EXPORT q120_mat1col_product_bbb_precomp* q120_new_vec_mat1col_product_bbb_precomp(); +EXPORT void q120_delete_vec_mat1col_product_bbb_precomp(q120_mat1col_product_bbb_precomp*); +EXPORT q120_mat1col_product_bbc_precomp* q120_new_vec_mat1col_product_bbc_precomp(); +EXPORT void q120_delete_vec_mat1col_product_bbc_precomp(q120_mat1col_product_bbc_precomp*); + +// ell < 10000 +EXPORT void q120_vec_mat1col_product_baa_ref(q120_mat1col_product_baa_precomp*, const uint64_t ell, q120b* const res, + const q120a* const x, const q120a* const y); +EXPORT void q120_vec_mat1col_product_bbb_ref(q120_mat1col_product_bbb_precomp*, const uint64_t ell, q120b* const res, + const q120b* const x, const q120b* const y); +EXPORT void q120_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp*, const uint64_t ell, q120b* const res, + const q120b* const x, const q120c* const y); + +EXPORT void q120_vec_mat1col_product_baa_avx2(q120_mat1col_product_baa_precomp*, const uint64_t ell, q120b* const res, + const q120a* const x, const q120a* const y); +EXPORT void q120_vec_mat1col_product_bbb_avx2(q120_mat1col_product_bbb_precomp*, const uint64_t ell, q120b* const res, + const q120b* const x, const q120b* const y); +EXPORT void q120_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp*, const uint64_t ell, q120b* const res, + const q120b* const x, const q120c* const y); + +EXPORT void q120x2_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); +EXPORT void q120x2_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); +EXPORT void q120x2_vec_mat2cols_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); +EXPORT void q120x2_vec_mat2cols_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); + +/** + * @brief extract 1 q120x2 block from one q120 ntt vectors + * @param nn the size of each vector + * @param blk the block id to extract ( + +#include "q120_arithmetic.h" +#include "q120_arithmetic_private.h" + +EXPORT void q120_vec_mat1col_product_baa_avx2(q120_mat1col_product_baa_precomp* precomp, const uint64_t ell, + q120b* const res, const q120a* const x, const q120a* const y) { + /** + * Algorithm: + * - res = acc1 + acc2 . ((2^H) % Q) + * - acc1 is the sum of H LSB of products x[i].y[i] + * - acc2 is the sum of 64-H MSB of products x[i]].y[i] + * - for l < 10k acc1 will have H + log2(10000) and acc2 64 - H + log2(10000) bits + * - final sum has max(H, 64 - H + bit_size((2^H) % Q)) + log2(10000) + 1 bits + */ + + const uint64_t H = precomp->h; + const __m256i MASK = _mm256_set1_epi64x((UINT64_C(1) << H) - 1); + + __m256i acc1 = _mm256_setzero_si256(); + __m256i acc2 = _mm256_setzero_si256(); + + const __m256i* x_ptr = (__m256i*)x; + const __m256i* y_ptr = (__m256i*)y; + + for (uint64_t i = 0; i < ell; ++i) { + __m256i a = _mm256_loadu_si256(x_ptr); + __m256i b = _mm256_loadu_si256(y_ptr); + __m256i t = _mm256_mul_epu32(a, b); + + acc1 = _mm256_add_epi64(acc1, _mm256_and_si256(t, MASK)); + acc2 = _mm256_add_epi64(acc2, _mm256_srli_epi64(t, H)); + + x_ptr++; + y_ptr++; + } + + const __m256i H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->h_pow_red); + + __m256i t = _mm256_add_epi64(acc1, _mm256_mul_epu32(acc2, H_POW_RED)); + _mm256_storeu_si256((__m256i*)res, t); +} + +EXPORT void q120_vec_mat1col_product_bbb_avx2(q120_mat1col_product_bbb_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120b* const y) { + /** + * Algorithm: + * 1. Split x_i and y_i in 2 32-bit parts and compute the cross-products: + * - x_i = xl_i + xh_i . 2^32 + * - y_i = yl_i + yh_i . 2^32 + * - A_i = xl_i . yl_i + * - B_i = xl_i . yh_i + * - C_i = xh_i . yl_i + * - D_i = xh_i . yh_i + * - we have x_i . y_i == A_i + (B_i + C_i) . 2^32 + D_i . 2^64 + * 2. Split A_i, B_i, C_i and D_i into 2 32-bit parts + * - A_i = Al_i + Ah_i . 2^32 + * - B_i = Bl_i + Bh_i . 2^32 + * - C_i = Cl_i + Ch_i . 2^32 + * - D_i = Dl_i + Dh_i . 2^32 + * 3. Compute the sums: + * - S1 = \sum Al_i + * - S2 = \sum (Ah_i + Bl_i + Cl_i) + * - S3 = \sum (Bh_i + Ch_i + Dl_i) + * - S4 = \sum Dh_i + * - here S1, S4 have 32 + log2(ell) bits and S2, S3 have 32 + log2(ell) + + * log2(3) bits + * - for ell == 10000 S2, S3 have < 47 bits + * 4. Split S1, S2, S3 and S4 in 2 24-bit parts (24 = ceil(47/2)) + * - S1 = S1l + S1h . 2^24 + * - S2 = S2l + S2h . 2^24 + * - S3 = S3l + S3h . 2^24 + * - S4 = S4l + S4h . 2^24 + * 5. Compute final result as: + * - \sum x_i . y_i = S1l + S1h . 2^24 + * + S2l . 2^32 + S2h . 2^(32+24) + * + S3l . 2^64 + S3h . 2^(64 + 24) + * + S4l . 2^96 + S4l . 2^(96+24) + * - here the powers of 2 are reduced modulo the primes Q before + * multiplications + * - the result will be on 24 + 3 + bit size of primes Q + */ + const uint64_t H1 = 32; + const __m256i MASK1 = _mm256_set1_epi64x((UINT64_C(1) << H1) - 1); + + __m256i s1 = _mm256_setzero_si256(); + __m256i s2 = _mm256_setzero_si256(); + __m256i s3 = _mm256_setzero_si256(); + __m256i s4 = _mm256_setzero_si256(); + + const __m256i* x_ptr = (__m256i*)x; + const __m256i* y_ptr = (__m256i*)y; + + for (uint64_t i = 0; i < ell; ++i) { + __m256i x = _mm256_loadu_si256(x_ptr); + __m256i xl = _mm256_and_si256(x, MASK1); + __m256i xh = _mm256_srli_epi64(x, H1); + + __m256i y = _mm256_loadu_si256(y_ptr); + __m256i yl = _mm256_and_si256(y, MASK1); + __m256i yh = _mm256_srli_epi64(y, H1); + + __m256i a = _mm256_mul_epu32(xl, yl); + __m256i b = _mm256_mul_epu32(xl, yh); + __m256i c = _mm256_mul_epu32(xh, yl); + __m256i d = _mm256_mul_epu32(xh, yh); + + s1 = _mm256_add_epi64(s1, _mm256_and_si256(a, MASK1)); + + s2 = _mm256_add_epi64(s2, _mm256_srli_epi64(a, H1)); + s2 = _mm256_add_epi64(s2, _mm256_and_si256(b, MASK1)); + s2 = _mm256_add_epi64(s2, _mm256_and_si256(c, MASK1)); + + s3 = _mm256_add_epi64(s3, _mm256_srli_epi64(b, H1)); + s3 = _mm256_add_epi64(s3, _mm256_srli_epi64(c, H1)); + s3 = _mm256_add_epi64(s3, _mm256_and_si256(d, MASK1)); + + s4 = _mm256_add_epi64(s4, _mm256_srli_epi64(d, H1)); + + x_ptr++; + y_ptr++; + } + + const uint64_t H2 = precomp->h; + const __m256i MASK2 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); + + const __m256i S1H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s1h_pow_red); + __m256i s1l = _mm256_and_si256(s1, MASK2); + __m256i s1h = _mm256_srli_epi64(s1, H2); + __m256i t = _mm256_add_epi64(s1l, _mm256_mul_epu32(s1h, S1H_POW_RED)); + + const __m256i S2L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); + const __m256i S2H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); + __m256i s2l = _mm256_and_si256(s2, MASK2); + __m256i s2h = _mm256_srli_epi64(s2, H2); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s2l, S2L_POW_RED)); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s2h, S2H_POW_RED)); + + const __m256i S3L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s3l_pow_red); + const __m256i S3H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s3h_pow_red); + __m256i s3l = _mm256_and_si256(s3, MASK2); + __m256i s3h = _mm256_srli_epi64(s3, H2); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s3l, S3L_POW_RED)); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s3h, S3H_POW_RED)); + + const __m256i S4L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s4l_pow_red); + const __m256i S4H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s4h_pow_red); + __m256i s4l = _mm256_and_si256(s4, MASK2); + __m256i s4h = _mm256_srli_epi64(s4, H2); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s4l, S4L_POW_RED)); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s4h, S4H_POW_RED)); + + _mm256_storeu_si256((__m256i*)res, t); +} + +EXPORT void q120_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + /** + * Algorithm: + * 0. We have + * - y0_i == y_i % Q and y1_i == (y_i . 2^32) % Q + * 1. Split x_i in 2 32-bit parts and compute the cross-products: + * - x_i = xl_i + xh_i . 2^32 + * - A_i = xl_i . y1_i + * - B_i = xh_i . y2_i + * - we have x_i . y_i == A_i + B_i + * 2. Split A_i and B_i into 2 32-bit parts + * - A_i = Al_i + Ah_i . 2^32 + * - B_i = Bl_i + Bh_i . 2^32 + * 3. Compute the sums: + * - S1 = \sum Al_i + Bl_i + * - S2 = \sum Ah_i + Bh_i + * - here S1 and S2 have 32 + log2(ell) bits + * - for ell == 10000 S1, S2 have < 46 bits + * 4. Split S2 in 27-bit and 19-bit parts (27+19 == 46) + * - S2 = S2l + S2h . 2^27 + * 5. Compute final result as: + * - \sum x_i . y_i = S1 + S2l . 2^32 + S2h . 2^(32+27) + * - here the powers of 2 are reduced modulo the primes Q before + * multiplications + * - the result will be on < 52 bits + */ + + const uint64_t H1 = 32; + const __m256i MASK1 = _mm256_set1_epi64x((UINT64_C(1) << H1) - 1); + + __m256i s1 = _mm256_setzero_si256(); + __m256i s2 = _mm256_setzero_si256(); + + const __m256i* x_ptr = (__m256i*)x; + const __m256i* y_ptr = (__m256i*)y; + + for (uint64_t i = 0; i < ell; ++i) { + __m256i x = _mm256_loadu_si256(x_ptr); + __m256i xl = _mm256_and_si256(x, MASK1); + __m256i xh = _mm256_srli_epi64(x, H1); + + __m256i y = _mm256_loadu_si256(y_ptr); + __m256i y0 = _mm256_and_si256(y, MASK1); + __m256i y1 = _mm256_srli_epi64(y, H1); + + __m256i a = _mm256_mul_epu32(xl, y0); + __m256i b = _mm256_mul_epu32(xh, y1); + + s1 = _mm256_add_epi64(s1, _mm256_and_si256(a, MASK1)); + s1 = _mm256_add_epi64(s1, _mm256_and_si256(b, MASK1)); + + s2 = _mm256_add_epi64(s2, _mm256_srli_epi64(a, H1)); + s2 = _mm256_add_epi64(s2, _mm256_srli_epi64(b, H1)); + + x_ptr++; + y_ptr++; + } + + const uint64_t H2 = precomp->h; + const __m256i MASK2 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); + + __m256i t = s1; + + const __m256i S2L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); + const __m256i S2H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); + __m256i s2l = _mm256_and_si256(s2, MASK2); + __m256i s2h = _mm256_srli_epi64(s2, H2); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s2l, S2L_POW_RED)); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s2h, S2H_POW_RED)); + + _mm256_storeu_si256((__m256i*)res, t); +} + +/** + * @deprecated keeping this one for history only. + * There is a slight register starvation condition on the q120x2_vec_mat2cols + * strategy below sounds better. + */ +EXPORT void q120x2_vec_mat2cols_product_bbc_avx2_old(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + __m256i s0 = _mm256_setzero_si256(); // col 1a + __m256i s1 = _mm256_setzero_si256(); + __m256i s2 = _mm256_setzero_si256(); // col 1b + __m256i s3 = _mm256_setzero_si256(); + __m256i s4 = _mm256_setzero_si256(); // col 2a + __m256i s5 = _mm256_setzero_si256(); + __m256i s6 = _mm256_setzero_si256(); // col 2b + __m256i s7 = _mm256_setzero_si256(); + __m256i s8, s9, s10, s11; + __m256i s12, s13, s14, s15; + + const __m256i* x_ptr = (__m256i*)x; + const __m256i* y_ptr = (__m256i*)y; + __m256i* res_ptr = (__m256i*)res; + for (uint64_t i = 0; i < ell; ++i) { + s8 = _mm256_loadu_si256(x_ptr); + s9 = _mm256_loadu_si256(x_ptr + 1); + s10 = _mm256_srli_epi64(s8, 32); + s11 = _mm256_srli_epi64(s9, 32); + + s12 = _mm256_loadu_si256(y_ptr); + s13 = _mm256_loadu_si256(y_ptr + 1); + s14 = _mm256_srli_epi64(s12, 32); + s15 = _mm256_srli_epi64(s13, 32); + + s12 = _mm256_mul_epu32(s8, s12); // -> s0,s1 + s13 = _mm256_mul_epu32(s9, s13); // -> s2,s3 + s14 = _mm256_mul_epu32(s10, s14); // -> s0,s1 + s15 = _mm256_mul_epu32(s11, s15); // -> s2,s3 + + s10 = _mm256_slli_epi64(s12, 32); // -> s0 + s11 = _mm256_slli_epi64(s13, 32); // -> s2 + s12 = _mm256_srli_epi64(s12, 32); // -> s1 + s13 = _mm256_srli_epi64(s13, 32); // -> s3 + s10 = _mm256_srli_epi64(s10, 32); // -> s0 + s11 = _mm256_srli_epi64(s11, 32); // -> s2 + + s0 = _mm256_add_epi64(s0, s10); + s1 = _mm256_add_epi64(s1, s12); + s2 = _mm256_add_epi64(s2, s11); + s3 = _mm256_add_epi64(s3, s13); + + s10 = _mm256_slli_epi64(s14, 32); // -> s0 + s11 = _mm256_slli_epi64(s15, 32); // -> s2 + s14 = _mm256_srli_epi64(s14, 32); // -> s1 + s15 = _mm256_srli_epi64(s15, 32); // -> s3 + s10 = _mm256_srli_epi64(s10, 32); // -> s0 + s11 = _mm256_srli_epi64(s11, 32); // -> s2 + + s0 = _mm256_add_epi64(s0, s10); + s1 = _mm256_add_epi64(s1, s14); + s2 = _mm256_add_epi64(s2, s11); + s3 = _mm256_add_epi64(s3, s15); + + // deal with the second column + // s8,s9 are still in place! + s10 = _mm256_srli_epi64(s8, 32); + s11 = _mm256_srli_epi64(s9, 32); + + s12 = _mm256_loadu_si256(y_ptr + 2); + s13 = _mm256_loadu_si256(y_ptr + 3); + s14 = _mm256_srli_epi64(s12, 32); + s15 = _mm256_srli_epi64(s13, 32); + + s12 = _mm256_mul_epu32(s8, s12); // -> s4,s5 + s13 = _mm256_mul_epu32(s9, s13); // -> s6,s7 + s14 = _mm256_mul_epu32(s10, s14); // -> s4,s5 + s15 = _mm256_mul_epu32(s11, s15); // -> s6,s7 + + s10 = _mm256_slli_epi64(s12, 32); // -> s4 + s11 = _mm256_slli_epi64(s13, 32); // -> s6 + s12 = _mm256_srli_epi64(s12, 32); // -> s5 + s13 = _mm256_srli_epi64(s13, 32); // -> s7 + s10 = _mm256_srli_epi64(s10, 32); // -> s4 + s11 = _mm256_srli_epi64(s11, 32); // -> s6 + + s4 = _mm256_add_epi64(s4, s10); + s5 = _mm256_add_epi64(s5, s12); + s6 = _mm256_add_epi64(s6, s11); + s7 = _mm256_add_epi64(s7, s13); + + s10 = _mm256_slli_epi64(s14, 32); // -> s4 + s11 = _mm256_slli_epi64(s15, 32); // -> s6 + s14 = _mm256_srli_epi64(s14, 32); // -> s5 + s15 = _mm256_srli_epi64(s15, 32); // -> s7 + s10 = _mm256_srli_epi64(s10, 32); // -> s4 + s11 = _mm256_srli_epi64(s11, 32); // -> s6 + + s4 = _mm256_add_epi64(s4, s10); + s5 = _mm256_add_epi64(s5, s14); + s6 = _mm256_add_epi64(s6, s11); + s7 = _mm256_add_epi64(s7, s15); + + x_ptr += 2; + y_ptr += 4; + } + // final reduction + const uint64_t H2 = precomp->h; + s8 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); // MASK2 + s9 = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); // S2L_POW_RED + s10 = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); // S2H_POW_RED + //--- s0,s1 + s11 = _mm256_and_si256(s1, s8); + s12 = _mm256_srli_epi64(s1, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s0 = _mm256_add_epi64(s0, s13); + s0 = _mm256_add_epi64(s0, s14); + _mm256_storeu_si256(res_ptr + 0, s0); + //--- s2,s3 + s11 = _mm256_and_si256(s3, s8); + s12 = _mm256_srli_epi64(s3, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s2 = _mm256_add_epi64(s2, s13); + s2 = _mm256_add_epi64(s2, s14); + _mm256_storeu_si256(res_ptr + 1, s2); + //--- s4,s5 + s11 = _mm256_and_si256(s5, s8); + s12 = _mm256_srli_epi64(s5, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s4 = _mm256_add_epi64(s4, s13); + s4 = _mm256_add_epi64(s4, s14); + _mm256_storeu_si256(res_ptr + 2, s4); + //--- s6,s7 + s11 = _mm256_and_si256(s7, s8); + s12 = _mm256_srli_epi64(s7, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s6 = _mm256_add_epi64(s6, s13); + s6 = _mm256_add_epi64(s6, s14); + _mm256_storeu_si256(res_ptr + 3, s6); +} + +EXPORT void q120x2_vec_mat2cols_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + __m256i s0 = _mm256_setzero_si256(); // col 1a + __m256i s1 = _mm256_setzero_si256(); + __m256i s2 = _mm256_setzero_si256(); // col 1b + __m256i s3 = _mm256_setzero_si256(); + __m256i s4 = _mm256_setzero_si256(); // col 2a + __m256i s5 = _mm256_setzero_si256(); + __m256i s6 = _mm256_setzero_si256(); // col 2b + __m256i s7 = _mm256_setzero_si256(); + __m256i s8, s9, s10, s11; + __m256i s12, s13, s14, s15; + + s11 = _mm256_set1_epi64x(0xFFFFFFFFUL); + const __m256i* x_ptr = (__m256i*)x; + const __m256i* y_ptr = (__m256i*)y; + __m256i* res_ptr = (__m256i*)res; + for (uint64_t i = 0; i < ell; ++i) { + // treat item a + s8 = _mm256_loadu_si256(x_ptr); + s9 = _mm256_srli_epi64(s8, 32); + + s12 = _mm256_loadu_si256(y_ptr); + s13 = _mm256_loadu_si256(y_ptr + 2); + s14 = _mm256_srli_epi64(s12, 32); + s15 = _mm256_srli_epi64(s13, 32); + + s12 = _mm256_mul_epu32(s8, s12); // c1a -> s0,s1 + s13 = _mm256_mul_epu32(s8, s13); // c2a -> s4,s5 + s14 = _mm256_mul_epu32(s9, s14); // c1a -> s0,s1 + s15 = _mm256_mul_epu32(s9, s15); // c2a -> s4,s5 + + s8 = _mm256_and_si256(s12, s11); // -> s0 + s9 = _mm256_and_si256(s13, s11); // -> s4 + s12 = _mm256_srli_epi64(s12, 32); // -> s1 + s13 = _mm256_srli_epi64(s13, 32); // -> s5 + s0 = _mm256_add_epi64(s0, s8); + s1 = _mm256_add_epi64(s1, s12); + s4 = _mm256_add_epi64(s4, s9); + s5 = _mm256_add_epi64(s5, s13); + + s8 = _mm256_and_si256(s14, s11); // -> s0 + s9 = _mm256_and_si256(s15, s11); // -> s4 + s14 = _mm256_srli_epi64(s14, 32); // -> s1 + s15 = _mm256_srli_epi64(s15, 32); // -> s5 + s0 = _mm256_add_epi64(s0, s8); + s1 = _mm256_add_epi64(s1, s14); + s4 = _mm256_add_epi64(s4, s9); + s5 = _mm256_add_epi64(s5, s15); + + // treat item b + s8 = _mm256_loadu_si256(x_ptr + 1); + s9 = _mm256_srli_epi64(s8, 32); + + s12 = _mm256_loadu_si256(y_ptr + 1); + s13 = _mm256_loadu_si256(y_ptr + 3); + s14 = _mm256_srli_epi64(s12, 32); + s15 = _mm256_srli_epi64(s13, 32); + + s12 = _mm256_mul_epu32(s8, s12); // c1b -> s2,s3 + s13 = _mm256_mul_epu32(s8, s13); // c2b -> s6,s7 + s14 = _mm256_mul_epu32(s9, s14); // c1b -> s2,s3 + s15 = _mm256_mul_epu32(s9, s15); // c2b -> s6,s7 + + s8 = _mm256_and_si256(s12, s11); // -> s2 + s9 = _mm256_and_si256(s13, s11); // -> s6 + s12 = _mm256_srli_epi64(s12, 32); // -> s3 + s13 = _mm256_srli_epi64(s13, 32); // -> s7 + s2 = _mm256_add_epi64(s2, s8); + s3 = _mm256_add_epi64(s3, s12); + s6 = _mm256_add_epi64(s6, s9); + s7 = _mm256_add_epi64(s7, s13); + + s8 = _mm256_and_si256(s14, s11); // -> s2 + s9 = _mm256_and_si256(s15, s11); // -> s6 + s14 = _mm256_srli_epi64(s14, 32); // -> s3 + s15 = _mm256_srli_epi64(s15, 32); // -> s7 + s2 = _mm256_add_epi64(s2, s8); + s3 = _mm256_add_epi64(s3, s14); + s6 = _mm256_add_epi64(s6, s9); + s7 = _mm256_add_epi64(s7, s15); + + x_ptr += 2; + y_ptr += 4; + } + // final reduction + const uint64_t H2 = precomp->h; + s8 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); // MASK2 + s9 = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); // S2L_POW_RED + s10 = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); // S2H_POW_RED + //--- s0,s1 + s11 = _mm256_and_si256(s1, s8); + s12 = _mm256_srli_epi64(s1, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s0 = _mm256_add_epi64(s0, s13); + s0 = _mm256_add_epi64(s0, s14); + _mm256_storeu_si256(res_ptr + 0, s0); + //--- s2,s3 + s11 = _mm256_and_si256(s3, s8); + s12 = _mm256_srli_epi64(s3, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s2 = _mm256_add_epi64(s2, s13); + s2 = _mm256_add_epi64(s2, s14); + _mm256_storeu_si256(res_ptr + 1, s2); + //--- s4,s5 + s11 = _mm256_and_si256(s5, s8); + s12 = _mm256_srli_epi64(s5, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s4 = _mm256_add_epi64(s4, s13); + s4 = _mm256_add_epi64(s4, s14); + _mm256_storeu_si256(res_ptr + 2, s4); + //--- s6,s7 + s11 = _mm256_and_si256(s7, s8); + s12 = _mm256_srli_epi64(s7, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s6 = _mm256_add_epi64(s6, s13); + s6 = _mm256_add_epi64(s6, s14); + _mm256_storeu_si256(res_ptr + 3, s6); +} + +EXPORT void q120x2_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + __m256i s0 = _mm256_setzero_si256(); // col 1a + __m256i s1 = _mm256_setzero_si256(); + __m256i s2 = _mm256_setzero_si256(); // col 1b + __m256i s3 = _mm256_setzero_si256(); + __m256i s4 = _mm256_set1_epi64x(0xFFFFFFFFUL); + __m256i s8, s9, s10, s11; + __m256i s12, s13, s14, s15; + + const __m256i* x_ptr = (__m256i*)x; + const __m256i* y_ptr = (__m256i*)y; + __m256i* res_ptr = (__m256i*)res; + for (uint64_t i = 0; i < ell; ++i) { + s8 = _mm256_loadu_si256(x_ptr); + s9 = _mm256_loadu_si256(x_ptr + 1); + s10 = _mm256_srli_epi64(s8, 32); + s11 = _mm256_srli_epi64(s9, 32); + + s12 = _mm256_loadu_si256(y_ptr); + s13 = _mm256_loadu_si256(y_ptr + 1); + s14 = _mm256_srli_epi64(s12, 32); + s15 = _mm256_srli_epi64(s13, 32); + + s12 = _mm256_mul_epu32(s8, s12); // -> s0,s1 + s13 = _mm256_mul_epu32(s9, s13); // -> s2,s3 + s14 = _mm256_mul_epu32(s10, s14); // -> s0,s1 + s15 = _mm256_mul_epu32(s11, s15); // -> s2,s3 + + s8 = _mm256_and_si256(s12, s4); // -> s0 + s9 = _mm256_and_si256(s13, s4); // -> s2 + s10 = _mm256_and_si256(s14, s4); // -> s0 + s11 = _mm256_and_si256(s15, s4); // -> s2 + s12 = _mm256_srli_epi64(s12, 32); // -> s1 + s13 = _mm256_srli_epi64(s13, 32); // -> s3 + s14 = _mm256_srli_epi64(s14, 32); // -> s1 + s15 = _mm256_srli_epi64(s15, 32); // -> s3 + + s0 = _mm256_add_epi64(s0, s8); + s1 = _mm256_add_epi64(s1, s12); + s2 = _mm256_add_epi64(s2, s9); + s3 = _mm256_add_epi64(s3, s13); + s0 = _mm256_add_epi64(s0, s10); + s1 = _mm256_add_epi64(s1, s14); + s2 = _mm256_add_epi64(s2, s11); + s3 = _mm256_add_epi64(s3, s15); + + x_ptr += 2; + y_ptr += 2; + } + // final reduction + const uint64_t H2 = precomp->h; + s8 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); // MASK2 + s9 = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); // S2L_POW_RED + s10 = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); // S2H_POW_RED + //--- s0,s1 + s11 = _mm256_and_si256(s1, s8); + s12 = _mm256_srli_epi64(s1, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s0 = _mm256_add_epi64(s0, s13); + s0 = _mm256_add_epi64(s0, s14); + _mm256_storeu_si256(res_ptr + 0, s0); + //--- s2,s3 + s11 = _mm256_and_si256(s3, s8); + s12 = _mm256_srli_epi64(s3, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s2 = _mm256_add_epi64(s2, s13); + s2 = _mm256_add_epi64(s2, s14); + _mm256_storeu_si256(res_ptr + 1, s2); +} diff --git a/spqlios/lib/spqlios/q120/q120_arithmetic_private.h b/spqlios/lib/spqlios/q120/q120_arithmetic_private.h new file mode 100644 index 0000000..399f989 --- /dev/null +++ b/spqlios/lib/spqlios/q120/q120_arithmetic_private.h @@ -0,0 +1,37 @@ +#ifndef SPQLIOS_Q120_ARITHMETIC_DEF_H +#define SPQLIOS_Q120_ARITHMETIC_DEF_H + +#include + +typedef struct _q120_mat1col_product_baa_precomp { + uint64_t h; + uint64_t h_pow_red[4]; +#ifndef NDEBUG + double res_bit_size; +#endif +} q120_mat1col_product_baa_precomp; + +typedef struct _q120_mat1col_product_bbb_precomp { + uint64_t h; + uint64_t s1h_pow_red[4]; + uint64_t s2l_pow_red[4]; + uint64_t s2h_pow_red[4]; + uint64_t s3l_pow_red[4]; + uint64_t s3h_pow_red[4]; + uint64_t s4l_pow_red[4]; + uint64_t s4h_pow_red[4]; +#ifndef NDEBUG + double res_bit_size; +#endif +} q120_mat1col_product_bbb_precomp; + +typedef struct _q120_mat1col_product_bbc_precomp { + uint64_t h; + uint64_t s2l_pow_red[4]; + uint64_t s2h_pow_red[4]; +#ifndef NDEBUG + double res_bit_size; +#endif +} q120_mat1col_product_bbc_precomp; + +#endif // SPQLIOS_Q120_ARITHMETIC_DEF_H diff --git a/spqlios/lib/spqlios/q120/q120_arithmetic_ref.c b/spqlios/lib/spqlios/q120/q120_arithmetic_ref.c new file mode 100644 index 0000000..cd20f10 --- /dev/null +++ b/spqlios/lib/spqlios/q120/q120_arithmetic_ref.c @@ -0,0 +1,506 @@ +#include +#include +#include + +#include "q120_arithmetic.h" +#include "q120_arithmetic_private.h" +#include "q120_common.h" + +#define MODQ(val, q) ((val) % (q)) + +double comp_bit_size_red(const uint64_t h, const uint64_t qs[4]) { + assert(h < 128); + double h_pow2_bs = 0; + for (uint64_t k = 0; k < 4; ++k) { + double t = log2((double)MODQ((__uint128_t)1 << h, qs[k])); + if (t > h_pow2_bs) h_pow2_bs = t; + } + return h_pow2_bs; +} + +double comp_bit_size_sum(const uint64_t n, const double* const bs) { + double s = 0; + for (uint64_t i = 0; i < n; ++i) { + s += pow(2, bs[i]); + } + return log2(s); +} + +void vec_mat1col_product_baa_precomp(q120_mat1col_product_baa_precomp* precomp) { + uint64_t qs[4] = {Q1, Q2, Q3, Q4}; + + double min_res_bs = 1000; + uint64_t min_h = -1; + + double ell_bs = log2((double)MAX_ELL); + for (uint64_t h = 1; h < 64; ++h) { + double h_pow2_bs = comp_bit_size_red(h, qs); + + const double bs[] = {h + ell_bs, 64 - h + ell_bs + h_pow2_bs}; + const double res_bs = comp_bit_size_sum(2, bs); + + if (min_res_bs > res_bs) { + min_res_bs = res_bs; + min_h = h; + } + } + + assert(min_res_bs < 64); + precomp->h = min_h; + for (uint64_t k = 0; k < 4; ++k) { + precomp->h_pow_red[k] = MODQ(UINT64_C(1) << precomp->h, qs[k]); + } +#ifndef NDEBUG + precomp->res_bit_size = min_res_bs; +#endif + // printf("AA %lu %lf\n", min_h, min_res_bs); +} + +EXPORT q120_mat1col_product_baa_precomp* q120_new_vec_mat1col_product_baa_precomp() { + q120_mat1col_product_baa_precomp* res = malloc(sizeof(q120_mat1col_product_baa_precomp)); + vec_mat1col_product_baa_precomp(res); + return res; +} + +EXPORT void q120_delete_vec_mat1col_product_baa_precomp(q120_mat1col_product_baa_precomp* addr) { free(addr); } + +EXPORT void q120_vec_mat1col_product_baa_ref(q120_mat1col_product_baa_precomp* precomp, const uint64_t ell, + q120b* const res, const q120a* const x, const q120a* const y) { + /** + * Algorithm: + * - res = acc1 + acc2 . ((2^H) % Q) + * - acc1 is the sum of H LSB of products x[i].y[i] + * - acc2 is the sum of 64-H MSB of products x[i]].y[i] + * - for l < 10k acc1 will have H + log2(10000) and acc2 64 - H + log2(10000) bits + * - final sum has max(H, 64 - H + bit_size((2^H) % Q)) + log2(10000) + 1 bits + */ + const uint64_t H = precomp->h; + const uint64_t MASK = (UINT64_C(1) << H) - 1; + + uint64_t acc1[4] = {0, 0, 0, 0}; // accumulate H least significant bits of product + uint64_t acc2[4] = {0, 0, 0, 0}; // accumulate 64 - H most significan bits of product + + const uint64_t* const x_ptr = (uint64_t*)x; + const uint64_t* const y_ptr = (uint64_t*)y; + + for (uint64_t i = 0; i < 4 * ell; i += 4) { + for (uint64_t j = 0; j < 4; ++j) { + uint64_t t = x_ptr[i + j] * y_ptr[i + j]; + acc1[j] += t & MASK; + acc2[j] += t >> H; + } + } + + uint64_t* const res_ptr = (uint64_t*)res; + for (uint64_t j = 0; j < 4; ++j) { + res_ptr[j] = acc1[j] + acc2[j] * precomp->h_pow_red[j]; + assert(log2(res_ptr[j]) < precomp->res_bit_size); + } +} + +void vec_mat1col_product_bbb_precomp(q120_mat1col_product_bbb_precomp* precomp) { + uint64_t qs[4] = {Q1, Q2, Q3, Q4}; + + double ell_bs = log2((double)MAX_ELL); + double min_res_bs = 1000; + uint64_t min_h = -1; + + const double s1_bs = 32 + ell_bs; + const double s2_bs = 32 + ell_bs + log2(3); + const double s3_bs = 32 + ell_bs + log2(3); + const double s4_bs = 32 + ell_bs; + for (uint64_t h = 16; h < 32; ++h) { + const double s1l_bs = h; + const double s1h_bs = (s1_bs - h) + comp_bit_size_red(h, qs); + const double s2l_bs = h + comp_bit_size_red(32, qs); + const double s2h_bs = (s2_bs - h) + comp_bit_size_red(32 + h, qs); + const double s3l_bs = h + comp_bit_size_red(64, qs); + const double s3h_bs = (s3_bs - h) + comp_bit_size_red(64 + h, qs); + const double s4l_bs = h + comp_bit_size_red(96, qs); + const double s4h_bs = (s4_bs - h) + comp_bit_size_red(96 + h, qs); + + const double bs[] = {s1l_bs, s1h_bs, s2l_bs, s2h_bs, s3l_bs, s3h_bs, s4l_bs, s4h_bs}; + const double res_bs = comp_bit_size_sum(8, bs); + + if (min_res_bs > res_bs) { + min_res_bs = res_bs; + min_h = h; + } + } + + assert(min_res_bs < 64); + precomp->h = min_h; + for (uint64_t k = 0; k < 4; ++k) { + precomp->s1h_pow_red[k] = UINT64_C(1) << precomp->h; // 2^24 + precomp->s2l_pow_red[k] = MODQ(UINT64_C(1) << 32, qs[k]); // 2^32 + precomp->s2h_pow_red[k] = MODQ(precomp->s2l_pow_red[k] * precomp->s1h_pow_red[k], qs[k]); // 2^(32+24) + precomp->s3l_pow_red[k] = MODQ(precomp->s2l_pow_red[k] * precomp->s2l_pow_red[k], qs[k]); // 2^64 = 2^(32+32) + precomp->s3h_pow_red[k] = MODQ(precomp->s3l_pow_red[k] * precomp->s1h_pow_red[k], qs[k]); // 2^(64+24) + precomp->s4l_pow_red[k] = MODQ(precomp->s3l_pow_red[k] * precomp->s2l_pow_red[k], qs[k]); // 2^96 = 2^(64+32) + precomp->s4h_pow_red[k] = MODQ(precomp->s4l_pow_red[k] * precomp->s1h_pow_red[k], qs[k]); // 2^(96+24) + } +// printf("AA %lu %lf\n", min_h, min_res_bs); +#ifndef NDEBUG + precomp->res_bit_size = min_res_bs; +#endif +} + +EXPORT q120_mat1col_product_bbb_precomp* q120_new_vec_mat1col_product_bbb_precomp() { + q120_mat1col_product_bbb_precomp* res = malloc(sizeof(q120_mat1col_product_bbb_precomp)); + vec_mat1col_product_bbb_precomp(res); + return res; +} + +EXPORT void q120_delete_vec_mat1col_product_bbb_precomp(q120_mat1col_product_bbb_precomp* addr) { free(addr); } + +EXPORT void q120_vec_mat1col_product_bbb_ref(q120_mat1col_product_bbb_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120b* const y) { + /** + * Algorithm: + * 1. Split x_i and y_i in 2 32-bit parts and compute the cross-products: + * - x_i = xl_i + xh_i . 2^32 + * - y_i = yl_i + yh_i . 2^32 + * - A_i = xl_i . yl_i + * - B_i = xl_i . yh_i + * - C_i = xh_i . yl_i + * - D_i = xh_i . yh_i + * - we have x_i . y_i == A_i + (B_i + C_i) . 2^32 + D_i . 2^64 + * 2. Split A_i, B_i, C_i and D_i into 2 32-bit parts + * - A_i = Al_i + Ah_i . 2^32 + * - B_i = Bl_i + Bh_i . 2^32 + * - C_i = Cl_i + Ch_i . 2^32 + * - D_i = Dl_i + Dh_i . 2^32 + * 3. Compute the sums: + * - S1 = \sum Al_i + * - S2 = \sum (Ah_i + Bl_i + Cl_i) + * - S3 = \sum (Bh_i + Ch_i + Dl_i) + * - S4 = \sum Dh_i + * - here S1, S4 have 32 + log2(ell) bits and S2, S3 have 32 + log2(ell) + + * log2(3) bits + * - for ell == 10000 S2, S3 have < 47 bits + * 4. Split S1, S2, S3 and S4 in 2 24-bit parts (24 = ceil(47/2)) + * - S1 = S1l + S1h . 2^24 + * - S2 = S2l + S2h . 2^24 + * - S3 = S3l + S3h . 2^24 + * - S4 = S4l + S4h . 2^24 + * 5. Compute final result as: + * - \sum x_i . y_i = S1l + S1h . 2^24 + * + S2l . 2^32 + S2h . 2^(32+24) + * + S3l . 2^64 + S3h . 2^(64 + 24) + * + S4l . 2^96 + S4l . 2^(96+24) + * - here the powers of 2 are reduced modulo the primes Q before + * multiplications + * - the result will be on 24 + 3 + bit size of primes Q + */ + const uint64_t H1 = 32; + const uint64_t MASK1 = (UINT64_C(1) << H1) - 1; + + uint64_t s1[4] = {0, 0, 0, 0}; + uint64_t s2[4] = {0, 0, 0, 0}; + uint64_t s3[4] = {0, 0, 0, 0}; + uint64_t s4[4] = {0, 0, 0, 0}; + + const uint64_t* const x_ptr = (uint64_t*)x; + const uint64_t* const y_ptr = (uint64_t*)y; + + for (uint64_t i = 0; i < 4 * ell; i += 4) { + for (uint64_t j = 0; j < 4; ++j) { + const uint64_t xl = x_ptr[i + j] & MASK1; + const uint64_t xh = x_ptr[i + j] >> H1; + const uint64_t yl = y_ptr[i + j] & MASK1; + const uint64_t yh = y_ptr[i + j] >> H1; + + const uint64_t a = xl * yl; + const uint64_t al = a & MASK1; + const uint64_t ah = a >> H1; + + const uint64_t b = xl * yh; + const uint64_t bl = b & MASK1; + const uint64_t bh = b >> H1; + + const uint64_t c = xh * yl; + const uint64_t cl = c & MASK1; + const uint64_t ch = c >> H1; + + const uint64_t d = xh * yh; + const uint64_t dl = d & MASK1; + const uint64_t dh = d >> H1; + + s1[j] += al; + s2[j] += ah + bl + cl; + s3[j] += bh + ch + dl; + s4[j] += dh; + } + } + + const uint64_t H2 = precomp->h; + const uint64_t MASK2 = (UINT64_C(1) << H2) - 1; + + uint64_t* const res_ptr = (uint64_t*)res; + for (uint64_t j = 0; j < 4; ++j) { + const uint64_t s1l = s1[j] & MASK2; + const uint64_t s1h = s1[j] >> H2; + const uint64_t s2l = s2[j] & MASK2; + const uint64_t s2h = s2[j] >> H2; + const uint64_t s3l = s3[j] & MASK2; + const uint64_t s3h = s3[j] >> H2; + const uint64_t s4l = s4[j] & MASK2; + const uint64_t s4h = s4[j] >> H2; + + uint64_t t = s1l; + t += s1h * precomp->s1h_pow_red[j]; + t += s2l * precomp->s2l_pow_red[j]; + t += s2h * precomp->s2h_pow_red[j]; + t += s3l * precomp->s3l_pow_red[j]; + t += s3h * precomp->s3h_pow_red[j]; + t += s4l * precomp->s4l_pow_red[j]; + t += s4h * precomp->s4h_pow_red[j]; + + res_ptr[j] = t; + assert(log2(res_ptr[j]) < precomp->res_bit_size); + } +} + +void vec_mat1col_product_bbc_precomp(q120_mat1col_product_bbc_precomp* precomp) { + uint64_t qs[4] = {Q1, Q2, Q3, Q4}; + + double min_res_bs = 1000; + uint64_t min_h = -1; + + double pow2_32_bs = comp_bit_size_red(32, qs); + + double ell_bs = log2((double)MAX_ELL); + double s1_bs = 32 + ell_bs; + for (uint64_t h = 16; h < 32; ++h) { + double s2l_bs = pow2_32_bs + h; + double s2h_bs = s1_bs - h + comp_bit_size_red(32 + h, qs); + + const double bs[] = {s1_bs, s2l_bs, s2h_bs}; + const double res_bs = comp_bit_size_sum(3, bs); + + if (min_res_bs > res_bs) { + min_res_bs = res_bs; + min_h = h; + } + } + + assert(min_res_bs < 64); + precomp->h = min_h; + for (uint64_t k = 0; k < 4; ++k) { + precomp->s2l_pow_red[k] = MODQ(UINT64_C(1) << 32, qs[k]); + precomp->s2h_pow_red[k] = MODQ(UINT64_C(1) << (32 + precomp->h), qs[k]); + } +#ifndef NDEBUG + precomp->res_bit_size = min_res_bs; +#endif + // printf("AA %lu %lf\n", min_h, min_res_bs); +} + +EXPORT q120_mat1col_product_bbc_precomp* q120_new_vec_mat1col_product_bbc_precomp() { + q120_mat1col_product_bbc_precomp* res = malloc(sizeof(q120_mat1col_product_bbc_precomp)); + vec_mat1col_product_bbc_precomp(res); + return res; +} + +EXPORT void q120_delete_vec_mat1col_product_bbc_precomp(q120_mat1col_product_bbc_precomp* addr) { free(addr); } + +EXPORT void q120_vec_mat1col_product_bbc_ref_old(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + /** + * Algorithm: + * 0. We have + * - y0_i == y_i % Q and y1_i == (y_i . 2^32) % Q + * 1. Split x_i in 2 32-bit parts and compute the cross-products: + * - x_i = xl_i + xh_i . 2^32 + * - A_i = xl_i . y1_i + * - B_i = xh_i . y2_i + * - we have x_i . y_i == A_i + B_i + * 2. Split A_i and B_i into 2 32-bit parts + * - A_i = Al_i + Ah_i . 2^32 + * - B_i = Bl_i + Bh_i . 2^32 + * 3. Compute the sums: + * - S1 = \sum Al_i + Bl_i + * - S2 = \sum Ah_i + Bh_i + * - here S1 and S2 have 32 + log2(ell) bits + * - for ell == 10000 S1, S2 have < 46 bits + * 4. Split S2 in 27-bit and 19-bit parts (27+19 == 46) + * - S2 = S2l + S2h . 2^27 + * 5. Compute final result as: + * - \sum x_i . y_i = S1 + S2l . 2^32 + S2h . 2^(32+27) + * - here the powers of 2 are reduced modulo the primes Q before + * multiplications + * - the result will be on < 52 bits + */ + + const uint64_t H1 = 32; + const uint64_t MASK1 = (UINT64_C(1) << H1) - 1; + + uint64_t s1[4] = {0, 0, 0, 0}; + uint64_t s2[4] = {0, 0, 0, 0}; + + const uint64_t* const x_ptr = (uint64_t*)x; + const uint32_t* const y_ptr = (uint32_t*)y; + + for (uint64_t i = 0; i < 4 * ell; i += 4) { + for (uint64_t j = 0; j < 4; ++j) { + const uint64_t xl = x_ptr[i + j] & MASK1; + const uint64_t xh = x_ptr[i + j] >> H1; + const uint64_t y0 = y_ptr[2 * (i + j)]; + const uint64_t y1 = y_ptr[2 * (i + j) + 1]; + + const uint64_t a = xl * y0; + const uint64_t al = a & MASK1; + const uint64_t ah = a >> H1; + + const uint64_t b = xh * y1; + const uint64_t bl = b & MASK1; + const uint64_t bh = b >> H1; + + s1[j] += al + bl; + s2[j] += ah + bh; + } + } + + const uint64_t H2 = precomp->h; + const uint64_t MASK2 = (UINT64_C(1) << H2) - 1; + + uint64_t* const res_ptr = (uint64_t*)res; + for (uint64_t k = 0; k < 4; ++k) { + const uint64_t s2l = s2[k] & MASK2; + const uint64_t s2h = s2[k] >> H2; + + uint64_t t = s1[k]; + t += s2l * precomp->s2l_pow_red[k]; + t += s2h * precomp->s2h_pow_red[k]; + + res_ptr[k] = t; + assert(log2(res_ptr[k]) < precomp->res_bit_size); + } +} + +static __always_inline void accum_mul_q120_bc(uint64_t res[8], // + const uint32_t x_layb[8], const uint32_t y_layc[8]) { + for (uint64_t i = 0; i < 4; ++i) { + static const uint64_t MASK32 = 0xFFFFFFFFUL; + uint64_t x_lo = x_layb[2 * i]; + uint64_t x_hi = x_layb[2 * i + 1]; + uint64_t y_lo = y_layc[2 * i]; + uint64_t y_hi = y_layc[2 * i + 1]; + uint64_t xy_lo = x_lo * y_lo; + uint64_t xy_hi = x_hi * y_hi; + res[2 * i] += (xy_lo & MASK32) + (xy_hi & MASK32); + res[2 * i + 1] += (xy_lo >> 32) + (xy_hi >> 32); + } +} + +static __always_inline void accum_to_q120b(uint64_t res[4], // + const uint64_t s[8], const q120_mat1col_product_bbc_precomp* precomp) { + const uint64_t H2 = precomp->h; + const uint64_t MASK2 = (UINT64_C(1) << H2) - 1; + for (uint64_t k = 0; k < 4; ++k) { + const uint64_t s2l = s[2 * k + 1] & MASK2; + const uint64_t s2h = s[2 * k + 1] >> H2; + uint64_t t = s[2 * k]; + t += s2l * precomp->s2l_pow_red[k]; + t += s2h * precomp->s2h_pow_red[k]; + res[k] = t; + assert(log2(res[k]) < precomp->res_bit_size); + } +} + +EXPORT void q120_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + uint64_t s[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + + const uint32_t(*const x_ptr)[8] = (const uint32_t(*const)[8])x; + const uint32_t(*const y_ptr)[8] = (const uint32_t(*const)[8])y; + + for (uint64_t i = 0; i < ell; i++) { + accum_mul_q120_bc(s, x_ptr[i], y_ptr[i]); + } + accum_to_q120b((uint64_t*)res, s, precomp); +} + +EXPORT void q120x2_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + uint64_t s[2][16] = {0}; + + const uint32_t(*const x_ptr)[2][8] = (const uint32_t(*const)[2][8])x; + const uint32_t(*const y_ptr)[2][8] = (const uint32_t(*const)[2][8])y; + uint64_t(*re)[4] = (uint64_t(*)[4])res; + + for (uint64_t i = 0; i < ell; i++) { + accum_mul_q120_bc(s[0], x_ptr[i][0], y_ptr[i][0]); + accum_mul_q120_bc(s[1], x_ptr[i][1], y_ptr[i][1]); + } + accum_to_q120b(re[0], s[0], precomp); + accum_to_q120b(re[1], s[1], precomp); +} + +EXPORT void q120x2_vec_mat2cols_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + uint64_t s[4][16] = {0}; + + const uint32_t(*const x_ptr)[2][8] = (const uint32_t(*const)[2][8])x; + const uint32_t(*const y_ptr)[4][8] = (const uint32_t(*const)[4][8])y; + uint64_t(*re)[4] = (uint64_t(*)[4])res; + + for (uint64_t i = 0; i < ell; i++) { + accum_mul_q120_bc(s[0], x_ptr[i][0], y_ptr[i][0]); + accum_mul_q120_bc(s[1], x_ptr[i][1], y_ptr[i][1]); + accum_mul_q120_bc(s[2], x_ptr[i][0], y_ptr[i][2]); + accum_mul_q120_bc(s[3], x_ptr[i][1], y_ptr[i][3]); + } + accum_to_q120b(re[0], s[0], precomp); + accum_to_q120b(re[1], s[1], precomp); + accum_to_q120b(re[2], s[2], precomp); + accum_to_q120b(re[3], s[3], precomp); +} + +EXPORT void q120x2_extract_1blk_from_q120b_ref(uint64_t nn, uint64_t blk, + q120x2b* const dst, // 8 doubles + const q120b* const src // a q120b vector +) { + const uint64_t* in = (uint64_t*)src; + uint64_t* out = (uint64_t*)dst; + for (uint64_t i = 0; i < 8; ++i) { + out[i] = in[8 * blk + i]; + } +} + +// function on layout c is the exact same as on layout b +#ifdef __APPLE__ +#pragma weak q120x2_extract_1blk_from_q120c_ref = q120x2_extract_1blk_from_q120b_ref +#else +EXPORT void q120x2_extract_1blk_from_q120c_ref(uint64_t nn, uint64_t blk, + q120x2c* const dst, // 8 doubles + const q120c* const src // a q120c vector + ) __attribute__((alias("q120x2_extract_1blk_from_q120b_ref"))); +#endif + +EXPORT void q120x2_extract_1blk_from_contiguous_q120b_ref( + uint64_t nn, uint64_t nrows, uint64_t blk, + q120x2b* const dst, // nrows * 2 q120 + const q120b* const src // a contiguous array of nrows q120b vectors +) { + const uint64_t* in = (uint64_t*)src; + uint64_t* out = (uint64_t*)dst; + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t i = 0; i < 8; ++i) { + out[i] = in[8 * blk + i]; + } + in += 4 * nn; + out += 8; + } +} + +EXPORT void q120x2b_save_1blk_to_q120b_ref(uint64_t nn, uint64_t blk, + q120b* dest, // 1 reim vector of length m + const q120x2b* src // 8 doubles +) { + const uint64_t* in = (uint64_t*)src; + uint64_t* out = (uint64_t*)dest; + for (uint64_t i = 0; i < 8; ++i) { + out[8 * blk + i] = in[i]; + } +} diff --git a/spqlios/lib/spqlios/q120/q120_arithmetic_simple.c b/spqlios/lib/spqlios/q120/q120_arithmetic_simple.c new file mode 100644 index 0000000..0753ad5 --- /dev/null +++ b/spqlios/lib/spqlios/q120/q120_arithmetic_simple.c @@ -0,0 +1,111 @@ +#include +#include +#include + +#include "q120_arithmetic.h" +#include "q120_common.h" + +EXPORT void q120_add_bbb_simple(uint64_t nn, q120b* const res, const q120b* const x, const q120b* const y) { + const uint64_t* x_u64 = (uint64_t*)x; + const uint64_t* y_u64 = (uint64_t*)y; + uint64_t* res_u64 = (uint64_t*)res; + for (uint64_t i = 0; i < 4 * nn; i += 4) { + res_u64[i + 0] = x_u64[i + 0] % ((uint64_t)Q1 << 33) + y_u64[i + 0] % ((uint64_t)Q1 << 33); + res_u64[i + 1] = x_u64[i + 1] % ((uint64_t)Q2 << 33) + y_u64[i + 1] % ((uint64_t)Q2 << 33); + res_u64[i + 2] = x_u64[i + 2] % ((uint64_t)Q3 << 33) + y_u64[i + 2] % ((uint64_t)Q3 << 33); + res_u64[i + 3] = x_u64[i + 3] % ((uint64_t)Q4 << 33) + y_u64[i + 3] % ((uint64_t)Q4 << 33); + } +} + +EXPORT void q120_add_ccc_simple(uint64_t nn, q120c* const res, const q120c* const x, const q120c* const y) { + const uint32_t* x_u32 = (uint32_t*)x; + const uint32_t* y_u32 = (uint32_t*)y; + uint32_t* res_u32 = (uint32_t*)res; + for (uint64_t i = 0; i < 8 * nn; i += 8) { + res_u32[i + 0] = (uint32_t)(((uint64_t)x_u32[i + 0] + (uint64_t)y_u32[i + 0]) % Q1); + res_u32[i + 1] = (uint32_t)(((uint64_t)x_u32[i + 1] + (uint64_t)y_u32[i + 1]) % Q1); + res_u32[i + 2] = (uint32_t)(((uint64_t)x_u32[i + 2] + (uint64_t)y_u32[i + 2]) % Q2); + res_u32[i + 3] = (uint32_t)(((uint64_t)x_u32[i + 3] + (uint64_t)y_u32[i + 3]) % Q2); + res_u32[i + 4] = (uint32_t)(((uint64_t)x_u32[i + 4] + (uint64_t)y_u32[i + 4]) % Q3); + res_u32[i + 5] = (uint32_t)(((uint64_t)x_u32[i + 5] + (uint64_t)y_u32[i + 5]) % Q3); + res_u32[i + 6] = (uint32_t)(((uint64_t)x_u32[i + 6] + (uint64_t)y_u32[i + 6]) % Q4); + res_u32[i + 7] = (uint32_t)(((uint64_t)x_u32[i + 7] + (uint64_t)y_u32[i + 7]) % Q4); + } +} + +EXPORT void q120_c_from_b_simple(uint64_t nn, q120c* const res, const q120b* const x) { + const uint64_t* x_u64 = (uint64_t*)x; + uint32_t* res_u32 = (uint32_t*)res; + for (uint64_t i = 0, j = 0; i < 4 * nn; i += 4, j += 8) { + res_u32[j + 0] = x_u64[i + 0] % Q1; + res_u32[j + 1] = ((uint64_t)res_u32[j + 0] << 32) % Q1; + res_u32[j + 2] = x_u64[i + 1] % Q2; + res_u32[j + 3] = ((uint64_t)res_u32[j + 2] << 32) % Q2; + res_u32[j + 4] = x_u64[i + 2] % Q3; + res_u32[j + 5] = ((uint64_t)res_u32[j + 4] << 32) % Q3; + res_u32[j + 6] = x_u64[i + 3] % Q4; + res_u32[j + 7] = ((uint64_t)res_u32[j + 6] << 32) % Q4; + } +} + +EXPORT void q120_b_from_znx64_simple(uint64_t nn, q120b* const res, const int64_t* const x) { + static const int64_t MASK_HI = INT64_C(0x8000000000000000); + static const int64_t MASK_LO = ~MASK_HI; + static const uint64_t OQ[4] = { + (Q1 - (UINT64_C(0x8000000000000000) % Q1)), + (Q2 - (UINT64_C(0x8000000000000000) % Q2)), + (Q3 - (UINT64_C(0x8000000000000000) % Q3)), + (Q4 - (UINT64_C(0x8000000000000000) % Q4)), + }; + uint64_t* res_u64 = (uint64_t*)res; + for (uint64_t i = 0, j = 0; j < nn; i += 4, ++j) { + uint64_t xj_lo = x[j] & MASK_LO; + uint64_t xj_hi = x[j] & MASK_HI; + res_u64[i + 0] = xj_lo + (xj_hi ? OQ[0] : 0); + res_u64[i + 1] = xj_lo + (xj_hi ? OQ[1] : 0); + res_u64[i + 2] = xj_lo + (xj_hi ? OQ[2] : 0); + res_u64[i + 3] = xj_lo + (xj_hi ? OQ[3] : 0); + } +} + +static int64_t posmod(int64_t x, int64_t q) { + int64_t t = x % q; + if (t < 0) + return t + q; + else + return t; +} + +EXPORT void q120_c_from_znx64_simple(uint64_t nn, q120c* const res, const int64_t* const x) { + uint32_t* res_u32 = (uint32_t*)res; + for (uint64_t i = 0, j = 0; j < nn; i += 8, ++j) { + res_u32[i + 0] = posmod(x[j], Q1); + res_u32[i + 1] = ((uint64_t)res_u32[i + 0] << 32) % Q1; + res_u32[i + 2] = posmod(x[j], Q2); + res_u32[i + 3] = ((uint64_t)res_u32[i + 2] << 32) % Q2; + res_u32[i + 4] = posmod(x[j], Q3); + res_u32[i + 5] = ((uint64_t)res_u32[i + 4] << 32) % Q3; + res_u32[i + 6] = posmod(x[j], Q4); + res_u32[i + 7] = ((uint64_t)res_u32[i + 6] << 32) % Q4; + ; + } +} + +EXPORT void q120_b_to_znx128_simple(uint64_t nn, __int128_t* const res, const q120b* const x) { + static const __int128_t Q = (__int128_t)Q1 * Q2 * Q3 * Q4; + static const __int128_t Qm1 = (__int128_t)Q2 * Q3 * Q4; + static const __int128_t Qm2 = (__int128_t)Q1 * Q3 * Q4; + static const __int128_t Qm3 = (__int128_t)Q1 * Q2 * Q4; + static const __int128_t Qm4 = (__int128_t)Q1 * Q2 * Q3; + + const uint64_t* x_u64 = (uint64_t*)x; + for (uint64_t i = 0, j = 0; j < nn; i += 4, ++j) { + __int128_t tmp = 0; + tmp += (((x_u64[i + 0] % Q1) * Q1_CRT_CST) % Q1) * Qm1; + tmp += (((x_u64[i + 1] % Q2) * Q2_CRT_CST) % Q2) * Qm2; + tmp += (((x_u64[i + 2] % Q3) * Q3_CRT_CST) % Q3) * Qm3; + tmp += (((x_u64[i + 3] % Q4) * Q4_CRT_CST) % Q4) * Qm4; + tmp %= Q; + res[j] = (tmp >= (Q + 1) / 2) ? tmp - Q : tmp; + } +} diff --git a/spqlios/lib/spqlios/q120/q120_common.h b/spqlios/lib/spqlios/q120/q120_common.h new file mode 100644 index 0000000..9acef5e --- /dev/null +++ b/spqlios/lib/spqlios/q120/q120_common.h @@ -0,0 +1,94 @@ +#ifndef SPQLIOS_Q120_COMMON_H +#define SPQLIOS_Q120_COMMON_H + +#include + +#if !defined(SPQLIOS_Q120_USE_29_BIT_PRIMES) && !defined(SPQLIOS_Q120_USE_30_BIT_PRIMES) && \ + !defined(SPQLIOS_Q120_USE_31_BIT_PRIMES) +#define SPQLIOS_Q120_USE_30_BIT_PRIMES +#endif + +/** + * 29-bit primes and 2*2^16 roots of unity + */ +#ifdef SPQLIOS_Q120_USE_29_BIT_PRIMES +#define Q1 ((1u << 29) - 2 * (1u << 17) + 1) +#define OMEGA1 78289835 +#define Q1_CRT_CST 301701286 // (Q2*Q3*Q4)^-1 mod Q1 + +#define Q2 ((1u << 29) - 5 * (1u << 17) + 1) +#define OMEGA2 178519192 +#define Q2_CRT_CST 536020447 // (Q1*Q3*Q4)^-1 mod Q2 + +#define Q3 ((1u << 29) - 26 * (1u << 17) + 1) +#define OMEGA3 483889678 +#define Q3_CRT_CST 86367873 // (Q1*Q2*Q4)^-1 mod Q3 + +#define Q4 ((1u << 29) - 35 * (1u << 17) + 1) +#define OMEGA4 239808033 +#define Q4_CRT_CST 147030781 // (Q1*Q2*Q3)^-1 mod Q4 +#endif + +/** + * 30-bit primes and 2*2^16 roots of unity + */ +#ifdef SPQLIOS_Q120_USE_30_BIT_PRIMES +#define Q1 ((1u << 30) - 2 * (1u << 17) + 1) +#define OMEGA1 1070907127 +#define Q1_CRT_CST 43599465 // (Q2*Q3*Q4)^-1 mod Q1 + +#define Q2 ((1u << 30) - 17 * (1u << 17) + 1) +#define OMEGA2 315046632 +#define Q2_CRT_CST 292938863 // (Q1*Q3*Q4)^-1 mod Q2 + +#define Q3 ((1u << 30) - 23 * (1u << 17) + 1) +#define OMEGA3 309185662 +#define Q3_CRT_CST 594011630 // (Q1*Q2*Q4)^-1 mod Q3 + +#define Q4 ((1u << 30) - 42 * (1u << 17) + 1) +#define OMEGA4 846468380 +#define Q4_CRT_CST 140177212 // (Q1*Q2*Q3)^-1 mod Q4 +#endif + +/** + * 31-bit primes and 2*2^16 roots of unity + */ +#ifdef SPQLIOS_Q120_USE_31_BIT_PRIMES +#define Q1 ((1u << 31) - 1 * (1u << 17) + 1) +#define OMEGA1 1615402923 +#define Q1_CRT_CST 1811422063 // (Q2*Q3*Q4)^-1 mod Q1 + +#define Q2 ((1u << 31) - 4 * (1u << 17) + 1) +#define OMEGA2 1137738560 +#define Q2_CRT_CST 2093150204 // (Q1*Q3*Q4)^-1 mod Q2 + +#define Q3 ((1u << 31) - 11 * (1u << 17) + 1) +#define OMEGA3 154880552 +#define Q3_CRT_CST 164149010 // (Q1*Q2*Q4)^-1 mod Q3 + +#define Q4 ((1u << 31) - 23 * (1u << 17) + 1) +#define OMEGA4 558784885 +#define Q4_CRT_CST 225197446 // (Q1*Q2*Q3)^-1 mod Q4 +#endif + +static const uint32_t PRIMES_VEC[4] = {Q1, Q2, Q3, Q4}; +static const uint32_t OMEGAS_VEC[4] = {OMEGA1, OMEGA2, OMEGA3, OMEGA4}; + +#define MAX_ELL 10000 + +// each number x mod Q120 is represented by uint64_t[4] with (non-unique) values (x mod q1, x mod q2,x mod q3,x mod q4), +// each between [0 and 2^32-1] +typedef struct _q120a q120a; + +// each number x mod Q120 is represented by uint64_t[4] with (non-unique) values (x mod q1, x mod q2,x mod q3,x mod q4), +// each between [0 and 2^64-1] +typedef struct _q120b q120b; + +// each number x mod Q120 is represented by uint32_t[8] with values (x mod q1, 2^32x mod q1, x mod q2, 2^32.x mod q2, x +// mod q3, 2^32.x mod q3, x mod q4, 2^32.x mod q4) each between [0 and 2^32-1] +typedef struct _q120c q120c; + +typedef struct _q120x2b q120x2b; +typedef struct _q120x2c q120x2c; + +#endif // SPQLIOS_Q120_COMMON_H diff --git a/spqlios/lib/spqlios/q120/q120_fallbacks_aarch64.c b/spqlios/lib/spqlios/q120/q120_fallbacks_aarch64.c new file mode 100644 index 0000000..18db6cb --- /dev/null +++ b/spqlios/lib/spqlios/q120/q120_fallbacks_aarch64.c @@ -0,0 +1,5 @@ +#include "q120_ntt_private.h" + +EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data) { UNDEFINED(); } + +EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data) { UNDEFINED(); } diff --git a/spqlios/lib/spqlios/q120/q120_ntt.c b/spqlios/lib/spqlios/q120/q120_ntt.c new file mode 100644 index 0000000..f58b98e --- /dev/null +++ b/spqlios/lib/spqlios/q120/q120_ntt.c @@ -0,0 +1,340 @@ +#include +#include +#include +#include + +#include "q120_ntt_private.h" + +q120_ntt_precomp* new_precomp(const uint64_t n) { + q120_ntt_precomp* precomp = malloc(sizeof(*precomp)); + precomp->n = n; + + assert(n && !(n & (n - 1)) && n <= (1 << 16)); // n is a power of 2 smaller than 2^16 + const uint64_t logN = ceil(log2(n)); + precomp->level_metadata = malloc((logN + 2) * sizeof(*precomp->level_metadata)); + + precomp->powomega = spqlios_alloc_custom_align(32, 4 * 2 * n * sizeof(*(precomp->powomega))); + + return precomp; +} + +uint32_t modq_pow(const uint32_t x, const int64_t n, const uint32_t q) { + uint64_t np = (n % (q - 1) + q - 1) % (q - 1); + + uint64_t val_pow = x; + uint64_t res = 1; + while (np != 0) { + if (np & 1) res = (res * val_pow) % q; + val_pow = (val_pow * val_pow) % q; + np >>= 1; + } + return res; +} + +void fill_omegas(const uint64_t n, uint32_t omegas[4]) { + for (uint64_t k = 0; k < 4; ++k) { + omegas[k] = modq_pow(OMEGAS_VEC[k], (1 << 16) / n, PRIMES_VEC[k]); + } + +#ifndef NDEBUG + + const uint64_t logQ = ceil(log2(PRIMES_VEC[0])); + for (int k = 1; k < 4; ++k) { + if (logQ != ceil(log2(PRIMES_VEC[k]))) { + fprintf(stderr, "The 4 primes must have the same bit-size\n"); + exit(-1); + } + } + + // check if each omega is a 2.n primitive root of unity + for (uint64_t k = 0; k < 4; ++k) { + assert(modq_pow(omegas[k], 2 * n, PRIMES_VEC[k]) == 1); + for (uint64_t i = 1; i < 2 * n; ++i) { + assert(modq_pow(omegas[k], i, PRIMES_VEC[k]) != 1); + } + } + + if (logQ > 31) { + fprintf(stderr, "Modulus q bit-size is larger than 30 bit\n"); + exit(-1); + } +#endif +} + +uint64_t fill_reduction_meta(const uint64_t bs_start, q120_ntt_reduc_step_precomp* reduc_metadata) { + // fill reduction metadata + uint64_t bs_after_reduc = -1; + { + uint64_t min_h = -1; + + for (uint64_t h = bs_start / 2; h < bs_start; ++h) { + uint64_t t = 0; + for (uint64_t k = 0; k < 4; ++k) { + const uint64_t t1 = bs_start - h + (uint64_t)ceil(log2((UINT64_C(1) << h) % PRIMES_VEC[k])); + const uint64_t t2 = UINT64_C(1) + ((t1 > h) ? t1 : h); + if (t < t2) t = t2; + } + if (t < bs_after_reduc) { + min_h = h; + bs_after_reduc = t; + } + } + + reduc_metadata->h = min_h; + reduc_metadata->mask = (UINT64_C(1) << min_h) - 1; + for (uint64_t k = 0; k < 4; ++k) { + reduc_metadata->modulo_red_cst[k] = (UINT64_C(1) << min_h) % PRIMES_VEC[k]; + } + + assert(bs_after_reduc < 64); + } + + return bs_after_reduc; +} + +uint64_t round_up_half_n(const uint64_t n) { return (n + 1) / 2; } + +EXPORT q120_ntt_precomp* q120_new_ntt_bb_precomp(const uint64_t n) { + uint32_t omega_vec[4]; + fill_omegas(n, omega_vec); + + const uint64_t logQ = ceil(log2(PRIMES_VEC[0])); + + q120_ntt_precomp* precomp = new_precomp(n); + + uint64_t bs = precomp->input_bit_size = 64; + + LOG("NTT parameters:\n"); + LOG("\tsize = %" PRIu64 "\n", n) + LOG("\tlogQ = %" PRIu64 "\n", logQ); + LOG("\tinput bit-size = %" PRIu64 "\n", bs); + + if (n == 1) return precomp; + + // fill reduction metadata + uint64_t bs_after_reduc = fill_reduction_meta(bs, &(precomp->reduc_metadata)); + + // forward metadata + q120_ntt_step_precomp* level_metadata_ptr = precomp->level_metadata; + + // first level a_k.omega^k + { + const uint64_t half_bs = (bs + 1) / 2; + level_metadata_ptr->half_bs = half_bs; + level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1); + level_metadata_ptr->bs = bs = half_bs + logQ + 1; + LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 " (a_k.omega^k) \n", n, bs); + level_metadata_ptr++; + } + + for (uint64_t nn = n; nn >= 4; nn /= 2) { + level_metadata_ptr->reduce = (bs == 64); + if (level_metadata_ptr->reduce) { + bs = bs_after_reduc; + LOG("\treduce output bit-size = %" PRIu64 "\n", bs); + } + + for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs - logQ); + + double bs_1 = bs + 1; // bit-size of term a+b or a-b + + const uint64_t half_bs = round_up_half_n(bs_1); + uint64_t bs_2 = half_bs + logQ + 1; // bit-size of term (a-b).omega^k + bs = (bs_1 > bs_2) ? bs_1 : bs_2; + assert(bs <= 64); + + level_metadata_ptr->bs = bs; + level_metadata_ptr->half_bs = half_bs; + level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1); + level_metadata_ptr++; + + LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", nn / 2, bs); + } + + // last level (a-b, a+b) + { + level_metadata_ptr->reduce = (bs == 64); + if (level_metadata_ptr->reduce) { + bs = bs_after_reduc; + LOG("\treduce output bit-size = %" PRIu64 "\n", bs); + } + for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = ((uint64_t)PRIMES_VEC[k] << (bs - logQ)); + level_metadata_ptr->bs = ++bs; + level_metadata_ptr->half_bs = level_metadata_ptr->mask = UINT64_C(0); // not used + + LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", UINT64_C(1), bs); + } + precomp->output_bit_size = bs; + + // omega powers + uint64_t* powomega = malloc(sizeof(*powomega) * 2 * n); + for (uint64_t k = 0; k < 4; ++k) { + const uint64_t q = PRIMES_VEC[k]; + + for (uint64_t i = 0; i < 2 * n; ++i) { + powomega[i] = modq_pow(omega_vec[k], i, q); + } + + uint64_t* powomega_ptr = precomp->powomega + k; + level_metadata_ptr = precomp->level_metadata; + + { + // const uint64_t hpow = UINT64_C(1) << level_metadata_ptr->half_bs; + for (uint64_t i = 0; i < n; ++i) { + uint64_t t = powomega[i]; + uint64_t t1 = (t << level_metadata_ptr->half_bs) % q; + powomega_ptr[4 * i] = (t1 << 32) + t; + } + powomega_ptr += 4 * n; + level_metadata_ptr++; + } + + for (uint64_t nn = n; nn >= 4; nn /= 2) { + const uint64_t halfnn = nn / 2; + const uint64_t m = n / halfnn; + + // const uint64_t hpow = UINT64_C(1) << level_metadata_ptr->half_bs; + for (uint64_t i = 1; i < halfnn; ++i) { + uint64_t t = powomega[i * m]; + uint64_t t1 = (t << level_metadata_ptr->half_bs) % q; + powomega_ptr[4 * (i - 1)] = (t1 << 32) + t; + } + powomega_ptr += 4 * (halfnn - 1); + level_metadata_ptr++; + } + } + free(powomega); + + return precomp; +} + +EXPORT q120_ntt_precomp* q120_new_intt_bb_precomp(const uint64_t n) { + uint32_t omega_vec[4]; + fill_omegas(n, omega_vec); + + const uint64_t logQ = ceil(log2(PRIMES_VEC[0])); + + q120_ntt_precomp* precomp = new_precomp(n); + + uint64_t bs = precomp->input_bit_size = 64; + + LOG("iNTT parameters:\n"); + LOG("\tsize = %" PRIu64 "\n", n) + LOG("\tlogQ = %" PRIu64 "\n", logQ); + LOG("\tinput bit-size = %" PRIu64 "\n", bs); + + if (n == 1) return precomp; + + // fill reduction metadata + uint64_t bs_after_reduc = fill_reduction_meta(bs, &(precomp->reduc_metadata)); + + // backward metadata + q120_ntt_step_precomp* level_metadata_ptr = precomp->level_metadata; + + // first level (a+b, a-b) adds 1-bit + { + level_metadata_ptr->reduce = (bs == 64); + if (level_metadata_ptr->reduce) { + bs = bs_after_reduc; + LOG("\treduce output bit-size = %" PRIu64 "\n", bs); + } + + for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs - logQ); + + level_metadata_ptr->bs = ++bs; + level_metadata_ptr->half_bs = level_metadata_ptr->mask = UINT64_C(0); // not used + LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", UINT64_C(1), bs); + level_metadata_ptr++; + } + + for (uint64_t nn = 4; nn <= n; nn *= 2) { + level_metadata_ptr->reduce = (bs == 64); + if (level_metadata_ptr->reduce) { + bs = bs_after_reduc; + LOG("\treduce output bit-size = %" PRIu64 "\n", bs); + } + + const uint64_t half_bs = round_up_half_n(bs); + const uint64_t bs_mult = half_bs + logQ + 1; // bit-size of term b.omega^k + bs = 1 + ((bs > bs_mult) ? bs : bs_mult); // bit-size of a+b.omega^k or a-b.omega^k + assert(bs <= 64); + + level_metadata_ptr->bs = bs; + level_metadata_ptr->half_bs = half_bs; + level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1); + for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs_mult - logQ); + level_metadata_ptr++; + + LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", nn / 2, bs); + } + + // last level a_k.omega^k + { + level_metadata_ptr->reduce = (bs == 64); + if (level_metadata_ptr->reduce) { + bs = bs_after_reduc; + LOG("\treduce output bit-size = %" PRIu64 "\n", bs); + } + + const uint64_t half_bs = round_up_half_n(bs); + + bs = half_bs + logQ + 1; // bit-size of term a.omega^k + assert(bs <= 64); + + level_metadata_ptr->bs = bs; + level_metadata_ptr->half_bs = half_bs; + level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1); + for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs - logQ); + + LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", n, bs); + } + + // omega powers + uint32_t* powomegabar = malloc(sizeof(*powomegabar) * 2 * n); + for (int k = 0; k < 4; ++k) { + const uint64_t q = PRIMES_VEC[k]; + + for (uint64_t i = 0; i < 2 * n; ++i) { + powomegabar[i] = modq_pow(omega_vec[k], -i, q); + } + + uint64_t* powomega_ptr = precomp->powomega + k; + level_metadata_ptr = precomp->level_metadata + 1; + + for (uint64_t nn = 4; nn <= n; nn *= 2) { + const uint64_t halfnn = nn / 2; + const uint64_t m = n / halfnn; + + for (uint64_t i = 1; i < halfnn; ++i) { + uint64_t t = powomegabar[i * m]; + uint64_t t1 = (t << level_metadata_ptr->half_bs) % q; + powomega_ptr[4 * (i - 1)] = (t1 << 32) + t; + } + powomega_ptr += 4 * (halfnn - 1); + level_metadata_ptr++; + } + + { + const uint64_t invNmod = modq_pow(n, -1, q); + for (uint64_t i = 0; i < n; ++i) { + uint64_t t = (powomegabar[i] * invNmod) % q; + uint64_t t1 = (t << level_metadata_ptr->half_bs) % q; + powomega_ptr[4 * i] = (t1 << 32) + t; + } + } + } + + free(powomegabar); + + return precomp; +} + +void del_precomp(q120_ntt_precomp* precomp) { + spqlios_free(precomp->powomega); + free(precomp->level_metadata); + free(precomp); +} + +EXPORT void q120_del_ntt_bb_precomp(q120_ntt_precomp* precomp) { del_precomp(precomp); } + +EXPORT void q120_del_intt_bb_precomp(q120_ntt_precomp* precomp) { del_precomp(precomp); } diff --git a/spqlios/lib/spqlios/q120/q120_ntt.h b/spqlios/lib/spqlios/q120/q120_ntt.h new file mode 100644 index 0000000..329b54d --- /dev/null +++ b/spqlios/lib/spqlios/q120/q120_ntt.h @@ -0,0 +1,25 @@ +#ifndef SPQLIOS_Q120_NTT_H +#define SPQLIOS_Q120_NTT_H + +#include "../commons.h" +#include "q120_common.h" + +typedef struct _q120_ntt_precomp q120_ntt_precomp; + +EXPORT q120_ntt_precomp* q120_new_ntt_bb_precomp(const uint64_t n); +EXPORT void q120_del_ntt_bb_precomp(q120_ntt_precomp* precomp); + +EXPORT q120_ntt_precomp* q120_new_intt_bb_precomp(const uint64_t n); +EXPORT void q120_del_intt_bb_precomp(q120_ntt_precomp* precomp); + +/** + * @brief computes a direct ntt in-place over data. + */ +EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data); + +/** + * @brief computes an inverse ntt in-place over data. + */ +EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data); + +#endif // SPQLIOS_Q120_NTT_H diff --git a/spqlios/lib/spqlios/q120/q120_ntt_avx2.c b/spqlios/lib/spqlios/q120/q120_ntt_avx2.c new file mode 100644 index 0000000..9d0b547 --- /dev/null +++ b/spqlios/lib/spqlios/q120/q120_ntt_avx2.c @@ -0,0 +1,479 @@ +#include +#include +#include +#include +#include +#include + +#include "q120_common.h" +#include "q120_ntt_private.h" + +// at which level to switch from computations by level to computations by block +#define CHANGE_MODE_N 1024 + +__always_inline __m256i split_precompmul_si256(__m256i inp, __m256i powomega, const uint64_t h, const __m256i mask) { + const __m256i inp_low = _mm256_and_si256(inp, mask); + const __m256i t1 = _mm256_mul_epu32(inp_low, powomega); + + const __m256i inp_high = _mm256_srli_epi64(inp, h); + const __m256i powomega_high = _mm256_srli_epi64(powomega, 32); + const __m256i t2 = _mm256_mul_epu32(inp_high, powomega_high); + + return _mm256_add_epi64(t1, t2); +} + +__always_inline __m256i modq_red(const __m256i x, const uint64_t h, const __m256i mask, const __m256i _2_pow_h_modq) { + __m256i xh = _mm256_srli_epi64(x, h); + __m256i xl = _mm256_and_si256(x, mask); + __m256i xh_1 = _mm256_mul_epu32(xh, _2_pow_h_modq); + return _mm256_add_epi64(xl, xh_1); +} + +void print_data(const uint64_t n, const uint64_t* const data, const uint64_t q) { + for (uint64_t i = 0; i < n; i++) { + printf("%" PRIu64 " ", *(data + i) % q); + } + printf("\n"); +} + +double max_bit_size(const void* const begin, const void* const end) { + double bs = 0; + const uint64_t* data = (uint64_t*)begin; + while (data != end) { + double t = log2(*(data++)); + if (bs < t) { + bs = t; + } + } + return bs; +} + +void ntt_iter_first(__m256i* const begin, const __m256i* const end, const q120_ntt_step_precomp* const itData, + const __m256i* powomega) { + const uint64_t h = itData->half_bs; + const __m256i vmask = _mm256_set1_epi64x(itData->mask); + + __m256i* data = begin; + while (data < end) { + __m256i x = _mm256_loadu_si256(data); + __m256i po = _mm256_loadu_si256(powomega); + __m256i r = split_precompmul_si256(x, po, h, vmask); + _mm256_storeu_si256(data, r); + + data++; + powomega++; + } +} + +void ntt_iter(const uint64_t nn, __m256i* const begin, const __m256i* const end, + const q120_ntt_step_precomp* const itData, const __m256i* const powomega) { + assert(nn % 2 == 0); + const uint64_t halfnn = nn / 2; + + const __m256i vq2bs = _mm256_loadu_si256((__m256i*)itData->q2bs); + const __m256i vmask = _mm256_set1_epi64x(itData->mask); + + __m256i* data = begin; + while (data < end) { + __m256i* ptr1 = data; + __m256i* ptr2 = data + halfnn; + + const __m256i a = _mm256_loadu_si256(ptr1); + const __m256i b = _mm256_loadu_si256(ptr2); + + const __m256i ap = _mm256_add_epi64(a, b); + _mm256_storeu_si256(ptr1, ap); + + const __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b); + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + + const __m256i* po_ptr = powomega; + for (uint64_t i = 1; i < halfnn; ++i) { + __m256i a = _mm256_loadu_si256(ptr1); + __m256i b = _mm256_loadu_si256(ptr2); + + __m256i ap = _mm256_add_epi64(a, b); + + _mm256_storeu_si256(ptr1, ap); + + __m256i b1 = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b); + __m256i po = _mm256_loadu_si256(po_ptr); + + __m256i bp = split_precompmul_si256(b1, po, itData->half_bs, vmask); + + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + po_ptr++; + } + data += nn; + } +} + +void ntt_iter_red(const uint64_t nn, __m256i* const begin, const __m256i* const end, + const q120_ntt_step_precomp* const itData, const __m256i* const powomega, + const q120_ntt_reduc_step_precomp* const reduc_precomp) { + assert(nn % 2 == 0); + const uint64_t halfnn = nn / 2; + + const __m256i vq2bs = _mm256_loadu_si256((__m256i*)itData->q2bs); + const __m256i vmask = _mm256_set1_epi64x(itData->mask); + + const __m256i reduc_mask = _mm256_set1_epi64x(reduc_precomp->mask); + const __m256i reduc_cst = _mm256_loadu_si256((__m256i*)reduc_precomp->modulo_red_cst); + + __m256i* data = begin; + while (data < end) { + __m256i* ptr1 = data; + __m256i* ptr2 = data + halfnn; + + __m256i a = _mm256_loadu_si256(ptr1); + __m256i b = _mm256_loadu_si256(ptr2); + + a = modq_red(a, reduc_precomp->h, reduc_mask, reduc_cst); + b = modq_red(b, reduc_precomp->h, reduc_mask, reduc_cst); + + const __m256i ap = _mm256_add_epi64(a, b); + _mm256_storeu_si256(ptr1, ap); + + const __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b); + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + + const __m256i* po_ptr = powomega; + for (uint64_t i = 1; i < halfnn; ++i) { + __m256i a = _mm256_loadu_si256(ptr1); + __m256i b = _mm256_loadu_si256(ptr2); + + a = modq_red(a, reduc_precomp->h, reduc_mask, reduc_cst); + b = modq_red(b, reduc_precomp->h, reduc_mask, reduc_cst); + + __m256i ap = _mm256_add_epi64(a, b); + + _mm256_storeu_si256(ptr1, ap); + + __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b); + __m256i po = _mm256_loadu_si256(po_ptr); + bp = split_precompmul_si256(bp, po, itData->half_bs, vmask); + + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + po_ptr++; + } + data += nn; + } +} + +EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data_ptr) { + // assert((size_t)data_ptr % 32 == 0); // alignment check + + const uint64_t n = precomp->n; + if (n == 1) return; + + const q120_ntt_step_precomp* itData = precomp->level_metadata; + const __m256i* powomega = (__m256i*)precomp->powomega; + + __m256i* const begin = (__m256i*)data_ptr; + const __m256i* const end = ((__m256i*)data_ptr) + n; + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end); + LOG("Input %lf %" PRIu64 "\n", bs, precomp->input_bit_size); + assert(bs <= precomp->input_bit_size); + } + + // first iteration a_k.omega^k + ntt_iter_first(begin, end, itData, powomega); + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end); + LOG("Iter %3" PRIu64 " - %lf %" PRIu64 "\n", n, bs, itData->bs); + assert(bs < itData->bs); + } + + powomega += n; + itData++; + + const uint64_t split_nn = (CHANGE_MODE_N > n) ? n : CHANGE_MODE_N; + // const uint64_t split_nn = 2; + + // computations by level + for (uint64_t nn = n; nn > split_nn; nn /= 2) { + const uint64_t halfnn = nn / 2; + + if (itData->reduce) { + ntt_iter_red(nn, begin, end, itData, powomega, &precomp->reduc_metadata); + } else { + ntt_iter(nn, begin, end, itData, powomega); + } + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end); + LOG("Iter %3" PRIu64 " - %lf %" PRIu64 " %c\n", nn / 2, bs, itData->bs, itData->reduce ? '*' : ' '); + assert(bs < itData->bs); + } + + powomega += halfnn - 1; + itData++; + } + + // computations by memory block + if (split_nn >= 2) { + const q120_ntt_step_precomp* itData1 = itData; + const __m256i* powomega1 = powomega; + for (__m256i* it = begin; it < end; it += split_nn) { + __m256i* const begin1 = it; + const __m256i* const end1 = it + split_nn; + + itData = itData1; + powomega = powomega1; + for (uint64_t nn = split_nn; nn >= 2; nn /= 2) { + const uint64_t halfnn = nn / 2; + + if (itData->reduce) { + ntt_iter_red(nn, begin1, end1, itData, powomega, &precomp->reduc_metadata); + } else { + ntt_iter(nn, begin1, end1, itData, powomega); + } + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((uint64_t*)begin1, (uint64_t*)end1); + // LOG("Iter %3lu - %lf %lu\n", nn / 2, bs, itData->bs); + assert(bs < itData->bs); + } + + powomega += halfnn - 1; + itData++; + } + } + } + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end); + LOG("Iter %3" PRIu64 " - %lf %" PRIu64 "\n", UINT64_C(1), bs, precomp->output_bit_size); + assert(bs < precomp->output_bit_size); + } +} + +void intt_iter(const uint64_t nn, __m256i* const begin, const __m256i* const end, + const q120_ntt_step_precomp* const itData, const __m256i* const powomega) { + assert(nn % 2 == 0); + const uint64_t halfnn = nn / 2; + + const __m256i vq2bs = _mm256_loadu_si256((__m256i*)itData->q2bs); + const __m256i vmask = _mm256_set1_epi64x(itData->mask); + + __m256i* data = begin; + while (data < end) { + __m256i* ptr1 = data; + __m256i* ptr2 = data + halfnn; + + const __m256i a = _mm256_loadu_si256(ptr1); + const __m256i b = _mm256_loadu_si256(ptr2); + + const __m256i ap = _mm256_add_epi64(a, b); + _mm256_storeu_si256(ptr1, ap); + + const __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b); + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + + const __m256i* po_ptr = powomega; + for (uint64_t i = 1; i < halfnn; ++i) { + __m256i a = _mm256_loadu_si256(ptr1); + __m256i b = _mm256_loadu_si256(ptr2); + + __m256i po = _mm256_loadu_si256(po_ptr); + __m256i bo = split_precompmul_si256(b, po, itData->half_bs, vmask); + + __m256i ap = _mm256_add_epi64(a, bo); + + _mm256_storeu_si256(ptr1, ap); + + __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), bo); + + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + po_ptr++; + } + data += nn; + } +} + +void intt_iter_red(const uint64_t nn, __m256i* const begin, const __m256i* const end, + const q120_ntt_step_precomp* const itData, const __m256i* const powomega, + const q120_ntt_reduc_step_precomp* const reduc_precomp) { + assert(nn % 2 == 0); + const uint64_t halfnn = nn / 2; + + const __m256i vq2bs = _mm256_loadu_si256((__m256i*)itData->q2bs); + const __m256i vmask = _mm256_set1_epi64x(itData->mask); + + const __m256i reduc_mask = _mm256_set1_epi64x(reduc_precomp->mask); + const __m256i reduc_cst = _mm256_loadu_si256((__m256i*)reduc_precomp->modulo_red_cst); + + __m256i* data = begin; + while (data < end) { + __m256i* ptr1 = data; + __m256i* ptr2 = data + halfnn; + + __m256i a = _mm256_loadu_si256(ptr1); + __m256i b = _mm256_loadu_si256(ptr2); + + a = modq_red(a, reduc_precomp->h, reduc_mask, reduc_cst); + b = modq_red(b, reduc_precomp->h, reduc_mask, reduc_cst); + + const __m256i ap = _mm256_add_epi64(a, b); + _mm256_storeu_si256(ptr1, ap); + + const __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b); + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + + const __m256i* po_ptr = powomega; + for (uint64_t i = 1; i < halfnn; ++i) { + __m256i a = _mm256_loadu_si256(ptr1); + __m256i b = _mm256_loadu_si256(ptr2); + + a = modq_red(a, reduc_precomp->h, reduc_mask, reduc_cst); + b = modq_red(b, reduc_precomp->h, reduc_mask, reduc_cst); + + __m256i po = _mm256_loadu_si256(po_ptr); + __m256i bo = split_precompmul_si256(b, po, itData->half_bs, vmask); + + __m256i ap = _mm256_add_epi64(a, bo); + + _mm256_storeu_si256(ptr1, ap); + + __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), bo); + + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + po_ptr++; + } + data += nn; + } +} + +void ntt_iter_first_red(__m256i* const begin, const __m256i* const end, const q120_ntt_step_precomp* const itData, + const __m256i* powomega, const q120_ntt_reduc_step_precomp* const reduc_precomp) { + const uint64_t h = itData->half_bs; + const __m256i vmask = _mm256_set1_epi64x(itData->mask); + + const __m256i reduc_mask = _mm256_set1_epi64x(reduc_precomp->mask); + const __m256i reduc_cst = _mm256_loadu_si256((__m256i*)reduc_precomp->modulo_red_cst); + + __m256i* data = begin; + while (data < end) { + __m256i x = _mm256_loadu_si256(data); + x = modq_red(x, reduc_precomp->h, reduc_mask, reduc_cst); + __m256i po = _mm256_loadu_si256(powomega); + __m256i r = split_precompmul_si256(x, po, h, vmask); + _mm256_storeu_si256(data, r); + + data++; + powomega++; + } +} + +EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data_ptr) { + // assert((size_t)data_ptr % 32 == 0); // alignment check + + const uint64_t n = precomp->n; + if (n == 1) return; + + const q120_ntt_step_precomp* itData = precomp->level_metadata; + const __m256i* powomega = (__m256i*)precomp->powomega; + + __m256i* const begin = (__m256i*)data_ptr; + const __m256i* const end = ((__m256i*)data_ptr) + n; + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end); + LOG("Input %lf %" PRIu64 "\n", bs, precomp->input_bit_size); + assert(bs <= precomp->input_bit_size); + } + + const uint64_t split_nn = (CHANGE_MODE_N > n) ? n : CHANGE_MODE_N; + + // computations by memory block + if (split_nn >= 2) { + const q120_ntt_step_precomp* itData1 = itData; + const __m256i* powomega1 = powomega; + for (__m256i* it = begin; it < end; it += split_nn) { + __m256i* const begin1 = it; + const __m256i* const end1 = it + split_nn; + + itData = itData1; + powomega = powomega1; + for (uint64_t nn = 2; nn <= split_nn; nn *= 2) { + const uint64_t halfnn = nn / 2; + + if (itData->reduce) { + intt_iter_red(nn, begin1, end1, itData, powomega, &precomp->reduc_metadata); + } else { + intt_iter(nn, begin1, end1, itData, powomega); + } + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((uint64_t*)begin1, (uint64_t*)end1); + // LOG("Iter %3lu - %lf %lu\n", nn / 2, bs, itData->bs); + assert(bs < itData->bs); + } + + powomega += halfnn - 1; + itData++; + } + } + } + + // computations by level + // for (uint64_t nn = 2; nn <= n; nn *= 2) { + for (uint64_t nn = 2 * split_nn; nn <= n; nn *= 2) { + const uint64_t halfnn = nn / 2; + + if (itData->reduce) { + intt_iter_red(nn, begin, end, itData, powomega, &precomp->reduc_metadata); + } else { + intt_iter(nn, begin, end, itData, powomega); + } + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end); + LOG("Iter %3" PRIu64 " - %lf %" PRIu64 " %c\n", nn / 2, bs, itData->bs, itData->reduce ? '*' : ' '); + assert(bs < itData->bs); + } + + powomega += halfnn - 1; + itData++; + } + + // last iteration a_k . omega^k . n^-1 + if (itData->reduce) { + ntt_iter_first_red(begin, end, itData, powomega, &precomp->reduc_metadata); + } else { + ntt_iter_first(begin, end, itData, powomega); + } + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end); + LOG("Iter %3" PRIu64 " - %lf %" PRIu64 "\n", n, bs, itData->bs); + assert(bs < itData->bs); + } +} diff --git a/spqlios/lib/spqlios/q120/q120_ntt_private.h b/spqlios/lib/spqlios/q120/q120_ntt_private.h new file mode 100644 index 0000000..c727ecd --- /dev/null +++ b/spqlios/lib/spqlios/q120/q120_ntt_private.h @@ -0,0 +1,39 @@ +#include "q120_ntt.h" + +#ifndef NDEBUG +#define CHECK_BOUNDS 1 +#define VERBOSE +#else +#define CHECK_BOUNDS 0 +#endif + +#ifndef VERBOSE +#define LOG(...) ; +#else +#define LOG(...) printf(__VA_ARGS__); +#endif + +typedef struct _q120_ntt_step_precomp { + uint64_t q2bs[4]; // q2bs = 2^{bs-31}.q[k] + uint64_t bs; // inputs at this iterations must be in Q_n + uint64_t half_bs; // == ceil(bs/2) + uint64_t mask; // (1< +#include + +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +void reim_from_znx64_ref(const REIM_FROM_ZNX64_PRECOMP* precomp, void* r, const int64_t* x) { + // naive version of the function (just cast) + const uint64_t nn = precomp->m << 1; + double* res = (double*)r; + for (uint64_t i = 0; i < nn; ++i) { + res[i] = (double)x[i]; + } +} + +void* init_reim_from_znx64_precomp(REIM_FROM_ZNX64_PRECOMP* const res, uint32_t m, uint32_t log2bound) { + if (m & (m - 1)) return spqlios_error("m must be a power of 2"); + // currently, we are going to use the trick add 3.2^51, mask the exponent, reinterpret bits as double. + // therefore we need the input values to be < 2^50. + if (log2bound > 50) return spqlios_error("Invalid log2bound error: must be in [0,50]"); + res->m = m; + FROM_ZNX64_FUNC resf = reim_from_znx64_ref; + if (m >= 8) { + if (CPU_SUPPORTS("avx2")) { + resf = reim_from_znx64_bnd50_fma; + } + } + res->function = resf; + return res; +} + +EXPORT REIM_FROM_ZNX64_PRECOMP* new_reim_from_znx64_precomp(uint32_t m, uint32_t log2bound) { + REIM_FROM_ZNX64_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim_from_znx64_precomp(res, m, log2bound)); +} + +EXPORT void reim_from_znx64(const REIM_FROM_ZNX64_PRECOMP* tables, void* r, const int64_t* a) { + tables->function(tables, r, a); +} + +/** + * @brief Simpler API for the znx64 to reim conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment + */ +EXPORT void reim_from_znx64_simple(uint32_t m, uint32_t log2bound, void* r, const int64_t* a) { + // not checking for log2bound which is not relevant here + static REIM_FROM_ZNX64_PRECOMP precomp[32]; + REIM_FROM_ZNX64_PRECOMP* p = precomp + log2m(m); + if (!p->function) { + if (!init_reim_from_znx64_precomp(p, m, log2bound)) abort(); + } + p->function(p, r, a); +} + +void reim_from_znx32_ref(const REIM_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { NOT_IMPLEMENTED(); } +void reim_from_znx32_avx2_fma(const REIM_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { NOT_IMPLEMENTED(); } + +void* init_reim_from_znx32_precomp(REIM_FROM_ZNX32_PRECOMP* const res, uint32_t m, uint32_t log2bound) { + if (m & (m - 1)) return spqlios_error("m must be a power of 2"); + if (log2bound > 32) return spqlios_error("Invalid log2bound error: must be in [0,32]"); + res->m = m; + // TODO: check selection logic + if (CPU_SUPPORTS("avx2")) { + if (m >= 8) { + res->function = reim_from_znx32_avx2_fma; + } else { + res->function = reim_from_znx32_ref; + } + } else { + res->function = reim_from_znx32_ref; + } + return res; +} + +EXPORT REIM_FROM_ZNX32_PRECOMP* new_reim_from_znx32_precomp(uint32_t m, uint32_t log2bound) { + REIM_FROM_ZNX32_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim_from_znx32_precomp(res, m, log2bound)); +} + +void reim_from_tnx32_ref(const REIM_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { NOT_IMPLEMENTED(); } +void reim_from_tnx32_avx2_fma(const REIM_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { NOT_IMPLEMENTED(); } + +void* init_reim_from_tnx32_precomp(REIM_FROM_TNX32_PRECOMP* const res, uint32_t m) { + if (m & (m - 1)) return spqlios_error("m must be a power of 2"); + res->m = m; + // TODO: check selection logic + if (CPU_SUPPORTS("avx2")) { + if (m >= 8) { + res->function = reim_from_tnx32_avx2_fma; + } else { + res->function = reim_from_tnx32_ref; + } + } else { + res->function = reim_from_tnx32_ref; + } + return res; +} + +EXPORT REIM_FROM_TNX32_PRECOMP* new_reim_from_tnx32_precomp(uint32_t m) { + REIM_FROM_TNX32_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim_from_tnx32_precomp(res, m)); +} + +void reim_to_tnx32_ref(const REIM_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) { NOT_IMPLEMENTED(); } +void reim_to_tnx32_avx2_fma(const REIM_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) { NOT_IMPLEMENTED(); } + +void* init_reim_to_tnx32_precomp(REIM_TO_TNX32_PRECOMP* const res, uint32_t m, double divisor, uint32_t log2overhead) { + if (m & (m - 1)) return spqlios_error("m must be a power of 2"); + if (is_not_pow2_double(&divisor)) return spqlios_error("divisor must be a power of 2"); + if (log2overhead > 52) return spqlios_error("log2overhead is too large"); + res->m = m; + res->divisor = divisor; + // TODO: check selection logic + if (CPU_SUPPORTS("avx2")) { + if (log2overhead <= 18) { + if (m >= 8) { + res->function = reim_to_tnx32_avx2_fma; + } else { + res->function = reim_to_tnx32_ref; + } + } else { + res->function = reim_to_tnx32_ref; + } + } else { + res->function = reim_to_tnx32_ref; + } + return res; +} + +EXPORT REIM_TO_TNX32_PRECOMP* new_reim_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead) { + REIM_TO_TNX32_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim_to_tnx32_precomp(res, m, divisor, log2overhead)); +} + +EXPORT void reim_from_znx32_simple(uint32_t m, uint32_t log2bound, void* r, const int32_t* x) { + static REIM_FROM_ZNX32_PRECOMP* p[31] = {0}; + REIM_FROM_ZNX32_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_from_znx32_precomp(m, log2bound); + (*f)->function(*f, r, x); +} + +EXPORT void reim_from_tnx32_simple(uint32_t m, void* r, const int32_t* x) { + static REIM_FROM_TNX32_PRECOMP* p[31] = {0}; + REIM_FROM_TNX32_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_from_tnx32_precomp(m); + (*f)->function(*f, r, x); +} + +EXPORT void reim_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x) { + static REIM_TO_TNX32_PRECOMP* p[31] = {0}; + REIM_TO_TNX32_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_to_tnx32_precomp(m, divisor, log2overhead); + (*f)->function(*f, r, x); +} + +void reim_to_znx64_ref(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x) { + // for now, we stick to a slow implem + uint64_t nn = precomp->m << 1; + const double* v = (double*)x; + double invdiv = 1. / precomp->divisor; + for (uint64_t i = 0; i < nn; ++i) { + r[i] = (int64_t)rint(v[i] * invdiv); + } +} + +void* init_reim_to_znx64_precomp(REIM_TO_ZNX64_PRECOMP* const res, uint32_t m, double divisor, uint32_t log2bound) { + if (m & (m - 1)) return spqlios_error("m must be a power of 2"); + if (is_not_pow2_double(&divisor)) return spqlios_error("divisor must be a power of 2"); + if (log2bound > 64) return spqlios_error("log2bound is too large"); + res->m = m; + res->divisor = divisor; + TO_ZNX64_FUNC resf = reim_to_znx64_ref; + if (CPU_SUPPORTS("avx2") && m >= 8) { + if (log2bound <= 50) { + resf = reim_to_znx64_avx2_bnd50_fma; + } else { + resf = reim_to_znx64_avx2_bnd63_fma; + } + } + res->function = resf; // must be the last one set + return res; +} + +EXPORT REIM_TO_ZNX64_PRECOMP* new_reim_to_znx64_precomp(uint32_t m, double divisor, uint32_t log2bound) { + REIM_TO_ZNX64_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim_to_znx64_precomp(res, m, divisor, log2bound)); +} + +EXPORT void reim_to_znx64(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* a) { + precomp->function(precomp, r, a); +} + +/** + * @brief Simpler API for the znx64 to reim conversion. + */ +EXPORT void reim_to_znx64_simple(uint32_t m, double divisor, uint32_t log2bound, int64_t* r, const void* a) { + // not checking distinguishing <=50 or not + static __thread REIM_TO_ZNX64_PRECOMP p; + static __thread uint32_t prev_log2bound; + if (!p.function || p.m != m || p.divisor != divisor || prev_log2bound != log2bound) { + if (!init_reim_to_znx64_precomp(&p, m, divisor, log2bound)) abort(); + prev_log2bound = log2bound; + } + p.function(&p, r, a); +} diff --git a/spqlios/lib/spqlios/reim/reim_conversions_avx.c b/spqlios/lib/spqlios/reim/reim_conversions_avx.c new file mode 100644 index 0000000..30f98c9 --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_conversions_avx.c @@ -0,0 +1,106 @@ +#include + +#include "reim_fft_private.h" + +void reim_from_znx64_bnd50_fma(const REIM_FROM_ZNX64_PRECOMP* precomp, void* r, const int64_t* x) { + static const double EXPO = INT64_C(1) << 52; + const int64_t ADD_CST = INT64_C(1) << 51; + const double SUB_CST = INT64_C(3) << 51; + + const __m256d SUB_CST_4 = _mm256_set1_pd(SUB_CST); + const __m256i ADD_CST_4 = _mm256_set1_epi64x(ADD_CST); + const __m256d EXPO_4 = _mm256_set1_pd(EXPO); + + double(*out)[4] = (double(*)[4])r; + __m256i* in = (__m256i*)x; + __m256i* inend = (__m256i*)(x + (precomp->m << 1)); + do { + // read the next value + __m256i a = _mm256_loadu_si256(in); + a = _mm256_add_epi64(a, ADD_CST_4); + __m256d ad = _mm256_castsi256_pd(a); + ad = _mm256_or_pd(ad, EXPO_4); + ad = _mm256_sub_pd(ad, SUB_CST_4); + // store the next value + _mm256_storeu_pd(out[0], ad); + ++out; + ++in; + } while (in < inend); +} + +// version where the output norm can be as big as 2^63 +void reim_to_znx64_avx2_bnd63_fma(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x) { + static const uint64_t SIGN_MASK = 0x8000000000000000UL; + static const uint64_t EXPO_MASK = 0x7FF0000000000000UL; + static const uint64_t MANTISSA_MASK = 0x000FFFFFFFFFFFFFUL; + static const uint64_t MANTISSA_MSB = 0x0010000000000000UL; + const double divisor_bits = precomp->divisor * ((double)(INT64_C(1) << 52)); + const double offset = precomp->divisor / 2.; + + const __m256d SIGN_MASK_4 = _mm256_castsi256_pd(_mm256_set1_epi64x(SIGN_MASK)); + const __m256i EXPO_MASK_4 = _mm256_set1_epi64x(EXPO_MASK); + const __m256i MANTISSA_MASK_4 = _mm256_set1_epi64x(MANTISSA_MASK); + const __m256i MANTISSA_MSB_4 = _mm256_set1_epi64x(MANTISSA_MSB); + const __m256d offset_4 = _mm256_set1_pd(offset); + const __m256i divi_bits_4 = _mm256_castpd_si256(_mm256_set1_pd(divisor_bits)); + + double(*in)[4] = (double(*)[4])x; + __m256i* out = (__m256i*)r; + __m256i* outend = (__m256i*)(r + (precomp->m << 1)); + do { + // read the next value + __m256d a = _mm256_loadu_pd(in[0]); + // a += sign(a) * m/2 + __m256d asign = _mm256_and_pd(a, SIGN_MASK_4); + a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_4)); + // sign: either 0 or -1 + __m256i sign_mask = _mm256_castpd_si256(asign); + sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63)); + // compute the exponents + __m256i a0exp = _mm256_and_si256(_mm256_castpd_si256(a), EXPO_MASK_4); + __m256i a0lsh = _mm256_sub_epi64(a0exp, divi_bits_4); + __m256i a0rsh = _mm256_sub_epi64(divi_bits_4, a0exp); + a0lsh = _mm256_srli_epi64(a0lsh, 52); + a0rsh = _mm256_srli_epi64(a0rsh, 52); + // compute the new mantissa + __m256i a0pos = _mm256_and_si256(_mm256_castpd_si256(a), MANTISSA_MASK_4); + a0pos = _mm256_or_si256(a0pos, MANTISSA_MSB_4); + a0lsh = _mm256_sllv_epi64(a0pos, a0lsh); + a0rsh = _mm256_srlv_epi64(a0pos, a0rsh); + __m256i final = _mm256_or_si256(a0lsh, a0rsh); + // negate if the sign was negative + final = _mm256_xor_si256(final, sign_mask); + final = _mm256_sub_epi64(final, sign_mask); + // read the next value + _mm256_storeu_si256(out, final); + ++out; + ++in; + } while (out < outend); +} + +// version where the output norm can be as big as 2^50 +void reim_to_znx64_avx2_bnd50_fma(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x) { + static const uint64_t MANTISSA_MASK = 0x000FFFFFFFFFFFFFUL; + const int64_t SUB_CST = INT64_C(1) << 51; + const double add_cst = precomp->divisor * ((double)(INT64_C(3) << 51)); + + const __m256i SUB_CST_4 = _mm256_set1_epi64x(SUB_CST); + const __m256d add_cst_4 = _mm256_set1_pd(add_cst); + const __m256i MANTISSA_MASK_4 = _mm256_set1_epi64x(MANTISSA_MASK); + + double(*in)[4] = (double(*)[4])x; + __m256i* out = (__m256i*)r; + __m256i* outend = (__m256i*)(r + (precomp->m << 1)); + do { + // read the next value + __m256d a = _mm256_loadu_pd(in[0]); + a = _mm256_add_pd(a, add_cst_4); + __m256i ai = _mm256_castpd_si256(a); + ai = _mm256_and_si256(ai, MANTISSA_MASK_4); + ai = _mm256_sub_epi64(ai, SUB_CST_4); + // store the next value + _mm256_storeu_si256(out, ai); + ++out; + ++in; + } while (out < outend); +} diff --git a/spqlios/lib/spqlios/reim/reim_execute.c b/spqlios/lib/spqlios/reim/reim_execute.c new file mode 100644 index 0000000..5a23926 --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_execute.c @@ -0,0 +1,22 @@ +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +EXPORT void reim_from_znx32(const REIM_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a) { + tables->function(tables, r, a); +} + +EXPORT void reim_from_tnx32(const REIM_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a) { + tables->function(tables, r, a); +} + +EXPORT void reim_to_tnx32(const REIM_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a) { + tables->function(tables, r, a); +} + +EXPORT void reim_fftvec_mul(const REIM_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b) { + tables->function(tables, r, a, b); +} + +EXPORT void reim_fftvec_addmul(const REIM_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, const double* b) { + tables->function(tables, r, a, b); +} diff --git a/spqlios/lib/spqlios/reim/reim_fallbacks_aarch64.c b/spqlios/lib/spqlios/reim/reim_fallbacks_aarch64.c new file mode 100644 index 0000000..6488770 --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fallbacks_aarch64.c @@ -0,0 +1,15 @@ +#include "reim_fft_private.h" + +EXPORT void reim_fftvec_addmul_fma(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, + const double* b) { + UNDEFINED(); +} +EXPORT void reim_fftvec_mul_fma(const REIM_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) { + UNDEFINED(); +} + +EXPORT void reim_fft_avx2_fma(const REIM_FFT_PRECOMP* tables, double* data) { UNDEFINED(); } +EXPORT void reim_ifft_avx2_fma(const REIM_IFFT_PRECOMP* tables, double* data) { UNDEFINED(); } + +//EXPORT void reim_fft(const REIM_FFT_PRECOMP* tables, double* data) { tables->function(tables, data); } +//EXPORT void reim_ifft(const REIM_IFFT_PRECOMP* tables, double* data) { tables->function(tables, data); } diff --git a/spqlios/lib/spqlios/reim/reim_fft.h b/spqlios/lib/spqlios/reim/reim_fft.h new file mode 100644 index 0000000..e8da04a --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fft.h @@ -0,0 +1,207 @@ +#ifndef SPQLIOS_REIM_FFT_H +#define SPQLIOS_REIM_FFT_H + +#include "../commons.h" + +typedef struct reim_fft_precomp REIM_FFT_PRECOMP; +typedef struct reim_ifft_precomp REIM_IFFT_PRECOMP; +typedef struct reim_mul_precomp REIM_FFTVEC_MUL_PRECOMP; +typedef struct reim_addmul_precomp REIM_FFTVEC_ADDMUL_PRECOMP; +typedef struct reim_from_znx32_precomp REIM_FROM_ZNX32_PRECOMP; +typedef struct reim_from_znx64_precomp REIM_FROM_ZNX64_PRECOMP; +typedef struct reim_from_tnx32_precomp REIM_FROM_TNX32_PRECOMP; +typedef struct reim_to_tnx32_precomp REIM_TO_TNX32_PRECOMP; +typedef struct reim_to_tnx_precomp REIM_TO_TNX_PRECOMP; +typedef struct reim_to_znx64_precomp REIM_TO_ZNX64_PRECOMP; + +/** + * @brief precomputes fft tables. + * The FFT tables contains a constant section that is required for efficient FFT operations in dimension nn. + * The resulting pointer is to be passed as "tables" argument to any call to the fft function. + * The user can optionnally allocate zero or more computation buffers, which are scratch spaces that are contiguous to + * the constant tables in memory, and allow for more efficient operations. It is the user's responsibility to ensure + * that each of those buffers are never used simultaneously by two ffts on different threads at the same time. The fft + * table must be deleted by delete_fft_precomp after its last usage. + */ +EXPORT REIM_FFT_PRECOMP* new_reim_fft_precomp(uint32_t m, uint32_t num_buffers); + +/** + * @brief gets the address of a fft buffer allocated during new_fft_precomp. + * This buffer can be used as data pointer in subsequent calls to fft, + * and does not need to be released afterwards. + */ +EXPORT double* reim_fft_precomp_get_buffer(const REIM_FFT_PRECOMP* tables, uint32_t buffer_index); + +/** + * @brief allocates a new fft buffer. + * This buffer can be used as data pointer in subsequent calls to fft, + * and must be deleted afterwards by calling delete_fft_buffer. + */ +EXPORT double* new_reim_fft_buffer(uint32_t m); + +/** + * @brief allocates a new fft buffer. + * This buffer can be used as data pointer in subsequent calls to fft, + * and must be deleted afterwards by calling delete_fft_buffer. + */ +EXPORT void delete_reim_fft_buffer(double* buffer); + +/** + * @brief deallocates a fft table and all its built-in buffers. + */ +#define delete_reim_fft_precomp free + +/** + * @brief computes a direct fft in-place over data. + */ +EXPORT void reim_fft(const REIM_FFT_PRECOMP* tables, double* data); + +EXPORT REIM_IFFT_PRECOMP* new_reim_ifft_precomp(uint32_t m, uint32_t num_buffers); +EXPORT double* reim_ifft_precomp_get_buffer(const REIM_IFFT_PRECOMP* tables, uint32_t buffer_index); +EXPORT void reim_ifft(const REIM_IFFT_PRECOMP* tables, double* data); +#define delete_reim_ifft_precomp free + +EXPORT REIM_FFTVEC_MUL_PRECOMP* new_reim_fftvec_mul_precomp(uint32_t m); +EXPORT void reim_fftvec_mul(const REIM_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); +#define delete_reim_fftvec_mul_precomp free + +EXPORT REIM_FFTVEC_ADDMUL_PRECOMP* new_reim_fftvec_addmul_precomp(uint32_t m); +EXPORT void reim_fftvec_addmul(const REIM_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, const double* b); +#define delete_reim_fftvec_addmul_precomp free + +/** + * @brief prepares a conversion from ZnX to the cplx layout. + * All the coefficients must be strictly lower than 2^log2bound in absolute value. Any attempt to use + * this function on a larger coefficient is undefined behaviour. The resulting precomputed data must + * be freed with `new_reim_from_znx32_precomp` + * @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m + * int32 coefficients in natural order modulo X^n+1 + * @param log2bound bound on the input coefficients. Must be between 0 and 32 + */ +EXPORT REIM_FROM_ZNX32_PRECOMP* new_reim_from_znx32_precomp(uint32_t m, uint32_t log2bound); + +/** + * @brief converts from ZnX to the cplx layout. + * @param tables precomputed data obtained by new_reim_from_znx32_precomp. + * @param r resulting array of m complexes coefficients mod X^m-i + * @param x input array of n bounded integer coefficients mod X^n+1 + */ +EXPORT void reim_from_znx32(const REIM_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a); +/** @brief frees a precomputed conversion data initialized with new_reim_from_znx32_precomp. */ +#define delete_reim_from_znx32_precomp free + +/** + * @brief converts from ZnX to the cplx layout. + * @param tables precomputed data obtained by new_reim_from_znx64_precomp. + * @param r resulting array of m complexes coefficients mod X^m-i + * @param x input array of n bounded integer coefficients mod X^n+1 + */ +EXPORT void reim_from_znx64(const REIM_FROM_ZNX64_PRECOMP* tables, void* r, const int64_t* a); +/** @brief frees a precomputed conversion data initialized with new_reim_from_znx32_precomp. */ +EXPORT REIM_FROM_ZNX64_PRECOMP* new_reim_from_znx64_precomp(uint32_t m, uint32_t maxbnd); +#define delete_reim_from_znx64_precomp free +EXPORT void reim_from_znx64_simple(uint32_t m, uint32_t log2bound, void* r, const int64_t* a); + +/** + * @brief prepares a conversion from TnX to the cplx layout. + * @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m + * torus32 coefficients. The resulting precomputed data must + * be freed with `delete_reim_from_tnx32_precomp` + */ + +EXPORT REIM_FROM_TNX32_PRECOMP* new_reim_from_tnx32_precomp(uint32_t m); +/** + * @brief converts from TnX to the cplx layout. + * @param tables precomputed data obtained by new_reim_from_tnx32_precomp. + * @param r resulting array of m complexes coefficients mod X^m-i + * @param x input array of n torus32 coefficients mod X^n+1 + */ +EXPORT void reim_from_tnx32(const REIM_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a); +/** @brief frees a precomputed conversion data initialized with new_reim_from_tnx32_precomp. */ +#define delete_reim_from_tnx32_precomp free + +/** + * @brief prepares a rescale and conversion from the cplx layout to TnX. + * @param m the target complex dimension m from C[X] mod X^m-i. Note that the outputs have n=2m + * torus32 coefficients. + * @param divisor must be a power of two. The inputs are rescaled by divisor before being reduced modulo 1. + * Remember that the output of an iFFT must be divided by m. + * @param log2overhead all inputs absolute values must be within divisor.2^log2overhead. + * For any inputs outside of these bounds, the conversion is undefined behaviour. + * The maximum supported log2overhead is 52, and the algorithm is faster for log2overhead=18. + */ +EXPORT REIM_TO_TNX32_PRECOMP* new_reim_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead); + +/** + * @brief rescale, converts and reduce mod 1 from cplx layout to torus32. + * @param tables precomputed data obtained by new_reim_from_tnx32_precomp. + * @param r resulting array of n torus32 coefficients mod X^n+1 + * @param x input array of m cplx coefficients mod X^m-i + */ +EXPORT void reim_to_tnx32(const REIM_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a); +#define delete_reim_to_tnx32_precomp free + +/** + * @brief prepares a rescale and conversion from the cplx layout to TnX (doubles). + * @param m the target complex dimension m from C[X] mod X^m-i. Note that the outputs have n=2m + * torus32 coefficients. + * @param divisor must be a power of two. The inputs are rescaled by divisor before being reduced modulo 1. + * Remember that the output of an iFFT must be divided by m. + * @param log2overhead all inputs absolute values must be within divisor.2^log2overhead. + * For any inputs outside of these bounds, the conversion is undefined behaviour. + * The maximum supported log2overhead is 52, and the algorithm is faster for log2overhead=18. + */ +EXPORT REIM_TO_TNX_PRECOMP* new_reim_to_tnx_precomp(uint32_t m, double divisor, uint32_t log2overhead); +/** + * @brief rescale, converts and reduce mod 1 from cplx layout to torus32. + * @param tables precomputed data obtained by new_reim_from_tnx32_precomp. + * @param r resulting array of n torus32 coefficients mod X^n+1 + * @param x input array of m cplx coefficients mod X^m-i + */ +EXPORT void reim_to_tnx(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* a); +#define delete_reim_to_tnx_precomp free +EXPORT void reim_to_tnx_simple(uint32_t m, double divisor, uint32_t log2overhead, double* r, const double* a); + +EXPORT REIM_TO_ZNX64_PRECOMP* new_reim_to_znx64_precomp(uint32_t m, double divisor, uint32_t log2bound); +#define delete_reim_to_znx64_precomp free +EXPORT void reim_to_znx64(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* a); +EXPORT void reim_to_znx64_simple(uint32_t m, double divisor, uint32_t log2bound, int64_t* r, const void* a); + +/** + * @brief Simpler API for the fft function. + * For each dimension, the precomputed tables for this dimension are generated automatically. + * It is advised to do one dry-run per desired dimension before using in a multithread environment */ +EXPORT void reim_fft_simple(uint32_t m, void* data); +/** + * @brief Simpler API for the ifft function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension in the main thread before using in a multithread + * environment */ +EXPORT void reim_ifft_simple(uint32_t m, void* data); +/** + * @brief Simpler API for the fftvec multiplication function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b); +/** + * @brief Simpler API for the fftvec addmul function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b); +/** + * @brief Simpler API for the znx32 to cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim_from_znx32_simple(uint32_t m, uint32_t log2bound, void* r, const int32_t* x); +/** + * @brief Simpler API for the tnx32 to cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim_from_tnx32_simple(uint32_t m, void* r, const int32_t* x); +/** + * @brief Simpler API for the cplx to tnx32 conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x); + +#endif // SPQLIOS_REIM_FFT_H diff --git a/spqlios/lib/spqlios/reim/reim_fft16_avx_fma.s b/spqlios/lib/spqlios/reim/reim_fft16_avx_fma.s new file mode 100644 index 0000000..e68012c --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fft16_avx_fma.s @@ -0,0 +1,167 @@ +#rdi datare ptr +#rsi dataim ptr +#rdx om ptr +.globl reim_fft16_avx_fma +reim_fft16_avx_fma: +vmovupd (%rdi),%ymm0 # ra0 +vmovupd 0x20(%rdi),%ymm1 # ra4 +vmovupd 0x40(%rdi),%ymm2 # ra8 +vmovupd 0x60(%rdi),%ymm3 # ra12 +vmovupd (%rsi),%ymm4 # ia0 +vmovupd 0x20(%rsi),%ymm5 # ia4 +vmovupd 0x40(%rsi),%ymm6 # ia8 +vmovupd 0x60(%rsi),%ymm7 # ia12 + +vmovupd (%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar +vmulpd %ymm6,%ymm13,%ymm8 # ia0.omai +vmulpd %ymm7,%ymm13,%ymm9 # ia4.omai +vmulpd %ymm2,%ymm13,%ymm10 # ra0.omai +vmulpd %ymm3,%ymm13,%ymm11 # ra4.omai +vfmsub231pd %ymm2,%ymm12,%ymm8 # rprod0 +vfmsub231pd %ymm3,%ymm12,%ymm9 # rprod4 +vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 +vfmadd231pd %ymm7,%ymm12,%ymm11 # iprod4 +vsubpd %ymm8,%ymm0,%ymm2 +vsubpd %ymm9,%ymm1,%ymm3 +vsubpd %ymm10,%ymm4,%ymm6 +vsubpd %ymm11,%ymm5,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vaddpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm4,%ymm4 +vaddpd %ymm11,%ymm5,%ymm5 + +1: +vmovupd 16(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar +vmulpd %ymm5,%ymm13,%ymm8 # ia0.omai (tw) +vmulpd %ymm7,%ymm12,%ymm9 # ia4.omar (itw) +vmulpd %ymm1,%ymm13,%ymm10 # ra0.omai (tw) +vmulpd %ymm3,%ymm12,%ymm11 # ra4.omar (itw) +vfmsub231pd %ymm1,%ymm12,%ymm8 # rprod0 (tw) +vfmadd231pd %ymm3,%ymm13,%ymm9 # rprod4 (itw) +vfmadd231pd %ymm5,%ymm12,%ymm10 # iprod0 (tw) +vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw) +vsubpd %ymm8,%ymm0,%ymm1 +vaddpd %ymm9,%ymm2,%ymm3 +vsubpd %ymm10,%ymm4,%ymm5 +vaddpd %ymm11,%ymm6,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vsubpd %ymm9,%ymm2,%ymm2 +vaddpd %ymm10,%ymm4,%ymm4 +vsubpd %ymm11,%ymm6,%ymm6 + +2: +vmovupd 0x20(%rdx),%ymm12 +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omaiii'i' +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omarrr'r' + +vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) + +vmulpd %ymm10,%ymm13,%ymm4 # ia0.omai (tw) +vmulpd %ymm11,%ymm12,%ymm5 # ia4.omar (itw) +vmulpd %ymm8,%ymm13,%ymm6 # ra0.omai (tw) +vmulpd %ymm9,%ymm12,%ymm7 # ra4.omar (itw) +vfmsub231pd %ymm8,%ymm12,%ymm4 # rprod0 (tw) +vfmadd231pd %ymm9,%ymm13,%ymm5 # rprod4 (itw) +vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 (tw) +vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw) +vsubpd %ymm4,%ymm0,%ymm8 +vaddpd %ymm5,%ymm1,%ymm9 +vsubpd %ymm6,%ymm2,%ymm10 +vaddpd %ymm7,%ymm3,%ymm11 +vaddpd %ymm4,%ymm0,%ymm0 +vsubpd %ymm5,%ymm1,%ymm1 +vaddpd %ymm6,%ymm2,%ymm2 +vsubpd %ymm7,%ymm3,%ymm3 + +#vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +#vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +#vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +#vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +#vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +#vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +#vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +#vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 + +3: +vmovupd 0x40(%rdx),%ymm12 +vmovupd 0x60(%rdx),%ymm13 + +#vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +#vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +#vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +#vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +#vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +#vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +#vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +#vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) + +vunpckhpd %ymm1,%ymm0,%ymm4 # (0,1) -> (0,4) +vunpckhpd %ymm3,%ymm2,%ymm6 # (2,3) -> (2,6) +vunpckhpd %ymm9,%ymm8,%ymm5 # (8,9) -> (1,5) +vunpckhpd %ymm11,%ymm10,%ymm7 # (10,11) -> (3,7) +vunpcklpd %ymm1,%ymm0,%ymm0 +vunpcklpd %ymm3,%ymm2,%ymm2 +vunpcklpd %ymm9,%ymm8,%ymm1 +vunpcklpd %ymm11,%ymm10,%ymm3 + +vmulpd %ymm6,%ymm13,%ymm8 # ia0.omai (tw) +vmulpd %ymm7,%ymm12,%ymm9 # ia4.omar (itw) +vmulpd %ymm4,%ymm13,%ymm10 # ra0.omai (tw) +vmulpd %ymm5,%ymm12,%ymm11 # ra4.omar (itw) +vfmsub231pd %ymm4,%ymm12,%ymm8 # rprod0 (tw) +vfmadd231pd %ymm5,%ymm13,%ymm9 # rprod4 (itw) +vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 (tw) +vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw) +vsubpd %ymm8,%ymm0,%ymm4 +vaddpd %ymm9,%ymm1,%ymm5 +vsubpd %ymm10,%ymm2,%ymm6 +vaddpd %ymm11,%ymm3,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vsubpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm2,%ymm2 +vsubpd %ymm11,%ymm3,%ymm3 + +vunpckhpd %ymm7,%ymm3,%ymm11 # (0,4) -> (0,1) +vunpckhpd %ymm5,%ymm1,%ymm9 # (2,6) -> (2,3) +vunpcklpd %ymm7,%ymm3,%ymm10 +vunpcklpd %ymm5,%ymm1,%ymm8 +vunpckhpd %ymm6,%ymm2,%ymm3 # (1,5) -> (8,9) +vunpckhpd %ymm4,%ymm0,%ymm1 # (3,7) -> (10,11) +vunpcklpd %ymm6,%ymm2,%ymm2 +vunpcklpd %ymm4,%ymm0,%ymm0 + +vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 + +4: +vmovupd %ymm0,(%rdi) # ra0 +vmovupd %ymm1,0x20(%rdi) # ra4 +vmovupd %ymm2,0x40(%rdi) # ra8 +vmovupd %ymm3,0x60(%rdi) # ra12 +vmovupd %ymm4,(%rsi) # ia0 +vmovupd %ymm5,0x20(%rsi) # ia4 +vmovupd %ymm6,0x40(%rsi) # ia8 +vmovupd %ymm7,0x60(%rsi) # ia12 +vzeroupper +ret +.size reim_fft16_avx_fma, .-reim_fft16_avx_fma +.section .note.GNU-stack,"",@progbits diff --git a/spqlios/lib/spqlios/reim/reim_fft16_avx_fma_win32.s b/spqlios/lib/spqlios/reim/reim_fft16_avx_fma_win32.s new file mode 100644 index 0000000..add742e --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fft16_avx_fma_win32.s @@ -0,0 +1,203 @@ + .text + .p2align 4 + .globl reim_fft16_avx_fma + .def reim_fft16_avx_fma; .scl 2; .type 32; .endef +reim_fft16_avx_fma: + + pushq %rdi + pushq %rsi + movq %rcx,%rdi + movq %rdx,%rsi + movq %r8,%rdx + subq $0x100,%rsp + movdqu %xmm6,(%rsp) + movdqu %xmm7,0x10(%rsp) + movdqu %xmm8,0x20(%rsp) + movdqu %xmm9,0x30(%rsp) + movdqu %xmm10,0x40(%rsp) + movdqu %xmm11,0x50(%rsp) + movdqu %xmm12,0x60(%rsp) + movdqu %xmm13,0x70(%rsp) + movdqu %xmm14,0x80(%rsp) + movdqu %xmm15,0x90(%rsp) + callq reim_fft16_avx_fma_amd64 + movdqu (%rsp),%xmm6 + movdqu 0x10(%rsp),%xmm7 + movdqu 0x20(%rsp),%xmm8 + movdqu 0x30(%rsp),%xmm9 + movdqu 0x40(%rsp),%xmm10 + movdqu 0x50(%rsp),%xmm11 + movdqu 0x60(%rsp),%xmm12 + movdqu 0x70(%rsp),%xmm13 + movdqu 0x80(%rsp),%xmm14 + movdqu 0x90(%rsp),%xmm15 + addq $0x100,%rsp + popq %rsi + popq %rdi + retq + +#rdi datare ptr +#rsi dataim ptr +#rdx om ptr +#.globl reim_fft16_avx_fma_amd64 +reim_fft16_avx_fma_amd64: +vmovupd (%rdi),%ymm0 # ra0 +vmovupd 0x20(%rdi),%ymm1 # ra4 +vmovupd 0x40(%rdi),%ymm2 # ra8 +vmovupd 0x60(%rdi),%ymm3 # ra12 +vmovupd (%rsi),%ymm4 # ia0 +vmovupd 0x20(%rsi),%ymm5 # ia4 +vmovupd 0x40(%rsi),%ymm6 # ia8 +vmovupd 0x60(%rsi),%ymm7 # ia12 + +vmovupd (%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar +vmulpd %ymm6,%ymm13,%ymm8 # ia0.omai +vmulpd %ymm7,%ymm13,%ymm9 # ia4.omai +vmulpd %ymm2,%ymm13,%ymm10 # ra0.omai +vmulpd %ymm3,%ymm13,%ymm11 # ra4.omai +vfmsub231pd %ymm2,%ymm12,%ymm8 # rprod0 +vfmsub231pd %ymm3,%ymm12,%ymm9 # rprod4 +vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 +vfmadd231pd %ymm7,%ymm12,%ymm11 # iprod4 +vsubpd %ymm8,%ymm0,%ymm2 +vsubpd %ymm9,%ymm1,%ymm3 +vsubpd %ymm10,%ymm4,%ymm6 +vsubpd %ymm11,%ymm5,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vaddpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm4,%ymm4 +vaddpd %ymm11,%ymm5,%ymm5 + +1: +vmovupd 16(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar +vmulpd %ymm5,%ymm13,%ymm8 # ia0.omai (tw) +vmulpd %ymm7,%ymm12,%ymm9 # ia4.omar (itw) +vmulpd %ymm1,%ymm13,%ymm10 # ra0.omai (tw) +vmulpd %ymm3,%ymm12,%ymm11 # ra4.omar (itw) +vfmsub231pd %ymm1,%ymm12,%ymm8 # rprod0 (tw) +vfmadd231pd %ymm3,%ymm13,%ymm9 # rprod4 (itw) +vfmadd231pd %ymm5,%ymm12,%ymm10 # iprod0 (tw) +vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw) +vsubpd %ymm8,%ymm0,%ymm1 +vaddpd %ymm9,%ymm2,%ymm3 +vsubpd %ymm10,%ymm4,%ymm5 +vaddpd %ymm11,%ymm6,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vsubpd %ymm9,%ymm2,%ymm2 +vaddpd %ymm10,%ymm4,%ymm4 +vsubpd %ymm11,%ymm6,%ymm6 + +2: +vmovupd 0x20(%rdx),%ymm12 +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omaiii'i' +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omarrr'r' + +vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) + +vmulpd %ymm10,%ymm13,%ymm4 # ia0.omai (tw) +vmulpd %ymm11,%ymm12,%ymm5 # ia4.omar (itw) +vmulpd %ymm8,%ymm13,%ymm6 # ra0.omai (tw) +vmulpd %ymm9,%ymm12,%ymm7 # ra4.omar (itw) +vfmsub231pd %ymm8,%ymm12,%ymm4 # rprod0 (tw) +vfmadd231pd %ymm9,%ymm13,%ymm5 # rprod4 (itw) +vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 (tw) +vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw) +vsubpd %ymm4,%ymm0,%ymm8 +vaddpd %ymm5,%ymm1,%ymm9 +vsubpd %ymm6,%ymm2,%ymm10 +vaddpd %ymm7,%ymm3,%ymm11 +vaddpd %ymm4,%ymm0,%ymm0 +vsubpd %ymm5,%ymm1,%ymm1 +vaddpd %ymm6,%ymm2,%ymm2 +vsubpd %ymm7,%ymm3,%ymm3 + +#vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +#vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +#vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +#vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +#vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +#vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +#vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +#vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 + +3: +vmovupd 0x40(%rdx),%ymm12 +vmovupd 0x60(%rdx),%ymm13 + +#vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +#vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +#vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +#vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +#vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +#vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +#vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +#vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) + +vunpckhpd %ymm1,%ymm0,%ymm4 # (0,1) -> (0,4) +vunpckhpd %ymm3,%ymm2,%ymm6 # (2,3) -> (2,6) +vunpckhpd %ymm9,%ymm8,%ymm5 # (8,9) -> (1,5) +vunpckhpd %ymm11,%ymm10,%ymm7 # (10,11) -> (3,7) +vunpcklpd %ymm1,%ymm0,%ymm0 +vunpcklpd %ymm3,%ymm2,%ymm2 +vunpcklpd %ymm9,%ymm8,%ymm1 +vunpcklpd %ymm11,%ymm10,%ymm3 + +vmulpd %ymm6,%ymm13,%ymm8 # ia0.omai (tw) +vmulpd %ymm7,%ymm12,%ymm9 # ia4.omar (itw) +vmulpd %ymm4,%ymm13,%ymm10 # ra0.omai (tw) +vmulpd %ymm5,%ymm12,%ymm11 # ra4.omar (itw) +vfmsub231pd %ymm4,%ymm12,%ymm8 # rprod0 (tw) +vfmadd231pd %ymm5,%ymm13,%ymm9 # rprod4 (itw) +vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 (tw) +vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw) +vsubpd %ymm8,%ymm0,%ymm4 +vaddpd %ymm9,%ymm1,%ymm5 +vsubpd %ymm10,%ymm2,%ymm6 +vaddpd %ymm11,%ymm3,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vsubpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm2,%ymm2 +vsubpd %ymm11,%ymm3,%ymm3 + +vunpckhpd %ymm7,%ymm3,%ymm11 # (0,4) -> (0,1) +vunpckhpd %ymm5,%ymm1,%ymm9 # (2,6) -> (2,3) +vunpcklpd %ymm7,%ymm3,%ymm10 +vunpcklpd %ymm5,%ymm1,%ymm8 +vunpckhpd %ymm6,%ymm2,%ymm3 # (1,5) -> (8,9) +vunpckhpd %ymm4,%ymm0,%ymm1 # (3,7) -> (10,11) +vunpcklpd %ymm6,%ymm2,%ymm2 +vunpcklpd %ymm4,%ymm0,%ymm0 + +vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 + +4: +vmovupd %ymm0,(%rdi) # ra0 +vmovupd %ymm1,0x20(%rdi) # ra4 +vmovupd %ymm2,0x40(%rdi) # ra8 +vmovupd %ymm3,0x60(%rdi) # ra12 +vmovupd %ymm4,(%rsi) # ia0 +vmovupd %ymm5,0x20(%rsi) # ia4 +vmovupd %ymm6,0x40(%rsi) # ia8 +vmovupd %ymm7,0x60(%rsi) # ia12 +vzeroupper +ret diff --git a/spqlios/lib/spqlios/reim/reim_fft4_avx_fma.c b/spqlios/lib/spqlios/reim/reim_fft4_avx_fma.c new file mode 100644 index 0000000..f7d24f3 --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fft4_avx_fma.c @@ -0,0 +1,66 @@ +#include +#include +#include + +#include "reim_fft_private.h" + +__always_inline void reim_ctwiddle_avx_fma(__m128d* ra, __m128d* rb, __m128d* ia, __m128d* ib, const __m128d omre, + const __m128d omim) { + // rb * omre - ib * omim; + __m128d rprod0 = _mm_mul_pd(*ib, omim); + rprod0 = _mm_fmsub_pd(*rb, omre, rprod0); + + // rb * omim + ib * omre; + __m128d iprod0 = _mm_mul_pd(*rb, omim); + iprod0 = _mm_fmadd_pd(*ib, omre, iprod0); + + *rb = _mm_sub_pd(*ra, rprod0); + *ib = _mm_sub_pd(*ia, iprod0); + *ra = _mm_add_pd(*ra, rprod0); + *ia = _mm_add_pd(*ia, iprod0); +} + +EXPORT void reim_fft4_avx_fma(double* dre, double* dim, const void* ompv) { + const double* omp = (const double*)ompv; + + __m128d ra0 = _mm_loadu_pd(dre); + __m128d ra2 = _mm_loadu_pd(dre + 2); + __m128d ia0 = _mm_loadu_pd(dim); + __m128d ia2 = _mm_loadu_pd(dim + 2); + + // 1 + { + // duplicate omegas in precomp? + __m128d om = _mm_loadu_pd(omp); + __m128d omre = _mm_permute_pd(om, 0); + __m128d omim = _mm_permute_pd(om, 3); + + reim_ctwiddle_avx_fma(&ra0, &ra2, &ia0, &ia2, omre, omim); + } + + // 2 + { + const __m128d fft4neg = _mm_castsi128_pd(_mm_set_epi64x(UINT64_C(1) << 63, 0)); + __m128d om = _mm_loadu_pd(omp + 2); // om: r,i + __m128d omim = _mm_permute_pd(om, 1); // omim: i,r + __m128d omre = _mm_xor_pd(om, fft4neg); // omre: r,-i + + __m128d rb = _mm_unpackhi_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r1, r3) + __m128d ib = _mm_unpackhi_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i1, i3) + __m128d ra = _mm_unpacklo_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r0, r2) + __m128d ia = _mm_unpacklo_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i0, i2) + + reim_ctwiddle_avx_fma(&ra, &rb, &ia, &ib, omre, omim); + + ra0 = _mm_unpacklo_pd(ra, rb); + ia0 = _mm_unpacklo_pd(ia, ib); + ra2 = _mm_unpackhi_pd(ra, rb); + ia2 = _mm_unpackhi_pd(ia, ib); + } + + // 4 + _mm_storeu_pd(dre, ra0); + _mm_storeu_pd(dre + 2, ra2); + _mm_storeu_pd(dim, ia0); + _mm_storeu_pd(dim + 2, ia2); +} diff --git a/spqlios/lib/spqlios/reim/reim_fft8_avx_fma.c b/spqlios/lib/spqlios/reim/reim_fft8_avx_fma.c new file mode 100644 index 0000000..a39597f --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fft8_avx_fma.c @@ -0,0 +1,89 @@ +#include +#include +#include + +#include "reim_fft_private.h" + +__always_inline void reim_ctwiddle_avx_fma(__m256d* ra, __m256d* rb, __m256d* ia, __m256d* ib, const __m256d omre, + const __m256d omim) { + // rb * omre - ib * omim; + __m256d rprod0 = _mm256_mul_pd(*ib, omim); + rprod0 = _mm256_fmsub_pd(*rb, omre, rprod0); + + // rb * omim + ib * omre; + __m256d iprod0 = _mm256_mul_pd(*rb, omim); + iprod0 = _mm256_fmadd_pd(*ib, omre, iprod0); + + *rb = _mm256_sub_pd(*ra, rprod0); + *ib = _mm256_sub_pd(*ia, iprod0); + *ra = _mm256_add_pd(*ra, rprod0); + *ia = _mm256_add_pd(*ia, iprod0); +} + +EXPORT void reim_fft8_avx_fma(double* dre, double* dim, const void* ompv) { + const double* omp = (const double*)ompv; + + __m256d ra0 = _mm256_loadu_pd(dre); + __m256d ra4 = _mm256_loadu_pd(dre + 4); + __m256d ia0 = _mm256_loadu_pd(dim); + __m256d ia4 = _mm256_loadu_pd(dim + 4); + + // 1 + { + // duplicate omegas in precomp? + __m128d omri = _mm_loadu_pd(omp); + __m256d omriri = _mm256_set_m128d(omri, omri); + __m256d omi = _mm256_permute_pd(omriri, 15); + __m256d omr = _mm256_permute_pd(omriri, 0); + + reim_ctwiddle_avx_fma(&ra0, &ra4, &ia0, &ia4, omr, omi); + } + + // 2 + { + const __m128d fft8neg = _mm_castsi128_pd(_mm_set_epi64x(UINT64_C(1) << 63, 0)); + __m128d omri = _mm_loadu_pd(omp + 2); // r,i + __m128d omrmi = _mm_xor_pd(omri, fft8neg); // r,-i + __m256d omrirmi = _mm256_set_m128d(omrmi, omri); // r,i,r,-i + __m256d omi = _mm256_permute_pd(omrirmi, 3); // i,i,r,r + __m256d omr = _mm256_permute_pd(omrirmi, 12); // r,r,-i,-i + + __m256d rb = _mm256_permute2f128_pd(ra0, ra4, 0x31); + __m256d ib = _mm256_permute2f128_pd(ia0, ia4, 0x31); + __m256d ra = _mm256_permute2f128_pd(ra0, ra4, 0x20); + __m256d ia = _mm256_permute2f128_pd(ia0, ia4, 0x20); + + reim_ctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi); + + ra0 = _mm256_permute2f128_pd(ra, rb, 0x20); + ra4 = _mm256_permute2f128_pd(ra, rb, 0x31); + ia0 = _mm256_permute2f128_pd(ia, ib, 0x20); + ia4 = _mm256_permute2f128_pd(ia, ib, 0x31); + } + + // 3 + { + const __m256d fft8neg2 = _mm256_castsi256_pd(_mm256_set_epi64x(UINT64_C(1) << 63, UINT64_C(1) << 63, 0, 0)); + __m256d om = _mm256_loadu_pd(omp + 4); // r0,r1,i0,i1 + __m256d omi = _mm256_permute2f128_pd(om, om, 1); // i0,i1,r0,r1 + __m256d omr = _mm256_xor_pd(om, fft8neg2); // r0,r1,-i0,-i1 + + __m256d rb = _mm256_unpackhi_pd(ra0, ra4); + __m256d ib = _mm256_unpackhi_pd(ia0, ia4); + __m256d ra = _mm256_unpacklo_pd(ra0, ra4); + __m256d ia = _mm256_unpacklo_pd(ia0, ia4); + + reim_ctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi); + + ra4 = _mm256_unpackhi_pd(ra, rb); + ia4 = _mm256_unpackhi_pd(ia, ib); + ra0 = _mm256_unpacklo_pd(ra, rb); + ia0 = _mm256_unpacklo_pd(ia, ib); + } + + // 4 + _mm256_storeu_pd(dre, ra0); + _mm256_storeu_pd(dre + 4, ra4); + _mm256_storeu_pd(dim, ia0); + _mm256_storeu_pd(dim + 4, ia4); +} diff --git a/spqlios/lib/spqlios/reim/reim_fft_avx2.c b/spqlios/lib/spqlios/reim/reim_fft_avx2.c new file mode 100644 index 0000000..351e6a6 --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fft_avx2.c @@ -0,0 +1,162 @@ +#include "immintrin.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +__always_inline void reim_twiddle_fft_avx2_fma(uint32_t h, double* re, double* im, double om[2]) { + const __m128d omx = _mm_load_pd(om); + const __m256d omra = _mm256_set_m128d(omx, omx); + const __m256d omi = _mm256_unpackhi_pd(omra, omra); + const __m256d omr = _mm256_unpacklo_pd(omra, omra); + double* r0 = re; + double* r1 = re + h; + double* i0 = im; + double* i1 = im + h; + for (uint32_t i=0; i> 1; + // do the first twiddle iteration normally + reim_twiddle_fft_avx2_fma(h, re, im, *omg); + *omg += 2; + mm = h; + } + while (mm > 16) { + uint32_t h = mm >> 2; + for (uint32_t off = 0; off < m; off += mm) { + reim_bitwiddle_fft_avx2_fma(h, re + off, im + off, *omg); + *omg += 4; + } + mm = h; + } + if (mm!=16) abort(); // bug! + for (uint32_t off = 0; off < m; off += 16) { + reim_fft16_avx_fma(re+off, im+off, *omg); + *omg += 16; + } +} + +void reim_fft_rec_16_avx2_fma(uint32_t m, double* re, double* im, double** omg) { + if (m <= 2048) return reim_fft_bfs_16_avx2_fma(m, re, im, omg); + const uint32_t h = m / 2; + reim_twiddle_fft_avx2_fma(h, re, im, *omg); + *omg += 2; + reim_fft_rec_16_avx2_fma(h, re, im, omg); + reim_fft_rec_16_avx2_fma(h, re + h, im + h, omg); +} + +void reim_fft_avx2_fma(const REIM_FFT_PRECOMP* precomp, double* dat) { + const int32_t m = precomp->m; + double* omg = precomp->powomegas; + double* re = dat; + double* im = dat+m; + if (m <= 16) { + switch (m) { + case 1: + return; + case 2: + return reim_fft2_ref(re, im, omg); + case 4: + return reim_fft4_avx_fma(re, im, omg); + case 8: + return reim_fft8_avx_fma(re, im, omg); + case 16: + return reim_fft16_avx_fma(re, im, omg); + default: + abort(); // m is not a power of 2 + } + } + if (m <= 2048) return reim_fft_bfs_16_avx2_fma(m, re, im, &omg); + return reim_fft_rec_16_avx2_fma(m, re, im, &omg); +} diff --git a/spqlios/lib/spqlios/reim/reim_fft_core_template.h b/spqlios/lib/spqlios/reim/reim_fft_core_template.h new file mode 100644 index 0000000..b82a05b --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fft_core_template.h @@ -0,0 +1,162 @@ +#ifndef SPQLIOS_REIM_FFT_CORE_TEMPLATE_H +#define SPQLIOS_REIM_FFT_CORE_TEMPLATE_H + +// this file contains the main template for the fft strategy. +// it is meant to be included once for each specialization (ref, avx, neon) +// all the leaf functions it uses shall be defined in the following macros, +// before including this header. + +#if !defined(reim_fft16_f) || !defined(reim_fft16_pom_offset) || !defined(fill_reim_fft16_omegas_f) +#error "missing reim16 definitions" +#endif +#if !defined(reim_fft8_f) || !defined(reim_fft8_pom_offset) || !defined(fill_reim_fft8_omegas_f) +#error "missing reim8 definitions" +#endif +#if !defined(reim_fft4_f) || !defined(reim_fft4_pom_offset) || !defined(fill_reim_fft4_omegas_f) +#error "missing reim4 definitions" +#endif +#if !defined(reim_fft2_f) || !defined(reim_fft2_pom_offset) || !defined(fill_reim_fft2_omegas_f) +#error "missing reim2 definitions" +#endif +#if !defined(reim_twiddle_fft_f) || !defined(reim_twiddle_fft_pom_offset) || !defined(fill_reim_twiddle_fft_omegas_f) +#error "missing twiddle definitions" +#endif +#if !defined(reim_bitwiddle_fft_f) || !defined(reim_bitwiddle_fft_pom_offset) || \ + !defined(fill_reim_bitwiddle_fft_omegas_f) +#error "missing bitwiddle definitions" +#endif +#if !defined(reim_fft_bfs_16_f) || !defined(fill_reim_fft_bfs_16_omegas_f) +#error "missing bfs_16 definitions" +#endif +#if !defined(reim_fft_rec_16_f) || !defined(fill_reim_fft_rec_16_omegas_f) +#error "missing rec_16 definitions" +#endif +#if !defined(reim_fft_f) || !defined(fill_reim_fft_omegas_f) +#error "missing main definitions" +#endif + +void reim_fft_bfs_16_f(uint64_t m, double* re, double* im, double** omg) { + uint64_t log2m = log2(m); + uint64_t mm = m; + if (log2m & 1) { + uint64_t h = mm >> 1; + // do the first twiddle iteration normally + reim_twiddle_fft_f(h, re, im, *omg); + *omg += reim_twiddle_fft_pom_offset; + mm = h; + } + while (mm > 16) { + uint64_t h = mm >> 2; + for (uint64_t off = 0; off < m; off += mm) { + reim_bitwiddle_fft_f(h, re + off, im + off, *omg); + *omg += reim_bitwiddle_fft_pom_offset; + } + mm = h; + } + if (mm != 16) abort(); // bug! + for (uint64_t off = 0; off < m; off += 16) { + reim_fft16_f(re + off, im + off, *omg); + *omg += reim_fft16_pom_offset; + } +} + +void fill_reim_fft_bfs_16_omegas_f(uint64_t m, double entry_pwr, double** omg) { + uint64_t log2m = log2(m); + uint64_t mm = m; + double ss = entry_pwr; + if (log2m % 2 != 0) { + uint64_t h = mm >> 1; + double s = ss / 2.; + // do the first twiddle iteration normally + fill_reim_twiddle_fft_omegas_f(s, omg); + mm = h; + ss = s; + } + while (mm > 16) { + uint64_t h = mm >> 2; + double s = ss / 4.; + for (uint64_t off = 0; off < m; off += mm) { + double rs0 = s + fracrevbits(off / mm) / 4.; + fill_reim_bitwiddle_fft_omegas_f(rs0, omg); + } + mm = h; + ss = s; + } + if (mm != 16) abort(); // bug! + for (uint64_t off = 0; off < m; off += 16) { + double s = ss + fracrevbits(off / 16); + fill_reim_fft16_omegas_f(s, omg); + } +} + +void reim_fft_rec_16_f(uint64_t m, double* re, double* im, double** omg) { + if (m <= 2048) return reim_fft_bfs_16_f(m, re, im, omg); + const uint32_t h = m >> 1; + reim_twiddle_fft_f(h, re, im, *omg); + *omg += reim_twiddle_fft_pom_offset; + reim_fft_rec_16_f(h, re, im, omg); + reim_fft_rec_16_f(h, re + h, im + h, omg); +} + +void fill_reim_fft_rec_16_omegas_f(uint64_t m, double entry_pwr, double** omg) { + if (m <= 2048) return fill_reim_fft_bfs_16_omegas_f(m, entry_pwr, omg); + const uint64_t h = m / 2; + const double s = entry_pwr / 2; + fill_reim_twiddle_fft_omegas_f(s, omg); + fill_reim_fft_rec_16_omegas_f(h, s, omg); + fill_reim_fft_rec_16_omegas_f(h, s + 0.5, omg); +} + +void reim_fft_f(const REIM_FFT_PRECOMP* precomp, double* dat) { + const int32_t m = precomp->m; + double* omg = precomp->powomegas; + double* re = dat; + double* im = dat + m; + if (m <= 16) { + switch (m) { + case 1: + return; + case 2: + return reim_fft2_f(re, im, omg); + case 4: + return reim_fft4_f(re, im, omg); + case 8: + return reim_fft8_f(re, im, omg); + case 16: + return reim_fft16_f(re, im, omg); + default: + abort(); // m is not a power of 2 + } + } + if (m <= 2048) return reim_fft_bfs_16_f(m, re, im, &omg); + return reim_fft_rec_16_f(m, re, im, &omg); +} + +void fill_reim_fft_omegas_f(uint64_t m, double entry_pwr, double** omg) { + if (m <= 16) { + switch (m) { + case 1: + break; + case 2: + fill_reim_fft2_omegas_f(entry_pwr, omg); + break; + case 4: + fill_reim_fft4_omegas_f(entry_pwr, omg); + break; + case 8: + fill_reim_fft8_omegas_f(entry_pwr, omg); + break; + case 16: + fill_reim_fft16_omegas_f(entry_pwr, omg); + break; + default: + abort(); // m is not a power of 2 + } + } else if (m <= 2048) { + fill_reim_fft_bfs_16_omegas_f(m, entry_pwr, omg); + } else { + fill_reim_fft_rec_16_omegas_f(m, entry_pwr, omg); + } +} + +#endif // SPQLIOS_REIM_FFT_CORE_TEMPLATE_H diff --git a/spqlios/lib/spqlios/reim/reim_fft_ifft.c b/spqlios/lib/spqlios/reim/reim_fft_ifft.c new file mode 100644 index 0000000..2a3c54e --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fft_ifft.c @@ -0,0 +1,37 @@ +#include +#include +#include +#include + +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +double accurate_cos(int32_t i, int32_t n) { // cos(2pi*i/n) + i = ((i % n) + n) % n; + if (i >= 3 * n / 4) return cos(2. * M_PI * (n - i) / (double)(n)); + if (i >= 2 * n / 4) return -cos(2. * M_PI * (i - n / 2) / (double)(n)); + if (i >= 1 * n / 4) return -cos(2. * M_PI * (n / 2 - i) / (double)(n)); + return cos(2. * M_PI * (i) / (double)(n)); +} + +double accurate_sin(int32_t i, int32_t n) { // sin(2pi*i/n) + i = ((i % n) + n) % n; + if (i >= 3 * n / 4) return -sin(2. * M_PI * (n - i) / (double)(n)); + if (i >= 2 * n / 4) return -sin(2. * M_PI * (i - n / 2) / (double)(n)); + if (i >= 1 * n / 4) return sin(2. * M_PI * (n / 2 - i) / (double)(n)); + return sin(2. * M_PI * (i) / (double)(n)); +} + + +EXPORT double* reim_ifft_precomp_get_buffer(const REIM_IFFT_PRECOMP* tables, uint32_t buffer_index) { + return (double*)((uint8_t*) tables->aligned_buffers + buffer_index * tables->buf_size); +} + +EXPORT double* reim_fft_precomp_get_buffer(const REIM_FFT_PRECOMP* tables, uint32_t buffer_index) { + return (double*)((uint8_t*) tables->aligned_buffers + buffer_index * tables->buf_size); +} + + +EXPORT void reim_fft(const REIM_FFT_PRECOMP* tables, double* data) { tables->function(tables, data); } +EXPORT void reim_ifft(const REIM_IFFT_PRECOMP* tables, double* data) { tables->function(tables, data); } diff --git a/spqlios/lib/spqlios/reim/reim_fft_internal.h b/spqlios/lib/spqlios/reim/reim_fft_internal.h new file mode 100644 index 0000000..e30bbeb --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fft_internal.h @@ -0,0 +1,143 @@ +#ifndef SPQLIOS_REIM_FFT_INTERNAL_H +#define SPQLIOS_REIM_FFT_INTERNAL_H + +#include +#include +#include + +#include "reim_fft.h" + +EXPORT void reim_fftvec_addmul_fma(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, + const double* b); +EXPORT void reim_fftvec_addmul_ref(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, + const double* b); + +EXPORT void reim_fftvec_mul_fma(const REIM_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); +EXPORT void reim_fftvec_mul_ref(const REIM_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); + + +/** @brief r = x from ZnX (coeffs as signed int32_t's ) to double */ +EXPORT void reim_from_znx32_ref(const REIM_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x); +EXPORT void reim_from_znx32_avx2_fma(const REIM_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x); + +/** @brief r = x mod 1 (coeffs as double) */ +EXPORT void reim_to_tnx_ref(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x); +EXPORT void reim_to_tnx_avx(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x); + +EXPORT void reim_from_znx64_ref(const REIM_FROM_ZNX64_PRECOMP* precomp, void* r, const int64_t* x); +EXPORT void reim_from_znx64_bnd50_fma(const REIM_FROM_ZNX64_PRECOMP* precomp, void* r, const int64_t* x); + +EXPORT void reim_to_znx64_ref(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x); +EXPORT void reim_to_znx64_avx2_bnd63_fma(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x); +EXPORT void reim_to_znx64_avx2_bnd50_fma(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x); + +/** + * @brief compute the fft evaluations of P in place + * fft(data) = fft_rec(data, i); + * function fft_rec(data, omega) { + * if #data = 1: return data + * let s = sqrt(omega) w. re(s)>0 + * let (u,v) = merge_fft(data, s) + * return [fft_rec(u, s), fft_rec(v, -s)] + * } + * @param tables precomputed tables (contains all the powers of omega in the order they are used) + * @param data vector of m complexes (coeffs as input, evals as output) + */ +EXPORT void reim_fft_ref(const REIM_FFT_PRECOMP* tables, double* data); +EXPORT void reim_fft_avx2_fma(const REIM_FFT_PRECOMP* tables, double* data); + +/** + * @brief compute the ifft evaluations of P in place + * ifft(data) = ifft_rec(data, i); + * function ifft_rec(data, omega) { + * if #data = 1: return data + * let s = sqrt(omega) w. re(s)>0 + * let (u,v) = data + * return split_fft([ifft_rec(u, s), ifft_rec(v, -s)],s) + * } + * @param itables precomputed tables (contains all the powers of omega in the order they are used) + * @param data vector of m complexes (coeffs as input, evals as output) + */ +EXPORT void reim_ifft_ref(const REIM_IFFT_PRECOMP* itables, double* data); +EXPORT void reim_ifft_avx2_fma(const REIM_IFFT_PRECOMP* itables, double* data); + +/// new compressed implementation + +/** @brief naive FFT code mod X^m-exp(2i.pi.entry_pwr) */ +EXPORT void reim_naive_fft(uint64_t m, double entry_pwr, double* re, double* im); + +/** @brief 16-dimensional FFT with precomputed omegas */ +EXPORT void reim_fft16_neon(double* dre, double* dim, const void* omega); +EXPORT void reim_fft16_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_fft16_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft16 functions */ +EXPORT void fill_reim_fft16_omegas(const double entry_pwr, double** omg); +EXPORT void fill_reim_fft16_omegas_neon(const double entry_pwr, double** omg); + +/** @brief 8-dimensional FFT with precomputed omegas */ +EXPORT void reim_fft8_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_fft8_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft8 functions */ +EXPORT void fill_reim_fft8_omegas(const double entry_pwr, double** omg); + +/** @brief 4-dimensional FFT with precomputed omegas */ +EXPORT void reim_fft4_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_fft4_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft8 functions */ +EXPORT void fill_reim_fft4_omegas(const double entry_pwr, double** omg); + +/** @brief 2-dimensional FFT with precomputed omegas */ +//EXPORT void reim_fft4_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_fft2_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft8 functions */ +EXPORT void fill_reim_fft2_omegas(const double entry_pwr, double** omg); + +EXPORT void reim_fft_bfs_16_ref(uint64_t m, double* re, double* im, double** omg); +EXPORT void fill_reim_fft_bfs_16_omegas(uint64_t m, double entry_pwr, double** omg); + +EXPORT void reim_fft_rec_16_ref(uint64_t m, double* re, double* im, double** omg); +EXPORT void fill_reim_fft_rec_16_omegas(uint64_t m, double entry_pwr, double** omg); + + +/** @brief naive FFT code mod X^m-exp(2i.pi.entry_pwr) */ +EXPORT void reim_naive_ifft(uint64_t m, double entry_pwr, double* re, double* im); + +/** @brief 16-dimensional FFT with precomputed omegas */ +EXPORT void reim_ifft16_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_ifft16_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft16 functions */ +EXPORT void fill_reim_ifft16_omegas(const double entry_pwr, double** omg); + +/** @brief 8-dimensional FFT with precomputed omegas */ +EXPORT void reim_ifft8_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_ifft8_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft8 functions */ +EXPORT void fill_reim_ifft8_omegas(const double entry_pwr, double** omg); + +/** @brief 4-dimensional FFT with precomputed omegas */ +EXPORT void reim_ifft4_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_ifft4_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft8 functions */ +EXPORT void fill_reim_ifft4_omegas(const double entry_pwr, double** omg); + +/** @brief 2-dimensional FFT with precomputed omegas */ +//EXPORT void reim_ifft2_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_ifft2_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft8 functions */ +EXPORT void fill_reim_ifft2_omegas(const double entry_pwr, double** omg); + +EXPORT void reim_ifft_bfs_16_ref(uint64_t m, double* re, double* im, double** omg); +EXPORT void fill_reim_ifft_bfs_16_omegas(uint64_t m, double entry_pwr, double** omg); + +EXPORT void reim_ifft_rec_16_ref(uint64_t m, double* re, double* im, double** omg); +EXPORT void fill_reim_ifft_rec_16_omegas(uint64_t m, double entry_pwr, double** omg); + +#endif // SPQLIOS_REIM_FFT_INTERNAL_H diff --git a/spqlios/lib/spqlios/reim/reim_fft_neon.c b/spqlios/lib/spqlios/reim/reim_fft_neon.c new file mode 100644 index 0000000..d350957 --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fft_neon.c @@ -0,0 +1,1627 @@ +/* + * This file is adapted from the implementation of the FFT on Arm64/Neon + * available in https://github.com/cothan/Falcon-Arm (neon/fft.c). + * ============================================================================= + * Copyright (c) 2022 by Cryptographic Engineering Research Group (CERG) + * ECE Department, George Mason University + * Fairfax, VA, U.S.A. + * @author: Duc Tri Nguyen dnguye69@gmu.edu, cothannguyen@gmail.com + * Licensed under the Apache License, Version 2.0 (the "License"); + * ============================================================================= + * + * The original source file has been modified by the authors of spqlios-arithmetic + * to be interfaced with the dynamic twiddle factors generator of the spqlios-arithmetic + * library, as well as the recursive dfs strategy for large complex dimensions m>=2048. + */ + +#include + +#include "../commons.h" +#include "../ext/neon_accel/macrof.h" +#include "../ext/neon_accel/macrofx4.h" + +void fill_reim_fft16_omegas_neon(const double entry_pwr, double** omg) { + const double j_pow = 1. / 8.; + const double k_pow = 1. / 16.; + const double pin = entry_pwr / 2.; + const double pin_2 = entry_pwr / 4.; + const double pin_4 = entry_pwr / 8.; + const double pin_8 = entry_pwr / 16.; + // 0 and 1 are real and imag of om + (*omg)[0] = cos(2. * M_PI * pin); + (*omg)[1] = sin(2. * M_PI * pin); + // 2 and 3 are real and imag of om^1/2 + (*omg)[2] = cos(2. * M_PI * (pin_2)); + (*omg)[3] = sin(2. * M_PI * (pin_2)); + // (4,5) and (6,7) are real and imag of om^1/4 and j.om^1/4 + (*omg)[4] = cos(2. * M_PI * (pin_4)); + (*omg)[5] = sin(2. * M_PI * (pin_4)); + (*omg)[6] = cos(2. * M_PI * (pin_4 + j_pow)); + (*omg)[7] = sin(2. * M_PI * (pin_4 + j_pow)); + // ((8,9,10,11),(12,13,14,15)) are 4 reals then 4 imag of om^1/8*(1,k,j,kj) + (*omg)[8] = cos(2. * M_PI * (pin_8)); + (*omg)[10] = cos(2. * M_PI * (pin_8 + j_pow)); + (*omg)[12] = cos(2. * M_PI * (pin_8 + k_pow)); + (*omg)[14] = cos(2. * M_PI * (pin_8 + j_pow + k_pow)); + (*omg)[9] = sin(2. * M_PI * (pin_8)); + (*omg)[11] = sin(2. * M_PI * (pin_8 + j_pow)); + (*omg)[13] = sin(2. * M_PI * (pin_8 + k_pow)); + (*omg)[15] = sin(2. * M_PI * (pin_8 + j_pow + k_pow)); + *omg += 16; +} + + +EXPORT void reim_fft16_neon(double* dre, double* dim, const void* omega) { + const double* pom = omega; + // Total SIMD register: 28 = 24 + 4 + float64x2x2_t s_re_im; // 2 + float64x2x4_t x_re, x_im, y_re, y_im, t_re, t_im, v_re, v_im; // 32 + + { + /* + Level 2 + ( 8, 24) * ( 0, 1) + ( 9, 25) * ( 0, 1) + ( 10, 26) * ( 0, 1) + ( 11, 27) * ( 0, 1) + ( 12, 28) * ( 0, 1) + ( 13, 29) * ( 0, 1) + ( 14, 30) * ( 0, 1) + ( 15, 31) * ( 0, 1) + + ( 8, 24) = ( 0, 16) - @ + ( 9, 25) = ( 1, 17) - @ + ( 10, 26) = ( 2, 18) - @ + ( 11, 27) = ( 3, 19) - @ + ( 12, 28) = ( 4, 20) - @ + ( 13, 29) = ( 5, 21) - @ + ( 14, 30) = ( 6, 22) - @ + ( 15, 31) = ( 7, 23) - @ + + ( 0, 16) = ( 0, 16) + @ + ( 1, 17) = ( 1, 17) + @ + ( 2, 18) = ( 2, 18) + @ + ( 3, 19) = ( 3, 19) + @ + ( 4, 20) = ( 4, 20) + @ + ( 5, 21) = ( 5, 21) + @ + ( 6, 22) = ( 6, 22) + @ + ( 7, 23) = ( 7, 23) + @ + */ + vload(s_re_im.val[0], pom); + + vloadx4(y_re, dre + 8); + vloadx4(y_im, dim + 8); + + FWD_TOP_LANEx4(v_re, v_im, y_re, y_im, s_re_im.val[0]); + + vloadx4(x_re, dre); + vloadx4(x_im, dim); + + FWD_BOTx4(x_re, x_im, y_re, y_im, v_re, v_im); + + //vstorex4(dre, x_re); + //vstorex4(dim, x_im); + //vstorex4(dre + 8, y_re); + //vstorex4(dim + 8, y_im); + //return; + } + { + /* + Level 3 + + ( 4, 20) * ( 0, 1) + ( 5, 21) * ( 0, 1) + ( 6, 22) * ( 0, 1) + ( 7, 23) * ( 0, 1) + + ( 4, 20) = ( 0, 16) - @ + ( 5, 21) = ( 1, 17) - @ + ( 6, 22) = ( 2, 18) - @ + ( 7, 23) = ( 3, 19) - @ + + ( 0, 16) = ( 0, 16) + @ + ( 1, 17) = ( 1, 17) + @ + ( 2, 18) = ( 2, 18) + @ + ( 3, 19) = ( 3, 19) + @ + + ( 12, 28) * ( 0, 1) + ( 13, 29) * ( 0, 1) + ( 14, 30) * ( 0, 1) + ( 15, 31) * ( 0, 1) + + ( 12, 28) = ( 8, 24) - j@ + ( 13, 29) = ( 9, 25) - j@ + ( 14, 30) = ( 10, 26) - j@ + ( 15, 31) = ( 11, 27) - j@ + + ( 8, 24) = ( 8, 24) + j@ + ( 9, 25) = ( 9, 25) + j@ + ( 10, 26) = ( 10, 26) + j@ + ( 11, 27) = ( 11, 27) + j@ + */ + + //vloadx4(y_re, dre + 8); + //vloadx4(y_im, dim + 8); + //vloadx4(x_re, dre); + //vloadx4(x_im, dim); + + vload(s_re_im.val[0], pom + 2); + + FWD_TOP_LANE(t_re.val[0], t_im.val[0], x_re.val[2], x_im.val[2], s_re_im.val[0]); + FWD_TOP_LANE(t_re.val[1], t_im.val[1], x_re.val[3], x_im.val[3], s_re_im.val[0]); + FWD_TOP_LANE(t_re.val[2], t_im.val[2], y_re.val[2], y_im.val[2], s_re_im.val[0]); + FWD_TOP_LANE(t_re.val[3], t_im.val[3], y_re.val[3], y_im.val[3], s_re_im.val[0]); + + FWD_BOT(x_re.val[0], x_im.val[0], x_re.val[2], x_im.val[2], t_re.val[0], t_im.val[0]); + FWD_BOT(x_re.val[1], x_im.val[1], x_re.val[3], x_im.val[3], t_re.val[1], t_im.val[1]); + FWD_BOTJ(y_re.val[0], y_im.val[0], y_re.val[2], y_im.val[2], t_re.val[2], t_im.val[2]); + FWD_BOTJ(y_re.val[1], y_im.val[1], y_re.val[3], y_im.val[3], t_re.val[3], t_im.val[3]); + + //vstorex4(dre, x_re); + //vstorex4(dim, x_im); + //vstorex4(dre + 8, y_re); + //vstorex4(dim + 8, y_im); + //return; + } + { + /* + Level 4 + + ( 2, 18) * ( 0, 1) + ( 3, 19) * ( 0, 1) + ( 6, 22) * ( 0, 1) + ( 7, 23) * ( 0, 1) + + ( 2, 18) = ( 0, 16) - @ + ( 3, 19) = ( 1, 17) - @ + ( 0, 16) = ( 0, 16) + @ + ( 1, 17) = ( 1, 17) + @ + + ( 6, 22) = ( 4, 20) - j@ + ( 7, 23) = ( 5, 21) - j@ + ( 4, 20) = ( 4, 20) + j@ + ( 5, 21) = ( 5, 21) + j@ + + ( 10, 26) * ( 2, 3) + ( 11, 27) * ( 2, 3) + ( 14, 30) * ( 2, 3) + ( 15, 31) * ( 2, 3) + + ( 10, 26) = ( 8, 24) - @ + ( 11, 27) = ( 9, 25) - @ + ( 8, 24) = ( 8, 24) + @ + ( 9, 25) = ( 9, 25) + @ + + ( 14, 30) = ( 12, 28) - j@ + ( 15, 31) = ( 13, 29) - j@ + ( 12, 28) = ( 12, 28) + j@ + ( 13, 29) = ( 13, 29) + j@ + */ + //vloadx4(y_re, dre + 8); + //vloadx4(y_im, dim + 8); + //vloadx4(x_re, dre); + //vloadx4(x_im, dim); + + vloadx2(s_re_im, pom + 4); + + FWD_TOP_LANE(t_re.val[0], t_im.val[0], x_re.val[1], x_im.val[1], s_re_im.val[0]); + FWD_TOP_LANE(t_re.val[1], t_im.val[1], x_re.val[3], x_im.val[3], s_re_im.val[0]); + FWD_TOP_LANE(t_re.val[2], t_im.val[2], y_re.val[1], y_im.val[1], s_re_im.val[1]); + FWD_TOP_LANE(t_re.val[3], t_im.val[3], y_re.val[3], y_im.val[3], s_re_im.val[1]); + + FWD_BOT(x_re.val[0], x_im.val[0], x_re.val[1], x_im.val[1], t_re.val[0], t_im.val[0]); + FWD_BOTJ(x_re.val[2], x_im.val[2], x_re.val[3], x_im.val[3], t_re.val[1], t_im.val[1]); + FWD_BOT(y_re.val[0], y_im.val[0], y_re.val[1], y_im.val[1], t_re.val[2], t_im.val[2]); + FWD_BOTJ(y_re.val[2], y_im.val[2], y_re.val[3], y_im.val[3], t_re.val[3], t_im.val[3]); + + //vstorex4(dre, x_re); + //vstorex4(dim, x_im); + //vstorex4(dre + 8, y_re); + //vstorex4(dim + 8, y_im); + //return; + } + { + /* + Level 5 + + ( 1, 17) * ( 0, 1) + ( 5, 21) * ( 2, 3) + ------ + ( 1, 17) = ( 0, 16) - @ + ( 5, 21) = ( 4, 20) - @ + ( 0, 16) = ( 0, 16) + @ + ( 4, 20) = ( 4, 20) + @ + + ( 3, 19) * ( 0, 1) + ( 7, 23) * ( 2, 3) + ------ + ( 3, 19) = ( 2, 18) - j@ + ( 7, 23) = ( 6, 22) - j@ + ( 2, 18) = ( 2, 18) + j@ + ( 6, 22) = ( 6, 22) + j@ + + ( 9, 25) * ( 4, 5) + ( 13, 29) * ( 6, 7) + ------ + ( 9, 25) = ( 8, 24) - @ + ( 13, 29) = ( 12, 28) - @ + ( 8, 24) = ( 8, 24) + @ + ( 12, 28) = ( 12, 28) + @ + + ( 11, 27) * ( 4, 5) + ( 15, 31) * ( 6, 7) + ------ + ( 11, 27) = ( 10, 26) - j@ + ( 15, 31) = ( 14, 30) - j@ + ( 10, 26) = ( 10, 26) + j@ + ( 14, 30) = ( 14, 30) + j@ + + before transpose_f64 + x_re: 0, 1 | 2, 3 | 4, 5 | 6, 7 + y_re: 8, 9 | 10, 11 | 12, 13 | 14, 15 + after transpose_f64 + x_re: 0, 4 | 2, 6 | 1, 5 | 3, 7 + y_re: 8, 12| 9, 13| 10, 14 | 11, 15 + after swap + x_re: 0, 4 | 1, 5 | 2, 6 | 3, 7 + y_re: 8, 12| 10, 14 | 9, 13| 11, 15 + */ + + //vloadx4(y_re, dre + 8); + //vloadx4(y_im, dim + 8); + //vloadx4(x_re, dre); + //vloadx4(x_im, dim); + + transpose_f64(x_re, x_re, v_re, 0, 2, 0); + transpose_f64(x_re, x_re, v_re, 1, 3, 1); + transpose_f64(x_im, x_im, v_im, 0, 2, 0); + transpose_f64(x_im, x_im, v_im, 1, 3, 1); + + + v_re.val[0] = x_re.val[2]; + x_re.val[2] = x_re.val[1]; + x_re.val[1] = v_re.val[0]; + + v_im.val[0] = x_im.val[2]; + x_im.val[2] = x_im.val[1]; + x_im.val[1] = v_im.val[0]; + + transpose_f64(y_re, y_re, v_re, 0, 2, 2); + transpose_f64(y_re, y_re, v_re, 1, 3, 3); + transpose_f64(y_im, y_im, v_im, 0, 2, 2); + transpose_f64(y_im, y_im, v_im, 1, 3, 3); + + v_re.val[0] = y_re.val[2]; + y_re.val[2] = y_re.val[1]; + y_re.val[1] = v_re.val[0]; + + v_im.val[0] = y_im.val[2]; + y_im.val[2] = y_im.val[1]; + y_im.val[1] = v_im.val[0]; + + //double pom8[] = {pom[8], pom[12], pom[9], pom[13]}; + vload2(s_re_im, pom+8); + //vload2(s_re_im, pom8); + + FWD_TOP(t_re.val[0], t_im.val[0], x_re.val[1], x_im.val[1], s_re_im.val[0], s_re_im.val[1]); + FWD_TOP(t_re.val[1], t_im.val[1], x_re.val[3], x_im.val[3], s_re_im.val[0], s_re_im.val[1]); + + //double pom12[] = {pom[10], pom[14], pom[11], pom[15]}; + vload2(s_re_im, pom+12); + //vload2(s_re_im, pom12); + + FWD_TOP(t_re.val[2], t_im.val[2], y_re.val[1], y_im.val[1], s_re_im.val[0], s_re_im.val[1]); + FWD_TOP(t_re.val[3], t_im.val[3], y_re.val[3], y_im.val[3], s_re_im.val[0], s_re_im.val[1]); + + FWD_BOT (x_re.val[0], x_im.val[0], x_re.val[1], x_im.val[1], t_re.val[0], t_im.val[0]); + FWD_BOTJ(x_re.val[2], x_im.val[2], x_re.val[3], x_im.val[3], t_re.val[1], t_im.val[1]); + + vstore4(dre, x_re); + vstore4(dim, x_im); + + FWD_BOT (y_re.val[0], y_im.val[0], y_re.val[1], y_im.val[1], t_re.val[2], t_im.val[2]); + FWD_BOTJ(y_re.val[2], y_im.val[2], y_re.val[3], y_im.val[3], t_re.val[3], t_im.val[3]); + + vstore4(dre+8, y_re); + vstore4(dim+8, y_im); + } +} + +void reim_twiddle_fft_neon(uint64_t h, double* re, double* im, double om[2]) { + // Total SIMD register: 28 = 24 + 4 + if (h<8) abort(); // bug + float64x2_t s_re_im; // 2 + float64x2x4_t x_re, x_im, y_re, y_im, v_re, v_im; // 32 + vload(s_re_im, om); + for (uint64_t blk = 0; blk < h; blk+=8) { + double* dre = re + blk; + double* dim = im + blk; + vloadx4(y_re, dre + h); + vloadx4(y_im, dim + h); + FWD_TOP_LANEx4(v_re, v_im, y_re, y_im, s_re_im); + vloadx4(x_re, dre); + vloadx4(x_im, dim); + FWD_BOTx4(x_re, x_im, y_re, y_im, v_re, v_im); + vstorex4(dre, x_re); + vstorex4(dim, x_im); + vstorex4(dre + h, y_re); + vstorex4(dim + h, y_im); + } +} + +void reim_ctwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim); +// i (omre + i omim) = -omim + i omre +void reim_citwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim); + +void reim_bitwiddle_fft_neon(uint64_t h, double* re, double* im, double om[4]) { + // Total SIMD register: 28 = 24 + 4 + if (h<4) abort(); // bug + double* r0 = re; + double* r1 = re + h; + double* r2 = re + 2*h; + double* r3 = re + 3*h; + double* i0 = im; + double* i1 = im + h; + double* i2 = im + 2*h; + double* i3 = im + 3*h; + float64x2x2_t s_re_im; // 2 + float64x2x4_t v_re, v_im; // 2 + float64x2x2_t x0_re, x0_im, x1_re, x1_im; // 32 + float64x2x2_t x2_re, x2_im, x3_re, x3_im; // 32 + vloadx2(s_re_im, om); + for (uint64_t blk=0; blk> 1; + + int level = logn; + const fpr *fpr_tab5 = fpr_table[level--], + *fpr_tab4 = fpr_table[level--], + *fpr_tab3 = fpr_table[level--], + *fpr_tab2 = fpr_table[level]; + int k2 = 0, k3 = 0, k4 = 0, k5 = 0; + + for (unsigned j = 0; j < hn; j += 16) + { + /* + * ( 0, 16) - ( 1, 17) + * ( 4, 20) - ( 5, 21) + * ( 0, 16) + ( 1, 17) + * ( 4, 20) + ( 5, 21) + * ( 1, 17) = @ * ( 0, 1) + * ( 5, 21) = @ * ( 2, 3) + * + * ( 2, 18) - ( 3, 19) + * ( 6, 22) - ( 7, 23) + * ( 2, 18) + ( 3, 19) + * ( 6, 22) + ( 7, 23) + * ( 3, 19) = j@ * ( 0, 1) + * ( 7, 23) = j@ * ( 2, 3) + * + * ( 8, 24) - ( 9, 25) + * ( 12, 28) - ( 13, 29) + * ( 8, 24) + ( 9, 25) + * ( 12, 28) + ( 13, 29) + * ( 9, 25) = @ * ( 4, 5) + * ( 13, 29) = @ * ( 6, 7) + * + * ( 10, 26) - ( 11, 27) + * ( 14, 30) - ( 15, 31) + * ( 10, 26) + ( 11, 27) + * ( 14, 30) + ( 15, 31) + * ( 11, 27) = j@ * ( 4, 5) + * ( 15, 31) = j@ * ( 6, 7) + */ + + vload4(x_re, &f[j]); + vload4(x_im, &f[j + hn]); + + INV_TOPJ(t_re.val[0], t_im.val[0], x_re.val[0], x_im.val[0], x_re.val[1], x_im.val[1]); + INV_TOPJm(t_re.val[2], t_im.val[2], x_re.val[2], x_im.val[2], x_re.val[3], x_im.val[3]); + + vload4(y_re, &f[j + 8]); + vload4(y_im, &f[j + 8 + hn]) + + INV_TOPJ(t_re.val[1], t_im.val[1], y_re.val[0], y_im.val[0], y_re.val[1], y_im.val[1]); + INV_TOPJm(t_re.val[3], t_im.val[3], y_re.val[2], y_im.val[2], y_re.val[3], y_im.val[3]); + + vload2(s_re_im, &fpr_tab5[k5]); + k5 += 4; + + INV_BOTJ (x_re.val[1], x_im.val[1], t_re.val[0], t_im.val[0], s_re_im.val[0], s_re_im.val[1]); + INV_BOTJm(x_re.val[3], x_im.val[3], t_re.val[2], t_im.val[2], s_re_im.val[0], s_re_im.val[1]); + + vload2(s_re_im, &fpr_tab5[k5]); + k5 += 4; + + INV_BOTJ (y_re.val[1], y_im.val[1], t_re.val[1], t_im.val[1], s_re_im.val[0], s_re_im.val[1]); + INV_BOTJm(y_re.val[3], y_im.val[3], t_re.val[3], t_im.val[3], s_re_im.val[0], s_re_im.val[1]); + + + // x_re: 0, 4 | 1, 5 | 2, 6 | 3, 7 + // y_re: 8, 12| 9, 13|10, 14|11, 15 + + transpose_f64(x_re, x_re, t_re, 0, 1, 0); + transpose_f64(x_re, x_re, t_re, 2, 3, 1); + transpose_f64(y_re, y_re, t_re, 0, 1, 2); + transpose_f64(y_re, y_re, t_re, 2, 3, 3); + + transpose_f64(x_im, x_im, t_im, 0, 1, 0); + transpose_f64(x_im, x_im, t_im, 2, 3, 1); + transpose_f64(y_im, y_im, t_im, 0, 1, 2); + transpose_f64(y_im, y_im, t_im, 2, 3, 3); + + // x_re: 0, 1 | 4, 5 | 2, 3 | 6, 7 + // y_re: 8, 9 | 12,13|10,11 |14, 15 + + t_re.val[0] = x_re.val[1]; + x_re.val[1] = x_re.val[2]; + x_re.val[2] = t_re.val[0]; + + t_re.val[1] = y_re.val[1]; + y_re.val[1] = y_re.val[2]; + y_re.val[2] = t_re.val[1]; + + + t_im.val[0] = x_im.val[1]; + x_im.val[1] = x_im.val[2]; + x_im.val[2] = t_im.val[0]; + + t_im.val[1] = y_im.val[1]; + y_im.val[1] = y_im.val[2]; + y_im.val[2] = t_im.val[1]; + // x_re: 0, 1 | 2, 3| 4, 5 | 6, 7 + // y_re: 8, 9 | 10, 11| 12, 13| 14, 15 + + /* + * ( 0, 16) - ( 2, 18) + * ( 1, 17) - ( 3, 19) + * ( 0, 16) + ( 2, 18) + * ( 1, 17) + ( 3, 19) + * ( 2, 18) = @ * ( 0, 1) + * ( 3, 19) = @ * ( 0, 1) + * + * ( 4, 20) - ( 6, 22) + * ( 5, 21) - ( 7, 23) + * ( 4, 20) + ( 6, 22) + * ( 5, 21) + ( 7, 23) + * ( 6, 22) = j@ * ( 0, 1) + * ( 7, 23) = j@ * ( 0, 1) + * + * ( 8, 24) - ( 10, 26) + * ( 9, 25) - ( 11, 27) + * ( 8, 24) + ( 10, 26) + * ( 9, 25) + ( 11, 27) + * ( 10, 26) = @ * ( 2, 3) + * ( 11, 27) = @ * ( 2, 3) + * + * ( 12, 28) - ( 14, 30) + * ( 13, 29) - ( 15, 31) + * ( 12, 28) + ( 14, 30) + * ( 13, 29) + ( 15, 31) + * ( 14, 30) = j@ * ( 2, 3) + * ( 15, 31) = j@ * ( 2, 3) + */ + + INV_TOPJ (t_re.val[0], t_im.val[0], x_re.val[0], x_im.val[0], x_re.val[1], x_im.val[1]); + INV_TOPJm(t_re.val[1], t_im.val[1], x_re.val[2], x_im.val[2], x_re.val[3], x_im.val[3]); + + INV_TOPJ (t_re.val[2], t_im.val[2], y_re.val[0], y_im.val[0], y_re.val[1], y_im.val[1]); + INV_TOPJm(t_re.val[3], t_im.val[3], y_re.val[2], y_im.val[2], y_re.val[3], y_im.val[3]); + + vloadx2(s_re_im, &fpr_tab4[k4]); + k4 += 4; + + INV_BOTJ_LANE (x_re.val[1], x_im.val[1], t_re.val[0], t_im.val[0], s_re_im.val[0]); + INV_BOTJm_LANE(x_re.val[3], x_im.val[3], t_re.val[1], t_im.val[1], s_re_im.val[0]); + + INV_BOTJ_LANE (y_re.val[1], y_im.val[1], t_re.val[2], t_im.val[2], s_re_im.val[1]); + INV_BOTJm_LANE(y_re.val[3], y_im.val[3], t_re.val[3], t_im.val[3], s_re_im.val[1]); + + /* + * ( 0, 16) - ( 4, 20) + * ( 1, 17) - ( 5, 21) + * ( 0, 16) + ( 4, 20) + * ( 1, 17) + ( 5, 21) + * ( 4, 20) = @ * ( 0, 1) + * ( 5, 21) = @ * ( 0, 1) + * + * ( 2, 18) - ( 6, 22) + * ( 3, 19) - ( 7, 23) + * ( 2, 18) + ( 6, 22) + * ( 3, 19) + ( 7, 23) + * ( 6, 22) = @ * ( 0, 1) + * ( 7, 23) = @ * ( 0, 1) + * + * ( 8, 24) - ( 12, 28) + * ( 9, 25) - ( 13, 29) + * ( 8, 24) + ( 12, 28) + * ( 9, 25) + ( 13, 29) + * ( 12, 28) = j@ * ( 0, 1) + * ( 13, 29) = j@ * ( 0, 1) + * + * ( 10, 26) - ( 14, 30) + * ( 11, 27) - ( 15, 31) + * ( 10, 26) + ( 14, 30) + * ( 11, 27) + ( 15, 31) + * ( 14, 30) = j@ * ( 0, 1) + * ( 15, 31) = j@ * ( 0, 1) + */ + + INV_TOPJ (t_re.val[0], t_im.val[0], x_re.val[0], x_im.val[0], x_re.val[2], x_im.val[2]); + INV_TOPJ (t_re.val[1], t_im.val[1], x_re.val[1], x_im.val[1], x_re.val[3], x_im.val[3]); + + INV_TOPJm(t_re.val[2], t_im.val[2], y_re.val[0], y_im.val[0], y_re.val[2], y_im.val[2]); + INV_TOPJm(t_re.val[3], t_im.val[3], y_re.val[1], y_im.val[1], y_re.val[3], y_im.val[3]); + + vload(s_re_im.val[0], &fpr_tab3[k3]); + k3 += 2; + + INV_BOTJ_LANE(x_re.val[2], x_im.val[2], t_re.val[0], t_im.val[0], s_re_im.val[0]); + INV_BOTJ_LANE(x_re.val[3], x_im.val[3], t_re.val[1], t_im.val[1], s_re_im.val[0]); + + INV_BOTJm_LANE(y_re.val[2], y_im.val[2], t_re.val[2], t_im.val[2], s_re_im.val[0]); + INV_BOTJm_LANE(y_re.val[3], y_im.val[3], t_re.val[3], t_im.val[3], s_re_im.val[0]); + + /* + * ( 0, 16) - ( 8, 24) + * ( 1, 17) - ( 9, 25) + * ( 0, 16) + ( 8, 24) + * ( 1, 17) + ( 9, 25) + * ( 8, 24) = @ * ( 0, 1) + * ( 9, 25) = @ * ( 0, 1) + * + * ( 2, 18) - ( 10, 26) + * ( 3, 19) - ( 11, 27) + * ( 2, 18) + ( 10, 26) + * ( 3, 19) + ( 11, 27) + * ( 10, 26) = @ * ( 0, 1) + * ( 11, 27) = @ * ( 0, 1) + * + * ( 4, 20) - ( 12, 28) + * ( 5, 21) - ( 13, 29) + * ( 4, 20) + ( 12, 28) + * ( 5, 21) + ( 13, 29) + * ( 12, 28) = @ * ( 0, 1) + * ( 13, 29) = @ * ( 0, 1) + * + * ( 6, 22) - ( 14, 30) + * ( 7, 23) - ( 15, 31) + * ( 6, 22) + ( 14, 30) + * ( 7, 23) + ( 15, 31) + * ( 14, 30) = @ * ( 0, 1) + * ( 15, 31) = @ * ( 0, 1) + */ + + + if ( (j >> 4) & 1) + { + INV_TOPJmx4(t_re, t_im, x_re, x_im, y_re, y_im); + } + else + { + INV_TOPJx4(t_re, t_im, x_re, x_im, y_re, y_im); + } + + vload(s_re_im.val[0], &fpr_tab2[k2]); + k2 += 2 * ((j & 31) == 16); + + if (last) + { + vfmuln(s_re_im.val[0], s_re_im.val[0], fpr_p2_tab[logn]); + vfmulnx4(x_re, x_re, fpr_p2_tab[logn]); + vfmulnx4(x_im, x_im, fpr_p2_tab[logn]); + } + vstorex4(&f[j], x_re); + vstorex4(&f[j + hn], x_im); + + if (logn == 5) + { + // Special case in fpr_tab_log2 where re == im + vfmulx4_i(t_re, t_re, s_re_im.val[0]); + vfmulx4_i(t_im, t_im, s_re_im.val[0]); + + vfaddx4(y_re, t_im, t_re); + vfsubx4(y_im, t_im, t_re); + } + else + { + if ((j >> 4) & 1) + { + INV_BOTJm_LANEx4(y_re, y_im, t_re, t_im, s_re_im.val[0]); + } + else + { + INV_BOTJ_LANEx4(y_re, y_im, t_re, t_im, s_re_im.val[0]); + } + } + + vstorex4(&f[j + 8], y_re); + vstorex4(&f[j + 8 + hn], y_im); + } +} + +static + void ZfN(iFFT_logn1)(fpr *f, const unsigned logn, const unsigned last) +{ + // Total SIMD register 26 = 24 + 2 + float64x2x4_t a_re, a_im, b_re, b_im, t_re, t_im; // 24 + float64x2_t s_re_im; // 2 + + const unsigned n = 1 << logn; + const unsigned hn = n >> 1; + const unsigned ht = n >> 2; + + for (unsigned j = 0; j < ht; j+= 8) + { + vloadx4(a_re, &f[j]); + vloadx4(a_im, &f[j + hn]); + vloadx4(b_re, &f[j + ht]); + vloadx4(b_im, &f[j + ht + hn]); + + INV_TOPJx4(t_re, t_im, a_re, a_im, b_re, b_im); + + s_re_im = vld1q_dup_f64(&fpr_tab_log2[0]); + + if (last) + { + vfmuln(s_re_im, s_re_im, fpr_p2_tab[logn]); + vfmulnx4(a_re, a_re, fpr_p2_tab[logn]); + vfmulnx4(a_im, a_im, fpr_p2_tab[logn]); + } + + vstorex4(&f[j], a_re); + vstorex4(&f[j + hn], a_im); + + vfmulx4_i(t_re, t_re, s_re_im); + vfmulx4_i(t_im, t_im, s_re_im); + + vfaddx4(b_re, t_im, t_re); + vfsubx4(b_im, t_im, t_re); + + vstorex4(&f[j + ht], b_re); + vstorex4(&f[j + ht + hn], b_im); + } +} + +// static +// void ZfN(iFFT_logn2)(fpr *f, const unsigned logn, const unsigned level, unsigned last) +// { +// const unsigned int falcon_n = 1 << logn; +// const unsigned int hn = falcon_n >> 1; + +// // Total SIMD register: 26 = 16 + 8 + 2 +// float64x2x4_t t_re, t_im; // 8 +// float64x2x2_t x1_re, x2_re, x1_im, x2_im, +// y1_re, y2_re, y1_im, y2_im; // 16 +// float64x2_t s1_re_im, s2_re_im; // 2 + +// const fpr *fpr_inv_tab1 = NULL, *fpr_inv_tab2 = NULL; +// unsigned l, len, start, j, k1, k2; +// unsigned bar = logn - 4 - 2; +// unsigned Jm; + +// for (l = 4; l < logn - level - 1; l += 2) +// { +// len = 1 << l; +// last -= 1; +// fpr_inv_tab1 = fpr_table[bar--]; +// fpr_inv_tab2 = fpr_table[bar--]; +// k1 = 0; k2 = 0; + +// for (start = 0; start < hn; start += 1 << (l + 2)) +// { +// vload(s1_re_im, &fpr_inv_tab1[k1]); +// vload(s2_re_im, &fpr_inv_tab2[k2]); +// k1 += 2; +// k2 += 2 * ((start & 127) == 64); +// if (!last) +// { +// vfmuln(s2_re_im, s2_re_im, fpr_p2_tab[logn]); +// } +// Jm = (start >> (l+ 2)) & 1; +// for (j = start; j < start + len; j += 4) +// { +// /* +// Level 6 +// * ( 0, 64) - ( 16, 80) +// * ( 1, 65) - ( 17, 81) +// * ( 0, 64) + ( 16, 80) +// * ( 1, 65) + ( 17, 81) +// * ( 16, 80) = @ * ( 0, 1) +// * ( 17, 81) = @ * ( 0, 1) +// * +// * ( 2, 66) - ( 18, 82) +// * ( 3, 67) - ( 19, 83) +// * ( 2, 66) + ( 18, 82) +// * ( 3, 67) + ( 19, 83) +// * ( 18, 82) = @ * ( 0, 1) +// * ( 19, 83) = @ * ( 0, 1) +// * +// * ( 32, 96) - ( 48, 112) +// * ( 33, 97) - ( 49, 113) +// * ( 32, 96) + ( 48, 112) +// * ( 33, 97) + ( 49, 113) +// * ( 48, 112) = j@ * ( 0, 1) +// * ( 49, 113) = j@ * ( 0, 1) +// * +// * ( 34, 98) - ( 50, 114) +// * ( 35, 99) - ( 51, 115) +// * ( 34, 98) + ( 50, 114) +// * ( 35, 99) + ( 51, 115) +// * ( 50, 114) = j@ * ( 0, 1) +// * ( 51, 115) = j@ * ( 0, 1) +// */ +// // x1: 0 -> 4 | 64 -> 67 +// // y1: 16 -> 19 | 80 -> 81 +// // x2: 32 -> 35 | 96 -> 99 +// // y2: 48 -> 51 | 112 -> 115 +// vloadx2(x1_re, &f[j]); +// vloadx2(x1_im, &f[j + hn]); +// vloadx2(y1_re, &f[j + len]); +// vloadx2(y1_im, &f[j + len + hn]); + +// INV_TOPJ (t_re.val[0], t_im.val[0], x1_re.val[0], x1_im.val[0], y1_re.val[0], y1_im.val[0]); +// INV_TOPJ (t_re.val[1], t_im.val[1], x1_re.val[1], x1_im.val[1], y1_re.val[1], y1_im.val[1]); + +// vloadx2(x2_re, &f[j + 2*len]); +// vloadx2(x2_im, &f[j + 2*len + hn]); +// vloadx2(y2_re, &f[j + 3*len]); +// vloadx2(y2_im, &f[j + 3*len + hn]); + +// INV_TOPJm(t_re.val[2], t_im.val[2], x2_re.val[0], x2_im.val[0], y2_re.val[0], y2_im.val[0]); +// INV_TOPJm(t_re.val[3], t_im.val[3], x2_re.val[1], x2_im.val[1], y2_re.val[1], y2_im.val[1]); + +// INV_BOTJ_LANE (y1_re.val[0], y1_im.val[0], t_re.val[0], t_im.val[0], s1_re_im); +// INV_BOTJ_LANE (y1_re.val[1], y1_im.val[1], t_re.val[1], t_im.val[1], s1_re_im); + +// INV_BOTJm_LANE(y2_re.val[0], y2_im.val[0], t_re.val[2], t_im.val[2], s1_re_im); +// INV_BOTJm_LANE(y2_re.val[1], y2_im.val[1], t_re.val[3], t_im.val[3], s1_re_im); +// /* +// * Level 7 +// * ( 0, 64) - ( 32, 96) +// * ( 1, 65) - ( 33, 97) +// * ( 0, 64) + ( 32, 96) +// * ( 1, 65) + ( 33, 97) +// * ( 32, 96) = @ * ( 0, 1) +// * ( 33, 97) = @ * ( 0, 1) +// * +// * ( 2, 66) - ( 34, 98) +// * ( 3, 67) - ( 35, 99) +// * ( 2, 66) + ( 34, 98) +// * ( 3, 67) + ( 35, 99) +// * ( 34, 98) = @ * ( 0, 1) +// * ( 35, 99) = @ * ( 0, 1) +// * ---- +// * ( 16, 80) - ( 48, 112) +// * ( 17, 81) - ( 49, 113) +// * ( 16, 80) + ( 48, 112) +// * ( 17, 81) + ( 49, 113) +// * ( 48, 112) = @ * ( 0, 1) +// * ( 49, 113) = @ * ( 0, 1) +// * +// * ( 18, 82) - ( 50, 114) +// * ( 19, 83) - ( 51, 115) +// * ( 18, 82) + ( 50, 114) +// * ( 19, 83) + ( 51, 115) +// * ( 50, 114) = @ * ( 0, 1) +// * ( 51, 115) = @ * ( 0, 1) +// */ + +// if (Jm) +// { +// INV_TOPJm(t_re.val[0], t_im.val[0], x1_re.val[0], x1_im.val[0], x2_re.val[0], x2_im.val[0]); +// INV_TOPJm(t_re.val[1], t_im.val[1], x1_re.val[1], x1_im.val[1], x2_re.val[1], x2_im.val[1]); + +// INV_TOPJm(t_re.val[2], t_im.val[2], y1_re.val[0], y1_im.val[0], y2_re.val[0], y2_im.val[0]); +// INV_TOPJm(t_re.val[3], t_im.val[3], y1_re.val[1], y1_im.val[1], y2_re.val[1], y2_im.val[1]); + +// INV_BOTJm_LANE(x2_re.val[0], x2_im.val[0], t_re.val[0], t_im.val[0], s2_re_im); +// INV_BOTJm_LANE(x2_re.val[1], x2_im.val[1], t_re.val[1], t_im.val[1], s2_re_im); +// INV_BOTJm_LANE(y2_re.val[0], y2_im.val[0], t_re.val[2], t_im.val[2], s2_re_im); +// INV_BOTJm_LANE(y2_re.val[1], y2_im.val[1], t_re.val[3], t_im.val[3], s2_re_im); +// } +// else +// { +// INV_TOPJ(t_re.val[0], t_im.val[0], x1_re.val[0], x1_im.val[0], x2_re.val[0], x2_im.val[0]); +// INV_TOPJ(t_re.val[1], t_im.val[1], x1_re.val[1], x1_im.val[1], x2_re.val[1], x2_im.val[1]); + +// INV_TOPJ(t_re.val[2], t_im.val[2], y1_re.val[0], y1_im.val[0], y2_re.val[0], y2_im.val[0]); +// INV_TOPJ(t_re.val[3], t_im.val[3], y1_re.val[1], y1_im.val[1], y2_re.val[1], y2_im.val[1]); + +// INV_BOTJ_LANE(x2_re.val[0], x2_im.val[0], t_re.val[0], t_im.val[0], s2_re_im); +// INV_BOTJ_LANE(x2_re.val[1], x2_im.val[1], t_re.val[1], t_im.val[1], s2_re_im); +// INV_BOTJ_LANE(y2_re.val[0], y2_im.val[0], t_re.val[2], t_im.val[2], s2_re_im); +// INV_BOTJ_LANE(y2_re.val[1], y2_im.val[1], t_re.val[3], t_im.val[3], s2_re_im); +// } + +// vstorex2(&f[j + 2*len], x2_re); +// vstorex2(&f[j + 2*len + hn], x2_im); + +// vstorex2(&f[j + 3*len], y2_re); +// vstorex2(&f[j + 3*len + hn], y2_im); + +// if (!last) +// { +// vfmuln(x1_re.val[0], x1_re.val[0], fpr_p2_tab[logn]); +// vfmuln(x1_re.val[1], x1_re.val[1], fpr_p2_tab[logn]); +// vfmuln(x1_im.val[0], x1_im.val[0], fpr_p2_tab[logn]); +// vfmuln(x1_im.val[1], x1_im.val[1], fpr_p2_tab[logn]); + +// vfmuln(y1_re.val[0], y1_re.val[0], fpr_p2_tab[logn]); +// vfmuln(y1_re.val[1], y1_re.val[1], fpr_p2_tab[logn]); +// vfmuln(y1_im.val[0], y1_im.val[0], fpr_p2_tab[logn]); +// vfmuln(y1_im.val[1], y1_im.val[1], fpr_p2_tab[logn]); +// } + +// vstorex2(&f[j], x1_re); +// vstorex2(&f[j + hn], x1_im); + +// vstorex2(&f[j + len], y1_re); +// vstorex2(&f[j + len + hn], y1_im); + +// } +// } +// } +// } + + + +static + void ZfN(iFFT_logn2)(fpr *f, const unsigned logn, const unsigned level, unsigned last) +{ + const unsigned int falcon_n = 1 << logn; + const unsigned int hn = falcon_n >> 1; + + // Total SIMD register: 26 = 16 + 8 + 2 + float64x2x4_t t_re, t_im; // 8 + float64x2x2_t x1_re, x2_re, x1_im, x2_im, + y1_re, y2_re, y1_im, y2_im; // 16 + float64x2_t s1_re_im, s2_re_im; // 2 + + const fpr *fpr_inv_tab1 = NULL, *fpr_inv_tab2 = NULL; + unsigned l, len, start, j, k1, k2; + unsigned bar = logn - 4; + + for (l = 4; l < logn - level - 1; l += 2) + { + len = 1 << l; + last -= 1; + fpr_inv_tab1 = fpr_table[bar--]; + fpr_inv_tab2 = fpr_table[bar--]; + k1 = 0; k2 = 0; + + for (start = 0; start < hn; start += 1 << (l + 2)) + { + vload(s1_re_im, &fpr_inv_tab1[k1]); + vload(s2_re_im, &fpr_inv_tab2[k2]); + k1 += 2; + k2 += 2 * ((start & 127) == 64); + if (!last) + { + vfmuln(s2_re_im, s2_re_im, fpr_p2_tab[logn]); + } + for (j = start; j < start + len; j += 4) + { + /* + Level 6 + * ( 0, 64) - ( 16, 80) + * ( 1, 65) - ( 17, 81) + * ( 0, 64) + ( 16, 80) + * ( 1, 65) + ( 17, 81) + * ( 16, 80) = @ * ( 0, 1) + * ( 17, 81) = @ * ( 0, 1) + * + * ( 2, 66) - ( 18, 82) + * ( 3, 67) - ( 19, 83) + * ( 2, 66) + ( 18, 82) + * ( 3, 67) + ( 19, 83) + * ( 18, 82) = @ * ( 0, 1) + * ( 19, 83) = @ * ( 0, 1) + * + * ( 32, 96) - ( 48, 112) + * ( 33, 97) - ( 49, 113) + * ( 32, 96) + ( 48, 112) + * ( 33, 97) + ( 49, 113) + * ( 48, 112) = j@ * ( 0, 1) + * ( 49, 113) = j@ * ( 0, 1) + * + * ( 34, 98) - ( 50, 114) + * ( 35, 99) - ( 51, 115) + * ( 34, 98) + ( 50, 114) + * ( 35, 99) + ( 51, 115) + * ( 50, 114) = j@ * ( 0, 1) + * ( 51, 115) = j@ * ( 0, 1) + */ + // x1: 0 -> 4 | 64 -> 67 + // y1: 16 -> 19 | 80 -> 81 + // x2: 32 -> 35 | 96 -> 99 + // y2: 48 -> 51 | 112 -> 115 + vloadx2(x1_re, &f[j]); + vloadx2(x1_im, &f[j + hn]); + vloadx2(y1_re, &f[j + len]); + vloadx2(y1_im, &f[j + len + hn]); + + INV_TOPJ (t_re.val[0], t_im.val[0], x1_re.val[0], x1_im.val[0], y1_re.val[0], y1_im.val[0]); + INV_TOPJ (t_re.val[1], t_im.val[1], x1_re.val[1], x1_im.val[1], y1_re.val[1], y1_im.val[1]); + + vloadx2(x2_re, &f[j + 2*len]); + vloadx2(x2_im, &f[j + 2*len + hn]); + vloadx2(y2_re, &f[j + 3*len]); + vloadx2(y2_im, &f[j + 3*len + hn]); + + INV_TOPJm(t_re.val[2], t_im.val[2], x2_re.val[0], x2_im.val[0], y2_re.val[0], y2_im.val[0]); + INV_TOPJm(t_re.val[3], t_im.val[3], x2_re.val[1], x2_im.val[1], y2_re.val[1], y2_im.val[1]); + + INV_BOTJ_LANE (y1_re.val[0], y1_im.val[0], t_re.val[0], t_im.val[0], s1_re_im); + INV_BOTJ_LANE (y1_re.val[1], y1_im.val[1], t_re.val[1], t_im.val[1], s1_re_im); + + INV_BOTJm_LANE(y2_re.val[0], y2_im.val[0], t_re.val[2], t_im.val[2], s1_re_im); + INV_BOTJm_LANE(y2_re.val[1], y2_im.val[1], t_re.val[3], t_im.val[3], s1_re_im); + /* + * Level 7 + * ( 0, 64) - ( 32, 96) + * ( 1, 65) - ( 33, 97) + * ( 0, 64) + ( 32, 96) + * ( 1, 65) + ( 33, 97) + * ( 32, 96) = @ * ( 0, 1) + * ( 33, 97) = @ * ( 0, 1) + * + * ( 2, 66) - ( 34, 98) + * ( 3, 67) - ( 35, 99) + * ( 2, 66) + ( 34, 98) + * ( 3, 67) + ( 35, 99) + * ( 34, 98) = @ * ( 0, 1) + * ( 35, 99) = @ * ( 0, 1) + * ---- + * ( 16, 80) - ( 48, 112) + * ( 17, 81) - ( 49, 113) + * ( 16, 80) + ( 48, 112) + * ( 17, 81) + ( 49, 113) + * ( 48, 112) = @ * ( 0, 1) + * ( 49, 113) = @ * ( 0, 1) + * + * ( 18, 82) - ( 50, 114) + * ( 19, 83) - ( 51, 115) + * ( 18, 82) + ( 50, 114) + * ( 19, 83) + ( 51, 115) + * ( 50, 114) = @ * ( 0, 1) + * ( 51, 115) = @ * ( 0, 1) + */ + + + INV_TOPJ(t_re.val[0], t_im.val[0], x1_re.val[0], x1_im.val[0], x2_re.val[0], x2_im.val[0]); + INV_TOPJ(t_re.val[1], t_im.val[1], x1_re.val[1], x1_im.val[1], x2_re.val[1], x2_im.val[1]); + + INV_TOPJ(t_re.val[2], t_im.val[2], y1_re.val[0], y1_im.val[0], y2_re.val[0], y2_im.val[0]); + INV_TOPJ(t_re.val[3], t_im.val[3], y1_re.val[1], y1_im.val[1], y2_re.val[1], y2_im.val[1]); + + INV_BOTJ_LANE(x2_re.val[0], x2_im.val[0], t_re.val[0], t_im.val[0], s2_re_im); + INV_BOTJ_LANE(x2_re.val[1], x2_im.val[1], t_re.val[1], t_im.val[1], s2_re_im); + INV_BOTJ_LANE(y2_re.val[0], y2_im.val[0], t_re.val[2], t_im.val[2], s2_re_im); + INV_BOTJ_LANE(y2_re.val[1], y2_im.val[1], t_re.val[3], t_im.val[3], s2_re_im); + + vstorex2(&f[j + 2*len], x2_re); + vstorex2(&f[j + 2*len + hn], x2_im); + + vstorex2(&f[j + 3*len], y2_re); + vstorex2(&f[j + 3*len + hn], y2_im); + + if (!last) + { + vfmuln(x1_re.val[0], x1_re.val[0], fpr_p2_tab[logn]); + vfmuln(x1_re.val[1], x1_re.val[1], fpr_p2_tab[logn]); + vfmuln(x1_im.val[0], x1_im.val[0], fpr_p2_tab[logn]); + vfmuln(x1_im.val[1], x1_im.val[1], fpr_p2_tab[logn]); + + vfmuln(y1_re.val[0], y1_re.val[0], fpr_p2_tab[logn]); + vfmuln(y1_re.val[1], y1_re.val[1], fpr_p2_tab[logn]); + vfmuln(y1_im.val[0], y1_im.val[0], fpr_p2_tab[logn]); + vfmuln(y1_im.val[1], y1_im.val[1], fpr_p2_tab[logn]); + } + + vstorex2(&f[j], x1_re); + vstorex2(&f[j + hn], x1_im); + + vstorex2(&f[j + len], y1_re); + vstorex2(&f[j + len + hn], y1_im); + + } + // + start += 1 << (l + 2); + if (start >= hn) break; + + vload(s1_re_im, &fpr_inv_tab1[k1]); + vload(s2_re_im, &fpr_inv_tab2[k2]); + k1 += 2; + k2 += 2 * ((start & 127) == 64); + if (!last) + { + vfmuln(s2_re_im, s2_re_im, fpr_p2_tab[logn]); + } + + for (j = start; j < start + len; j += 4) + { + /* + Level 6 + * ( 0, 64) - ( 16, 80) + * ( 1, 65) - ( 17, 81) + * ( 0, 64) + ( 16, 80) + * ( 1, 65) + ( 17, 81) + * ( 16, 80) = @ * ( 0, 1) + * ( 17, 81) = @ * ( 0, 1) + * + * ( 2, 66) - ( 18, 82) + * ( 3, 67) - ( 19, 83) + * ( 2, 66) + ( 18, 82) + * ( 3, 67) + ( 19, 83) + * ( 18, 82) = @ * ( 0, 1) + * ( 19, 83) = @ * ( 0, 1) + * + * ( 32, 96) - ( 48, 112) + * ( 33, 97) - ( 49, 113) + * ( 32, 96) + ( 48, 112) + * ( 33, 97) + ( 49, 113) + * ( 48, 112) = j@ * ( 0, 1) + * ( 49, 113) = j@ * ( 0, 1) + * + * ( 34, 98) - ( 50, 114) + * ( 35, 99) - ( 51, 115) + * ( 34, 98) + ( 50, 114) + * ( 35, 99) + ( 51, 115) + * ( 50, 114) = j@ * ( 0, 1) + * ( 51, 115) = j@ * ( 0, 1) + */ + // x1: 0 -> 4 | 64 -> 67 + // y1: 16 -> 19 | 80 -> 81 + // x2: 32 -> 35 | 96 -> 99 + // y2: 48 -> 51 | 112 -> 115 + vloadx2(x1_re, &f[j]); + vloadx2(x1_im, &f[j + hn]); + vloadx2(y1_re, &f[j + len]); + vloadx2(y1_im, &f[j + len + hn]); + + INV_TOPJ (t_re.val[0], t_im.val[0], x1_re.val[0], x1_im.val[0], y1_re.val[0], y1_im.val[0]); + INV_TOPJ (t_re.val[1], t_im.val[1], x1_re.val[1], x1_im.val[1], y1_re.val[1], y1_im.val[1]); + + vloadx2(x2_re, &f[j + 2*len]); + vloadx2(x2_im, &f[j + 2*len + hn]); + vloadx2(y2_re, &f[j + 3*len]); + vloadx2(y2_im, &f[j + 3*len + hn]); + + INV_TOPJm(t_re.val[2], t_im.val[2], x2_re.val[0], x2_im.val[0], y2_re.val[0], y2_im.val[0]); + INV_TOPJm(t_re.val[3], t_im.val[3], x2_re.val[1], x2_im.val[1], y2_re.val[1], y2_im.val[1]); + + INV_BOTJ_LANE (y1_re.val[0], y1_im.val[0], t_re.val[0], t_im.val[0], s1_re_im); + INV_BOTJ_LANE (y1_re.val[1], y1_im.val[1], t_re.val[1], t_im.val[1], s1_re_im); + + INV_BOTJm_LANE(y2_re.val[0], y2_im.val[0], t_re.val[2], t_im.val[2], s1_re_im); + INV_BOTJm_LANE(y2_re.val[1], y2_im.val[1], t_re.val[3], t_im.val[3], s1_re_im); + /* + * Level 7 + * ( 0, 64) - ( 32, 96) + * ( 1, 65) - ( 33, 97) + * ( 0, 64) + ( 32, 96) + * ( 1, 65) + ( 33, 97) + * ( 32, 96) = @ * ( 0, 1) + * ( 33, 97) = @ * ( 0, 1) + * + * ( 2, 66) - ( 34, 98) + * ( 3, 67) - ( 35, 99) + * ( 2, 66) + ( 34, 98) + * ( 3, 67) + ( 35, 99) + * ( 34, 98) = @ * ( 0, 1) + * ( 35, 99) = @ * ( 0, 1) + * ---- + * ( 16, 80) - ( 48, 112) + * ( 17, 81) - ( 49, 113) + * ( 16, 80) + ( 48, 112) + * ( 17, 81) + ( 49, 113) + * ( 48, 112) = @ * ( 0, 1) + * ( 49, 113) = @ * ( 0, 1) + * + * ( 18, 82) - ( 50, 114) + * ( 19, 83) - ( 51, 115) + * ( 18, 82) + ( 50, 114) + * ( 19, 83) + ( 51, 115) + * ( 50, 114) = @ * ( 0, 1) + * ( 51, 115) = @ * ( 0, 1) + */ + + INV_TOPJm(t_re.val[0], t_im.val[0], x1_re.val[0], x1_im.val[0], x2_re.val[0], x2_im.val[0]); + INV_TOPJm(t_re.val[1], t_im.val[1], x1_re.val[1], x1_im.val[1], x2_re.val[1], x2_im.val[1]); + + INV_TOPJm(t_re.val[2], t_im.val[2], y1_re.val[0], y1_im.val[0], y2_re.val[0], y2_im.val[0]); + INV_TOPJm(t_re.val[3], t_im.val[3], y1_re.val[1], y1_im.val[1], y2_re.val[1], y2_im.val[1]); + + INV_BOTJm_LANE(x2_re.val[0], x2_im.val[0], t_re.val[0], t_im.val[0], s2_re_im); + INV_BOTJm_LANE(x2_re.val[1], x2_im.val[1], t_re.val[1], t_im.val[1], s2_re_im); + INV_BOTJm_LANE(y2_re.val[0], y2_im.val[0], t_re.val[2], t_im.val[2], s2_re_im); + INV_BOTJm_LANE(y2_re.val[1], y2_im.val[1], t_re.val[3], t_im.val[3], s2_re_im); + + vstorex2(&f[j + 2*len], x2_re); + vstorex2(&f[j + 2*len + hn], x2_im); + + vstorex2(&f[j + 3*len], y2_re); + vstorex2(&f[j + 3*len + hn], y2_im); + + if (!last) + { + vfmuln(x1_re.val[0], x1_re.val[0], fpr_p2_tab[logn]); + vfmuln(x1_re.val[1], x1_re.val[1], fpr_p2_tab[logn]); + vfmuln(x1_im.val[0], x1_im.val[0], fpr_p2_tab[logn]); + vfmuln(x1_im.val[1], x1_im.val[1], fpr_p2_tab[logn]); + + vfmuln(y1_re.val[0], y1_re.val[0], fpr_p2_tab[logn]); + vfmuln(y1_re.val[1], y1_re.val[1], fpr_p2_tab[logn]); + vfmuln(y1_im.val[0], y1_im.val[0], fpr_p2_tab[logn]); + vfmuln(y1_im.val[1], y1_im.val[1], fpr_p2_tab[logn]); + } + + vstorex2(&f[j], x1_re); + vstorex2(&f[j + hn], x1_im); + + vstorex2(&f[j + len], y1_re); + vstorex2(&f[j + len + hn], y1_im); + + } + // + } + } +} + + +/* +* Support logn from [1, 10] +* Can be easily extended to logn > 10 +*/ +void ZfN(iFFT)(fpr *f, const unsigned logn) +{ + const unsigned level = (logn - 5) & 1; + + switch (logn) + { + case 2: + ZfN(iFFT_log2)(f); + break; + + case 3: + ZfN(iFFT_log3)(f); + break; + + case 4: + ZfN(iFFT_log4)(f); + break; + + case 5: + ZfN(iFFT_log5)(f, 5, 1); + break; + + case 6: + ZfN(iFFT_log5)(f, logn, 0); + ZfN(iFFT_logn1)(f, logn, 1); + break; + + case 7: + case 9: + ZfN(iFFT_log5)(f, logn, 0); + ZfN(iFFT_logn2)(f, logn, level, 1); + break; + + case 8: + case 10: + ZfN(iFFT_log5)(f, logn, 0); + ZfN(iFFT_logn2)(f, logn, level, 0); + ZfN(iFFT_logn1)(f, logn, 1); + break; + + default: + break; + } +} +#endif + +// generic fft stuff + +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +void reim_ctwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim); +void reim_citwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim); + +void reim_fft16_ref(double* dre, double* dim, const void* pom); +void fill_reim_fft16_omegas(const double entry_pwr, double** omg); +void reim_fft8_ref(double* dre, double* dim, const void* pom); +void reim_twiddle_fft_ref(uint64_t h, double* re, double* im, double om[2]); +void fill_reim_twiddle_fft_ref(const double s, double** omg); +void reim_bitwiddle_fft_ref(uint64_t h, double* re, double* im, double om[4]); +void fill_reim_bitwiddle_fft_ref(const double s, double** omg); +void reim_fft_rec_16_ref(uint64_t m, double* re, double* im, double** omg); +void fill_reim_fft_rec_16_omegas(uint64_t m, double entry_pwr, double** omg); + +void fill_reim_twiddle_fft_omegas_ref(const double rs0, double** omg) { + (*omg)[0] = cos(2 * M_PI * rs0); + (*omg)[1] = sin(2 * M_PI * rs0); + *omg += 2; +} + +void fill_reim_bitwiddle_fft_omegas_ref(const double rs0, double** omg) { + double rs1 = 2. * rs0; + (*omg)[0] = cos(2 * M_PI * rs1); + (*omg)[1] = sin(2 * M_PI * rs1); + (*omg)[2] = cos(2 * M_PI * rs0); + (*omg)[3] = sin(2 * M_PI * rs0); + *omg += 4; +} + +#define reim_fft16_f reim_fft16_neon +#define reim_fft16_pom_offset 16 +#define fill_reim_fft16_omegas_f fill_reim_fft16_omegas_neon +// currently, m=4 uses the ref implem (corner case with low impact) TODO!! +#define reim_fft8_f reim_fft8_ref +#define reim_fft8_pom_offset 8 +#define fill_reim_fft8_omegas_f fill_reim_fft8_omegas +// currently, m=4 uses the ref implem (corner case with low impact) TODO!! +#define reim_fft4_f reim_fft4_ref +#define reim_fft4_pom_offset 4 +#define fill_reim_fft4_omegas_f fill_reim_fft4_omegas +// m = 2 will use the ref implem, since intrinsics don't provide any speed-up +#define reim_fft2_f reim_fft2_ref +#define reim_fft2_pom_offset 2 +#define fill_reim_fft2_omegas_f fill_reim_fft2_omegas + +// neon twiddle use the same omegas layout as the ref implem +#define reim_twiddle_fft_f reim_twiddle_fft_neon +#define reim_twiddle_fft_pom_offset 2 +#define fill_reim_twiddle_fft_omegas_f fill_reim_twiddle_fft_omegas_ref + +// neon bi-twiddle use the same omegas layout as the ref implem +#define reim_bitwiddle_fft_f reim_bitwiddle_fft_neon +#define reim_bitwiddle_fft_pom_offset 4 +#define fill_reim_bitwiddle_fft_omegas_f fill_reim_bitwiddle_fft_omegas_ref + +// template functions to produce +#define reim_fft_bfs_16_f reim_fft_bfs_16_neon +#define fill_reim_fft_bfs_16_omegas_f fill_reim_fft_bfs_16_omegas_neon +#define reim_fft_rec_16_f reim_fft_rec_16_neon +#define fill_reim_fft_rec_16_omegas_f fill_reim_fft_rec_16_omegas_neon +#define reim_fft_f reim_fft_neon +#define fill_reim_fft_omegas_f fill_reim_fft_omegas_neon + +#include "reim_fft_core_template.h" + +EXPORT REIM_FFT_PRECOMP* new_reim_fft_precomp_neon(uint32_t m, uint32_t num_buffers) { + const uint64_t OMG_SPACE = ceilto64b(2 * m * sizeof(double)); + const uint64_t BUF_SIZE = ceilto64b(2 * m * sizeof(double)); + void* reps = malloc(sizeof(REIM_FFT_PRECOMP) // base + + 63 // padding + + OMG_SPACE // tables //TODO 16? + + num_buffers * BUF_SIZE // buffers + ); + uint64_t aligned_addr = ceilto64b((uint64_t)(reps) + sizeof(REIM_FFT_PRECOMP)); + REIM_FFT_PRECOMP* r = (REIM_FFT_PRECOMP*)reps; + r->m = m; + r->buf_size = BUF_SIZE; + r->powomegas = (double*)aligned_addr; + r->aligned_buffers = (void*)(aligned_addr + OMG_SPACE); + // fill in powomegas + double* omg = (double*)r->powomegas; + fill_reim_fft_omegas_f(m, 0.25, &omg); + if (((uint64_t)omg) - aligned_addr > OMG_SPACE) abort(); + // dispatch the right implementation + //{ + // if (CPU_SUPPORTS("fma")) { + // r->function = reim_fft_avx2_fma; + // } else { + // r->function = reim_fft_ref; + // } + //} + r->function = reim_fft_f; + return reps; +} diff --git a/spqlios/lib/spqlios/reim/reim_fft_private.h b/spqlios/lib/spqlios/reim/reim_fft_private.h new file mode 100644 index 0000000..3fb456f --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fft_private.h @@ -0,0 +1,101 @@ +#ifndef SPQLIOS_REIM_FFT_PRIVATE_H +#define SPQLIOS_REIM_FFT_PRIVATE_H + +#include "../commons_private.h" +#include "reim_fft.h" + +#define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)]) + +typedef struct reim_twiddle_precomp REIM_FFTVEC_TWIDDLE_PRECOMP; +typedef struct reim_bitwiddle_precomp REIM_FFTVEC_BITWIDDLE_PRECOMP; + +typedef void (*FFT_FUNC)(const REIM_FFT_PRECOMP*, double*); +typedef void (*IFFT_FUNC)(const REIM_IFFT_PRECOMP*, double*); +typedef void (*FFTVEC_MUL_FUNC)(const REIM_FFTVEC_MUL_PRECOMP*, double*, const double*, const double*); +typedef void (*FFTVEC_ADDMUL_FUNC)(const REIM_FFTVEC_ADDMUL_PRECOMP*, double*, const double*, const double*); + +typedef void (*FROM_ZNX32_FUNC)(const REIM_FROM_ZNX32_PRECOMP*, void*, const int32_t*); +typedef void (*FROM_ZNX64_FUNC)(const REIM_FROM_ZNX64_PRECOMP*, void*, const int64_t*); +typedef void (*FROM_TNX32_FUNC)(const REIM_FROM_TNX32_PRECOMP*, void*, const int32_t*); +typedef void (*TO_TNX32_FUNC)(const REIM_TO_TNX32_PRECOMP*, int32_t*, const void*); +typedef void (*TO_TNX_FUNC)(const REIM_TO_TNX_PRECOMP*, double*, const double*); +typedef void (*TO_ZNX64_FUNC)(const REIM_TO_ZNX64_PRECOMP*, int64_t*, const void*); +typedef void (*FFTVEC_TWIDDLE_FUNC)(const REIM_FFTVEC_TWIDDLE_PRECOMP*, void*, const void*, const void*); +typedef void (*FFTVEC_BITWIDDLE_FUNC)(const REIM_FFTVEC_BITWIDDLE_PRECOMP*, void*, uint64_t, const void*); + +typedef struct reim_fft_precomp { + FFT_FUNC function; + int64_t m; ///< complex dimension warning: reim uses n=2N=4m + uint64_t buf_size; ///< size of aligned_buffers (multiple of 64B) + double* powomegas; ///< 64B aligned + void* aligned_buffers; ///< 64B aligned +} REIM_FFT_PRECOMP; + +typedef struct reim_ifft_precomp { + IFFT_FUNC function; + int64_t m; // warning: reim uses n=2N=4m + uint64_t buf_size; ///< size of aligned_buffers (multiple of 64B) + double* powomegas; + void* aligned_buffers; +} REIM_IFFT_PRECOMP; + +typedef struct reim_mul_precomp { + FFTVEC_MUL_FUNC function; + int64_t m; +} REIM_FFTVEC_MUL_PRECOMP; + +typedef struct reim_addmul_precomp { + FFTVEC_ADDMUL_FUNC function; + int64_t m; +} REIM_FFTVEC_ADDMUL_PRECOMP; + + +struct reim_from_znx32_precomp { + FROM_ZNX32_FUNC function; + int64_t m; +}; + +struct reim_from_znx64_precomp { + FROM_ZNX64_FUNC function; + int64_t m; +}; + +struct reim_from_tnx32_precomp { + FROM_TNX32_FUNC function; + int64_t m; +}; + +struct reim_to_tnx32_precomp { + TO_TNX32_FUNC function; + int64_t m; + double divisor; +}; + +struct reim_to_tnx_precomp { + TO_TNX_FUNC function; + int64_t m; + double divisor; + uint32_t log2overhead; + double add_cst; + uint64_t mask_and; + uint64_t mask_or; + double sub_cst; +}; + +struct reim_to_znx64_precomp { + TO_ZNX64_FUNC function; + int64_t m; + double divisor; +}; + +struct reim_twiddle_precomp { + FFTVEC_TWIDDLE_FUNC function; + int64_t m; +}; + +struct reim_bitwiddle_precomp { + FFTVEC_BITWIDDLE_FUNC function; + int64_t m; +}; + +#endif // SPQLIOS_REIM_FFT_PRIVATE_H diff --git a/spqlios/lib/spqlios/reim/reim_fft_ref.c b/spqlios/lib/spqlios/reim/reim_fft_ref.c new file mode 100644 index 0000000..616e954 --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fft_ref.c @@ -0,0 +1,437 @@ +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +EXPORT void reim_fft_simple(uint32_t m, void* data) { + static REIM_FFT_PRECOMP* p[31] = {0}; + REIM_FFT_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_fft_precomp(m, 0); + (*f)->function(*f, data); +} + +EXPORT void reim_ifft_simple(uint32_t m, void* data) { + static REIM_IFFT_PRECOMP* p[31] = {0}; + REIM_IFFT_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_ifft_precomp(m, 0); + (*f)->function(*f, data); +} + +EXPORT void reim_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b) { + static REIM_FFTVEC_MUL_PRECOMP* p[31] = {0}; + REIM_FFTVEC_MUL_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_fftvec_mul_precomp(m); + (*f)->function(*f, r, a, b); +} + +EXPORT void reim_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b) { + static REIM_FFTVEC_ADDMUL_PRECOMP* p[31] = {0}; + REIM_FFTVEC_ADDMUL_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_fftvec_addmul_precomp(m); + (*f)->function(*f, r, a, b); +} + +void reim_ctwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim) { + double newrt = *rb * omre - *ib * omim; + double newit = *rb * omim + *ib * omre; + *rb = *ra - newrt; + *ib = *ia - newit; + *ra = *ra + newrt; + *ia = *ia + newit; +} + +// i (omre + i omim) = -omim + i omre +void reim_citwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim) { + double newrt = -*rb * omim - *ib * omre; + double newit = *rb * omre - *ib * omim; + *rb = *ra - newrt; + *ib = *ia - newit; + *ra = *ra + newrt; + *ia = *ia + newit; +} + +void reim_fft16_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omre = om[0]; + double omim = om[1]; + reim_ctwiddle(&dre[0], &dim[0], &dre[8], &dim[8], omre, omim); + reim_ctwiddle(&dre[1], &dim[1], &dre[9], &dim[9], omre, omim); + reim_ctwiddle(&dre[2], &dim[2], &dre[10], &dim[10], omre, omim); + reim_ctwiddle(&dre[3], &dim[3], &dre[11], &dim[11], omre, omim); + reim_ctwiddle(&dre[4], &dim[4], &dre[12], &dim[12], omre, omim); + reim_ctwiddle(&dre[5], &dim[5], &dre[13], &dim[13], omre, omim); + reim_ctwiddle(&dre[6], &dim[6], &dre[14], &dim[14], omre, omim); + reim_ctwiddle(&dre[7], &dim[7], &dre[15], &dim[15], omre, omim); + } + { + double omre = om[2]; + double omim = om[3]; + reim_ctwiddle(&dre[0], &dim[0], &dre[4], &dim[4], omre, omim); + reim_ctwiddle(&dre[1], &dim[1], &dre[5], &dim[5], omre, omim); + reim_ctwiddle(&dre[2], &dim[2], &dre[6], &dim[6], omre, omim); + reim_ctwiddle(&dre[3], &dim[3], &dre[7], &dim[7], omre, omim); + reim_citwiddle(&dre[8], &dim[8], &dre[12], &dim[12], omre, omim); + reim_citwiddle(&dre[9], &dim[9], &dre[13], &dim[13], omre, omim); + reim_citwiddle(&dre[10], &dim[10], &dre[14], &dim[14], omre, omim); + reim_citwiddle(&dre[11], &dim[11], &dre[15], &dim[15], omre, omim); + } + { + double omare = om[4]; + double omaim = om[5]; + double ombre = om[6]; + double ombim = om[7]; + reim_ctwiddle(&dre[0], &dim[0], &dre[2], &dim[2], omare, omaim); + reim_ctwiddle(&dre[1], &dim[1], &dre[3], &dim[3], omare, omaim); + reim_citwiddle(&dre[4], &dim[4], &dre[6], &dim[6], omare, omaim); + reim_citwiddle(&dre[5], &dim[5], &dre[7], &dim[7], omare, omaim); + reim_ctwiddle(&dre[8], &dim[8], &dre[10], &dim[10], ombre, ombim); + reim_ctwiddle(&dre[9], &dim[9], &dre[11], &dim[11], ombre, ombim); + reim_citwiddle(&dre[12], &dim[12], &dre[14], &dim[14], ombre, ombim); + reim_citwiddle(&dre[13], &dim[13], &dre[15], &dim[15], ombre, ombim); + } + { + double omare = om[8]; + double ombre = om[9]; + double omcre = om[10]; + double omdre = om[11]; + double omaim = om[12]; + double ombim = om[13]; + double omcim = om[14]; + double omdim = om[15]; + reim_ctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + reim_citwiddle(&dre[2], &dim[2], &dre[3], &dim[3], omare, omaim); + reim_ctwiddle(&dre[4], &dim[4], &dre[5], &dim[5], ombre, ombim); + reim_citwiddle(&dre[6], &dim[6], &dre[7], &dim[7], ombre, ombim); + reim_ctwiddle(&dre[8], &dim[8], &dre[9], &dim[9], omcre, omcim); + reim_citwiddle(&dre[10], &dim[10], &dre[11], &dim[11], omcre, omcim); + reim_ctwiddle(&dre[12], &dim[12], &dre[13], &dim[13], omdre, omdim); + reim_citwiddle(&dre[14], &dim[14], &dre[15], &dim[15], omdre, omdim); + } +} + +void fill_reim_fft16_omegas(const double entry_pwr, double** omg) { + const double j_pow = 1. / 8.; + const double k_pow = 1. / 16.; + const double pin = entry_pwr / 2.; + const double pin_2 = entry_pwr / 4.; + const double pin_4 = entry_pwr / 8.; + const double pin_8 = entry_pwr / 16.; + // 0 and 1 are real and imag of om + (*omg)[0] = cos(2. * M_PI * pin); + (*omg)[1] = sin(2. * M_PI * pin); + // 2 and 3 are real and imag of om^1/2 + (*omg)[2] = cos(2. * M_PI * (pin_2)); + (*omg)[3] = sin(2. * M_PI * (pin_2)); + // (4,5) and (6,7) are real and imag of om^1/4 and j.om^1/4 + (*omg)[4] = cos(2. * M_PI * (pin_4)); + (*omg)[5] = sin(2. * M_PI * (pin_4)); + (*omg)[6] = cos(2. * M_PI * (pin_4 + j_pow)); + (*omg)[7] = sin(2. * M_PI * (pin_4 + j_pow)); + // ((8,9,10,11),(12,13,14,15)) are 4 reals then 4 imag of om^1/8*(1,k,j,kj) + (*omg)[8] = cos(2. * M_PI * (pin_8)); + (*omg)[9] = cos(2. * M_PI * (pin_8 + j_pow)); + (*omg)[10] = cos(2. * M_PI * (pin_8 + k_pow)); + (*omg)[11] = cos(2. * M_PI * (pin_8 + j_pow + k_pow)); + (*omg)[12] = sin(2. * M_PI * (pin_8)); + (*omg)[13] = sin(2. * M_PI * (pin_8 + j_pow)); + (*omg)[14] = sin(2. * M_PI * (pin_8 + k_pow)); + (*omg)[15] = sin(2. * M_PI * (pin_8 + j_pow + k_pow)); + *omg += 16; +} + +void reim_fft8_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omre = om[0]; + double omim = om[1]; + reim_ctwiddle(&dre[0], &dim[0], &dre[4], &dim[4], omre, omim); + reim_ctwiddle(&dre[1], &dim[1], &dre[5], &dim[5], omre, omim); + reim_ctwiddle(&dre[2], &dim[2], &dre[6], &dim[6], omre, omim); + reim_ctwiddle(&dre[3], &dim[3], &dre[7], &dim[7], omre, omim); + } + { + double omare = om[2]; + double omaim = om[3]; + reim_ctwiddle(&dre[0], &dim[0], &dre[2], &dim[2], omare, omaim); + reim_ctwiddle(&dre[1], &dim[1], &dre[3], &dim[3], omare, omaim); + reim_citwiddle(&dre[4], &dim[4], &dre[6], &dim[6], omare, omaim); + reim_citwiddle(&dre[5], &dim[5], &dre[7], &dim[7], omare, omaim); + } + { + double omare = om[4]; + double ombre = om[5]; + double omaim = om[6]; + double ombim = om[7]; + reim_ctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + reim_citwiddle(&dre[2], &dim[2], &dre[3], &dim[3], omare, omaim); + reim_ctwiddle(&dre[4], &dim[4], &dre[5], &dim[5], ombre, ombim); + reim_citwiddle(&dre[6], &dim[6], &dre[7], &dim[7], ombre, ombim); + } +} + +void fill_reim_fft8_omegas(const double entry_pwr, double** omg) { + const double j_pow = 1. / 8.; + const double pin = entry_pwr / 2.; + const double pin_2 = entry_pwr / 4.; + const double pin_4 = entry_pwr / 8.; + // 0 and 1 are real and imag of om + (*omg)[0] = cos(2. * M_PI * pin); + (*omg)[1] = sin(2. * M_PI * pin); + // 2 and 3 are real and imag of om^1/2 + (*omg)[2] = cos(2. * M_PI * (pin_2)); + (*omg)[3] = sin(2. * M_PI * (pin_2)); + // (4,5) and (6,7) are real and imag of om^1/4 and j.om^1/4 + (*omg)[4] = cos(2. * M_PI * (pin_4)); + (*omg)[5] = cos(2. * M_PI * (pin_4 + j_pow)); + (*omg)[6] = sin(2. * M_PI * (pin_4)); + (*omg)[7] = sin(2. * M_PI * (pin_4 + j_pow)); + *omg += 8; +} + +void reim_fft4_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omare = om[0]; + double omaim = om[1]; + reim_ctwiddle(&dre[0], &dim[0], &dre[2], &dim[2], omare, omaim); + reim_ctwiddle(&dre[1], &dim[1], &dre[3], &dim[3], omare, omaim); + } + { + double omare = om[2]; + double omaim = om[3]; + reim_ctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + reim_citwiddle(&dre[2], &dim[2], &dre[3], &dim[3], omare, omaim); + } +} + +void fill_reim_fft4_omegas(const double entry_pwr, double** omg) { + const double pin = entry_pwr / 2.; + const double pin_2 = entry_pwr / 4.; + // 0 and 1 are real and imag of om + (*omg)[0] = cos(2. * M_PI * pin); + (*omg)[1] = sin(2. * M_PI * pin); + // 2 and 3 are real and imag of om^1/2 + (*omg)[2] = cos(2. * M_PI * (pin_2)); + (*omg)[3] = sin(2. * M_PI * (pin_2)); + *omg += 4; +} + +void reim_fft2_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omare = om[0]; + double omaim = om[1]; + reim_ctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + } +} + +void fill_reim_fft2_omegas(const double entry_pwr, double** omg) { + const double pin = entry_pwr / 2.; + // 0 and 1 are real and imag of om + (*omg)[0] = cos(2. * M_PI * pin); + (*omg)[1] = sin(2. * M_PI * pin); + *omg += 2; +} + +void reim_twiddle_fft_ref(uint64_t h, double* re, double* im, double om[2]) { + for (uint64_t i=0; i> 1; + // do the first twiddle iteration normally + reim_twiddle_fft_ref(h, re, im, *omg); + *omg += 2; + mm = h; + } + while (mm > 16) { + uint64_t h = mm >> 2; + for (uint64_t off = 0; off < m; off += mm) { + reim_bitwiddle_fft_ref(h, re + off, im + off, *omg); + *omg += 4; + } + mm = h; + } + if (mm!=16) abort(); // bug! + for (uint64_t off = 0; off < m; off += 16) { + reim_fft16_ref(re+off, im+off, *omg); + *omg += 16; + } +} + +void fill_reim_fft_bfs_16_omegas(uint64_t m, double entry_pwr, double** omg) { + uint64_t log2m = log2(m); + uint64_t mm = m; + double ss = entry_pwr; + if (log2m % 2 != 0) { + uint64_t h = mm >> 1; + double s = ss / 2.; + // do the first twiddle iteration normally + (*omg)[0] = cos(2 * M_PI * s); + (*omg)[1] = sin(2 * M_PI * s); + *omg += 2; + mm = h; + ss = s; + } + while (mm > 16) { + uint64_t h = mm >> 2; + double s = ss / 4.; + for (uint64_t off = 0; off < m; off += mm) { + double rs0 = s + fracrevbits(off/mm) / 4.; + double rs1 = 2. * rs0; + (*omg)[0] = cos(2 * M_PI * rs1); + (*omg)[1] = sin(2 * M_PI * rs1); + (*omg)[2] = cos(2 * M_PI * rs0); + (*omg)[3] = sin(2 * M_PI * rs0); + *omg += 4; + } + mm = h; + ss = s; + } + if (mm!=16) abort(); // bug! + for (uint64_t off = 0; off < m; off += 16) { + double s = ss + fracrevbits(off/16); + fill_reim_fft16_omegas(s, omg); + } +} + +void reim_fft_rec_16_ref(uint64_t m, double* re, double* im, double** omg) { + if (m <= 2048) return reim_fft_bfs_16_ref(m, re, im, omg); + const uint32_t h = m / 2; + reim_twiddle_fft_ref(h, re, im, *omg); + *omg += 2; + reim_fft_rec_16_ref(h, re, im, omg); + reim_fft_rec_16_ref(h, re + h, im + h, omg); +} + +void fill_reim_fft_rec_16_omegas(uint64_t m, double entry_pwr, double** omg) { + if (m <= 2048) return fill_reim_fft_bfs_16_omegas(m, entry_pwr, omg); + const uint64_t h = m / 2; + const double s = entry_pwr / 2; + (*omg)[0] = cos(2 * M_PI * s); + (*omg)[1] = sin(2 * M_PI * s); + *omg += 2; + fill_reim_fft_rec_16_omegas(h, s, omg); + fill_reim_fft_rec_16_omegas(h, s + 0.5, omg); +} + +void reim_fft_ref(const REIM_FFT_PRECOMP* precomp, double* dat) { + const int32_t m = precomp->m; + double* omg = precomp->powomegas; + double* re = dat; + double* im = dat+m; + if (m <= 16) { + switch (m) { + case 1: + return; + case 2: + return reim_fft2_ref(re, im, omg); + case 4: + return reim_fft4_ref(re, im, omg); + case 8: + return reim_fft8_ref(re, im, omg); + case 16: + return reim_fft16_ref(re, im, omg); + default: + abort(); // m is not a power of 2 + } + } + if (m <= 2048) return reim_fft_bfs_16_ref(m, re, im, &omg); + return reim_fft_rec_16_ref(m, re, im, &omg); +} + +EXPORT REIM_FFT_PRECOMP* new_reim_fft_precomp(uint32_t m, uint32_t num_buffers) { + const uint64_t OMG_SPACE = ceilto64b(2 * m * sizeof(double)); + const uint64_t BUF_SIZE = ceilto64b(2 * m * sizeof(double)); + void* reps = malloc(sizeof(REIM_FFT_PRECOMP) // base + + 63 // padding + + OMG_SPACE // tables //TODO 16? + + num_buffers * BUF_SIZE // buffers + ); + uint64_t aligned_addr = ceilto64b((uint64_t)(reps) + sizeof(REIM_FFT_PRECOMP)); + REIM_FFT_PRECOMP* r = (REIM_FFT_PRECOMP*)reps; + r->m = m; + r->buf_size = BUF_SIZE; + r->powomegas = (double*)aligned_addr; + r->aligned_buffers = (void*)(aligned_addr + OMG_SPACE); + // fill in powomegas + double* omg = (double*) r->powomegas; + if (m <= 16) { + switch (m) { + case 1: + break; + case 2: + fill_reim_fft2_omegas(0.25, &omg); + break; + case 4: + fill_reim_fft4_omegas(0.25, &omg); + break; + case 8: + fill_reim_fft8_omegas(0.25, &omg); + break; + case 16: + fill_reim_fft16_omegas(0.25, &omg); + break; + default: + abort(); // m is not a power of 2 + } + } else if (m <= 2048) { + fill_reim_fft_bfs_16_omegas(m, 0.25, &omg); + } else { + fill_reim_fft_rec_16_omegas(m, 0.25, &omg); + } + if (((uint64_t)omg) - aligned_addr > OMG_SPACE) abort(); + // dispatch the right implementation + { + if (CPU_SUPPORTS("fma")) { + r->function = reim_fft_avx2_fma; + } else { + r->function = reim_fft_ref; + } + } + return reps; +} + + +void reim_naive_fft(uint64_t m, double entry_pwr, double* re, double* im) { + if (m == 1) return; + // twiddle + const uint64_t h = m / 2; + const double s = entry_pwr / 2.; + const double sre = cos(2 * M_PI * s); + const double sim = sin(2 * M_PI * s); + for (uint64_t j = 0; j < h; ++j) { + double pre = re[h + j] * sre - im[h + j] * sim; + double pim = im[h + j] * sre + re[h + j] * sim; + re[h + j] = re[j] - pre; + im[h + j] = im[j] - pim; + re[j] += pre; + im[j] += pim; + } + reim_naive_fft(h, s, re, im); + reim_naive_fft(h, s + 0.5, re + h, im + h); +} diff --git a/spqlios/lib/spqlios/reim/reim_fftvec_addmul_fma.c b/spqlios/lib/spqlios/reim/reim_fftvec_addmul_fma.c new file mode 100644 index 0000000..ec7a350 --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fftvec_addmul_fma.c @@ -0,0 +1,75 @@ +#include +#include +#include + +#include "reim_fft_private.h" + +EXPORT void reim_fftvec_addmul_fma(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, + const double* b) { + const uint64_t m = precomp->m; + double* rr_ptr = r; + double* ri_ptr = r + m; + const double* ar_ptr = a; + const double* ai_ptr = a + m; + const double* br_ptr = b; + const double* bi_ptr = b + m; + + const double* const rend_ptr = ri_ptr; + while (rr_ptr != rend_ptr) { + __m256d rr = _mm256_loadu_pd(rr_ptr); + __m256d ri = _mm256_loadu_pd(ri_ptr); + const __m256d ar = _mm256_loadu_pd(ar_ptr); + const __m256d ai = _mm256_loadu_pd(ai_ptr); + const __m256d br = _mm256_loadu_pd(br_ptr); + const __m256d bi = _mm256_loadu_pd(bi_ptr); + + rr = _mm256_fmsub_pd(ai, bi, rr); + rr = _mm256_fmsub_pd(ar, br, rr); + ri = _mm256_fmadd_pd(ar, bi, ri); + ri = _mm256_fmadd_pd(ai, br, ri); + + _mm256_storeu_pd(rr_ptr, rr); + _mm256_storeu_pd(ri_ptr, ri); + + rr_ptr += 4; + ri_ptr += 4; + ar_ptr += 4; + ai_ptr += 4; + br_ptr += 4; + bi_ptr += 4; + } +} + +EXPORT void reim_fftvec_mul_fma(const REIM_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) { + const uint64_t m = precomp->m; + double* rr_ptr = r; + double* ri_ptr = r + m; + const double* ar_ptr = a; + const double* ai_ptr = a + m; + const double* br_ptr = b; + const double* bi_ptr = b + m; + + const double* const rend_ptr = ri_ptr; + while (rr_ptr != rend_ptr) { + const __m256d ar = _mm256_loadu_pd(ar_ptr); + const __m256d ai = _mm256_loadu_pd(ai_ptr); + const __m256d br = _mm256_loadu_pd(br_ptr); + const __m256d bi = _mm256_loadu_pd(bi_ptr); + + const __m256d t1 = _mm256_mul_pd(ai, bi); + const __m256d t2 = _mm256_mul_pd(ar, bi); + + __m256d rr = _mm256_fmsub_pd(ar, br, t1); + __m256d ri = _mm256_fmadd_pd(ai, br, t2); + + _mm256_storeu_pd(rr_ptr, rr); + _mm256_storeu_pd(ri_ptr, ri); + + rr_ptr += 4; + ri_ptr += 4; + ar_ptr += 4; + ai_ptr += 4; + br_ptr += 4; + bi_ptr += 4; + } +} diff --git a/spqlios/lib/spqlios/reim/reim_fftvec_addmul_ref.c b/spqlios/lib/spqlios/reim/reim_fftvec_addmul_ref.c new file mode 100644 index 0000000..df411f6 --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_fftvec_addmul_ref.c @@ -0,0 +1,54 @@ +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +EXPORT void reim_fftvec_addmul_ref(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, + const double* b) { + const uint64_t m = precomp->m; + for (uint64_t i = 0; i < m; ++i) { + double re = a[i] * b[i] - a[i + m] * b[i + m]; + double im = a[i] * b[i + m] + a[i + m] * b[i]; + r[i] += re; + r[i + m] += im; + } +} + +EXPORT void reim_fftvec_mul_ref(const REIM_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) { + const uint64_t m = precomp->m; + for (uint64_t i = 0; i < m; ++i) { + double re = a[i] * b[i] - a[i + m] * b[i + m]; + double im = a[i] * b[i + m] + a[i + m] * b[i]; + r[i] = re; + r[i + m] = im; + } +} + +EXPORT REIM_FFTVEC_ADDMUL_PRECOMP* new_reim_fftvec_addmul_precomp(uint32_t m) { + REIM_FFTVEC_ADDMUL_PRECOMP* reps = malloc(sizeof(REIM_FFTVEC_ADDMUL_PRECOMP)); + reps->m = m; + if (CPU_SUPPORTS("fma")) { + if (m >= 4) { + reps->function = reim_fftvec_addmul_fma; + } else { + reps->function = reim_fftvec_addmul_ref; + } + } else { + reps->function = reim_fftvec_addmul_ref; + } + return reps; +} + +EXPORT REIM_FFTVEC_MUL_PRECOMP* new_reim_fftvec_mul_precomp(uint32_t m) { + REIM_FFTVEC_MUL_PRECOMP* reps = malloc(sizeof(REIM_FFTVEC_MUL_PRECOMP)); + reps->m = m; + if (CPU_SUPPORTS("fma")) { + if (m >= 4) { + reps->function = reim_fftvec_mul_fma; + } else { + reps->function = reim_fftvec_mul_ref; + } + } else { + reps->function = reim_fftvec_mul_ref; + } + return reps; +} + diff --git a/spqlios/lib/spqlios/reim/reim_ifft16_avx_fma.s b/spqlios/lib/spqlios/reim/reim_ifft16_avx_fma.s new file mode 100644 index 0000000..4657c5d --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_ifft16_avx_fma.s @@ -0,0 +1,192 @@ +#rdi datare ptr +#rsi dataim ptr +#rdx om ptr +.globl reim_ifft16_avx_fma +reim_ifft16_avx_fma: +vmovupd (%rdi),%ymm0 # ra0 +vmovupd 0x20(%rdi),%ymm1 # ra4 +vmovupd 0x40(%rdi),%ymm2 # ra8 +vmovupd 0x60(%rdi),%ymm3 # ra12 +vmovupd (%rsi),%ymm4 # ia0 +vmovupd 0x20(%rsi),%ymm5 # ia4 +vmovupd 0x40(%rsi),%ymm6 # ia8 +vmovupd 0x60(%rsi),%ymm7 # ia12 + +1: +vmovupd 0x00(%rdx),%ymm12 +vmovupd 0x20(%rdx),%ymm13 + +vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) + +vunpckhpd %ymm1,%ymm0,%ymm4 # (0,1) -> (0,4) +vunpckhpd %ymm3,%ymm2,%ymm6 # (2,3) -> (2,6) +vunpckhpd %ymm9,%ymm8,%ymm5 # (8,9) -> (1,5) +vunpckhpd %ymm11,%ymm10,%ymm7 # (10,11) -> (3,7) +vunpcklpd %ymm1,%ymm0,%ymm0 +vunpcklpd %ymm3,%ymm2,%ymm2 +vunpcklpd %ymm9,%ymm8,%ymm1 +vunpcklpd %ymm11,%ymm10,%ymm3 + +# invctwiddle Re:(ymm0,ymm4) and Im:(ymm2,ymm6) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm1,ymm5) and Im:(ymm3,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm4,%ymm0,%ymm8 # retw +vsubpd %ymm5,%ymm1,%ymm9 # reitw +vsubpd %ymm6,%ymm2,%ymm10 # imtw +vsubpd %ymm7,%ymm3,%ymm11 # imitw +vaddpd %ymm4,%ymm0,%ymm0 +vaddpd %ymm5,%ymm1,%ymm1 +vaddpd %ymm6,%ymm2,%ymm2 +vaddpd %ymm7,%ymm3,%ymm3 +# multiply 8,9,10,11 by 12,13, result to: 4,5,6,7 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm10,%ymm13,%ymm4 # imtw.omai (tw) +vmulpd %ymm11,%ymm12,%ymm5 # imitw.omar (itw) +vmulpd %ymm8,%ymm13,%ymm6 # retw.omai (tw) +vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw) +vfmsub231pd %ymm8,%ymm12,%ymm4 # rprod0 (tw) +vfmadd231pd %ymm9,%ymm13,%ymm5 # rprod4 (itw) +vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 (tw) +vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw) + +vunpckhpd %ymm7,%ymm3,%ymm11 # (0,4) -> (0,1) +vunpckhpd %ymm5,%ymm1,%ymm9 # (2,6) -> (2,3) +vunpcklpd %ymm7,%ymm3,%ymm10 +vunpcklpd %ymm5,%ymm1,%ymm8 +vunpckhpd %ymm6,%ymm2,%ymm3 # (1,5) -> (8,9) +vunpckhpd %ymm4,%ymm0,%ymm1 # (3,7) -> (10,11) +vunpcklpd %ymm6,%ymm2,%ymm2 +vunpcklpd %ymm4,%ymm0,%ymm0 + +/* +vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 +*/ +2: +vmovupd 0x40(%rdx),%ymm12 +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omaiii'i' +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omarrr'r' + +/* +vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) +*/ + +# invctwiddle Re:(ymm0,ymm8) and Im:(ymm2,ymm10) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm1,ymm9) and Im:(ymm3,ymm11) with omega=(ymm12,ymm13) +vsubpd %ymm8,%ymm0,%ymm4 # retw +vsubpd %ymm9,%ymm1,%ymm5 # reitw +vsubpd %ymm10,%ymm2,%ymm6 # imtw +vsubpd %ymm11,%ymm3,%ymm7 # imitw +vaddpd %ymm8,%ymm0,%ymm0 +vaddpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm2,%ymm2 +vaddpd %ymm11,%ymm3,%ymm3 +# multiply 4,5,6,7 by 12,13, result to 8,9,10,11 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm6,%ymm13,%ymm8 # imtw.omai (tw) +vmulpd %ymm7,%ymm12,%ymm9 # imitw.omar (itw) +vmulpd %ymm4,%ymm13,%ymm10 # retw.omai (tw) +vmulpd %ymm5,%ymm12,%ymm11 # reitw.omar (itw) +vfmsub231pd %ymm4,%ymm12,%ymm8 # rprod0 (tw) +vfmadd231pd %ymm5,%ymm13,%ymm9 # rprod4 (itw) +vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 (tw) +vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw) + +vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 + +3: +vmovupd 0x60(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar + +# invctwiddle Re:(ymm0,ymm1) and Im:(ymm4,ymm5) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm2,ymm3) and Im:(ymm6,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm1,%ymm0,%ymm8 # retw +vsubpd %ymm3,%ymm2,%ymm9 # reitw +vsubpd %ymm5,%ymm4,%ymm10 # imtw +vsubpd %ymm7,%ymm6,%ymm11 # imitw +vaddpd %ymm1,%ymm0,%ymm0 +vaddpd %ymm3,%ymm2,%ymm2 +vaddpd %ymm5,%ymm4,%ymm4 +vaddpd %ymm7,%ymm6,%ymm6 +# multiply 8,9,10,11 by 12,13, result to 1,3,5,7 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm10,%ymm13,%ymm1 # imtw.omai (tw) +vmulpd %ymm11,%ymm12,%ymm3 # imitw.omar (itw) +vmulpd %ymm8,%ymm13,%ymm5 # retw.omai (tw) +vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw) +vfmsub231pd %ymm8,%ymm12,%ymm1 # rprod0 (tw) +vfmadd231pd %ymm9,%ymm13,%ymm3 # rprod4 (itw) +vfmadd231pd %ymm10,%ymm12,%ymm5 # iprod0 (tw) +vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw) + +4: +vmovupd 0x70(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar + +# invctwiddle Re:(ymm0,ymm2) and Im:(ymm4,ymm6) with omega=(ymm12,ymm13) +# invctwiddle Re:(ymm1,ymm3) and Im:(ymm5,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm2,%ymm0,%ymm8 # retw1 +vsubpd %ymm3,%ymm1,%ymm9 # retw2 +vsubpd %ymm6,%ymm4,%ymm10 # imtw1 +vsubpd %ymm7,%ymm5,%ymm11 # imtw2 +vaddpd %ymm2,%ymm0,%ymm0 +vaddpd %ymm3,%ymm1,%ymm1 +vaddpd %ymm6,%ymm4,%ymm4 +vaddpd %ymm7,%ymm5,%ymm5 +# multiply 8,9,10,11 by 12,13, result to 2,3,6,7 +# twiddles use reom=ymm12, imom=ymm13 +vmulpd %ymm10,%ymm13,%ymm2 # imtw1.omai +vmulpd %ymm11,%ymm13,%ymm3 # imtw2.omai +vmulpd %ymm8,%ymm13,%ymm6 # retw1.omai +vmulpd %ymm9,%ymm13,%ymm7 # retw2.omai +vfmsub231pd %ymm8,%ymm12,%ymm2 # rprod0 +vfmsub231pd %ymm9,%ymm12,%ymm3 # rprod4 +vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 +vfmadd231pd %ymm11,%ymm12,%ymm7 # iprod4 + +5: +vmovupd %ymm0,(%rdi) # ra0 +vmovupd %ymm1,0x20(%rdi) # ra4 +vmovupd %ymm2,0x40(%rdi) # ra8 +vmovupd %ymm3,0x60(%rdi) # ra12 +vmovupd %ymm4,(%rsi) # ia0 +vmovupd %ymm5,0x20(%rsi) # ia4 +vmovupd %ymm6,0x40(%rsi) # ia8 +vmovupd %ymm7,0x60(%rsi) # ia12 +vzeroupper +ret +.size reim_ifft16_avx_fma, .-reim_ifft16_avx_fma +.section .note.GNU-stack,"",@progbits diff --git a/spqlios/lib/spqlios/reim/reim_ifft16_avx_fma_win32.s b/spqlios/lib/spqlios/reim/reim_ifft16_avx_fma_win32.s new file mode 100644 index 0000000..cf554c5 --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_ifft16_avx_fma_win32.s @@ -0,0 +1,228 @@ + .text + .p2align 4 + .globl reim_ifft16_avx_fma + .def reim_ifft16_avx_fma; .scl 2; .type 32; .endef +reim_ifft16_avx_fma: + + pushq %rdi + pushq %rsi + movq %rcx,%rdi + movq %rdx,%rsi + movq %r8,%rdx + subq $0x100,%rsp + movdqu %xmm6,(%rsp) + movdqu %xmm7,0x10(%rsp) + movdqu %xmm8,0x20(%rsp) + movdqu %xmm9,0x30(%rsp) + movdqu %xmm10,0x40(%rsp) + movdqu %xmm11,0x50(%rsp) + movdqu %xmm12,0x60(%rsp) + movdqu %xmm13,0x70(%rsp) + movdqu %xmm14,0x80(%rsp) + movdqu %xmm15,0x90(%rsp) + callq reim_ifft16_avx_fma_amd64 + movdqu (%rsp),%xmm6 + movdqu 0x10(%rsp),%xmm7 + movdqu 0x20(%rsp),%xmm8 + movdqu 0x30(%rsp),%xmm9 + movdqu 0x40(%rsp),%xmm10 + movdqu 0x50(%rsp),%xmm11 + movdqu 0x60(%rsp),%xmm12 + movdqu 0x70(%rsp),%xmm13 + movdqu 0x80(%rsp),%xmm14 + movdqu 0x90(%rsp),%xmm15 + addq $0x100,%rsp + popq %rsi + popq %rdi + retq + +#rdi datare ptr +#rsi dataim ptr +#rdx om ptr +#.globl reim_ifft16_avx_fma_amd64 +reim_ifft16_avx_fma_amd64: +vmovupd (%rdi),%ymm0 # ra0 +vmovupd 0x20(%rdi),%ymm1 # ra4 +vmovupd 0x40(%rdi),%ymm2 # ra8 +vmovupd 0x60(%rdi),%ymm3 # ra12 +vmovupd (%rsi),%ymm4 # ia0 +vmovupd 0x20(%rsi),%ymm5 # ia4 +vmovupd 0x40(%rsi),%ymm6 # ia8 +vmovupd 0x60(%rsi),%ymm7 # ia12 + +1: +vmovupd 0x00(%rdx),%ymm12 +vmovupd 0x20(%rdx),%ymm13 + +vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) + +vunpckhpd %ymm1,%ymm0,%ymm4 # (0,1) -> (0,4) +vunpckhpd %ymm3,%ymm2,%ymm6 # (2,3) -> (2,6) +vunpckhpd %ymm9,%ymm8,%ymm5 # (8,9) -> (1,5) +vunpckhpd %ymm11,%ymm10,%ymm7 # (10,11) -> (3,7) +vunpcklpd %ymm1,%ymm0,%ymm0 +vunpcklpd %ymm3,%ymm2,%ymm2 +vunpcklpd %ymm9,%ymm8,%ymm1 +vunpcklpd %ymm11,%ymm10,%ymm3 + +# invctwiddle Re:(ymm0,ymm4) and Im:(ymm2,ymm6) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm1,ymm5) and Im:(ymm3,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm4,%ymm0,%ymm8 # retw +vsubpd %ymm5,%ymm1,%ymm9 # reitw +vsubpd %ymm6,%ymm2,%ymm10 # imtw +vsubpd %ymm7,%ymm3,%ymm11 # imitw +vaddpd %ymm4,%ymm0,%ymm0 +vaddpd %ymm5,%ymm1,%ymm1 +vaddpd %ymm6,%ymm2,%ymm2 +vaddpd %ymm7,%ymm3,%ymm3 +# multiply 8,9,10,11 by 12,13, result to: 4,5,6,7 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm10,%ymm13,%ymm4 # imtw.omai (tw) +vmulpd %ymm11,%ymm12,%ymm5 # imitw.omar (itw) +vmulpd %ymm8,%ymm13,%ymm6 # retw.omai (tw) +vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw) +vfmsub231pd %ymm8,%ymm12,%ymm4 # rprod0 (tw) +vfmadd231pd %ymm9,%ymm13,%ymm5 # rprod4 (itw) +vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 (tw) +vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw) + +vunpckhpd %ymm7,%ymm3,%ymm11 # (0,4) -> (0,1) +vunpckhpd %ymm5,%ymm1,%ymm9 # (2,6) -> (2,3) +vunpcklpd %ymm7,%ymm3,%ymm10 +vunpcklpd %ymm5,%ymm1,%ymm8 +vunpckhpd %ymm6,%ymm2,%ymm3 # (1,5) -> (8,9) +vunpckhpd %ymm4,%ymm0,%ymm1 # (3,7) -> (10,11) +vunpcklpd %ymm6,%ymm2,%ymm2 +vunpcklpd %ymm4,%ymm0,%ymm0 + +/* +vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 +*/ +2: +vmovupd 0x40(%rdx),%ymm12 +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omaiii'i' +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omarrr'r' + +/* +vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) +*/ + +# invctwiddle Re:(ymm0,ymm8) and Im:(ymm2,ymm10) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm1,ymm9) and Im:(ymm3,ymm11) with omega=(ymm12,ymm13) +vsubpd %ymm8,%ymm0,%ymm4 # retw +vsubpd %ymm9,%ymm1,%ymm5 # reitw +vsubpd %ymm10,%ymm2,%ymm6 # imtw +vsubpd %ymm11,%ymm3,%ymm7 # imitw +vaddpd %ymm8,%ymm0,%ymm0 +vaddpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm2,%ymm2 +vaddpd %ymm11,%ymm3,%ymm3 +# multiply 4,5,6,7 by 12,13, result to 8,9,10,11 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm6,%ymm13,%ymm8 # imtw.omai (tw) +vmulpd %ymm7,%ymm12,%ymm9 # imitw.omar (itw) +vmulpd %ymm4,%ymm13,%ymm10 # retw.omai (tw) +vmulpd %ymm5,%ymm12,%ymm11 # reitw.omar (itw) +vfmsub231pd %ymm4,%ymm12,%ymm8 # rprod0 (tw) +vfmadd231pd %ymm5,%ymm13,%ymm9 # rprod4 (itw) +vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 (tw) +vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw) + +vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 + +3: +vmovupd 0x60(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar + +# invctwiddle Re:(ymm0,ymm1) and Im:(ymm4,ymm5) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm2,ymm3) and Im:(ymm6,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm1,%ymm0,%ymm8 # retw +vsubpd %ymm3,%ymm2,%ymm9 # reitw +vsubpd %ymm5,%ymm4,%ymm10 # imtw +vsubpd %ymm7,%ymm6,%ymm11 # imitw +vaddpd %ymm1,%ymm0,%ymm0 +vaddpd %ymm3,%ymm2,%ymm2 +vaddpd %ymm5,%ymm4,%ymm4 +vaddpd %ymm7,%ymm6,%ymm6 +# multiply 8,9,10,11 by 12,13, result to 1,3,5,7 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm10,%ymm13,%ymm1 # imtw.omai (tw) +vmulpd %ymm11,%ymm12,%ymm3 # imitw.omar (itw) +vmulpd %ymm8,%ymm13,%ymm5 # retw.omai (tw) +vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw) +vfmsub231pd %ymm8,%ymm12,%ymm1 # rprod0 (tw) +vfmadd231pd %ymm9,%ymm13,%ymm3 # rprod4 (itw) +vfmadd231pd %ymm10,%ymm12,%ymm5 # iprod0 (tw) +vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw) + +4: +vmovupd 0x70(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar + +# invctwiddle Re:(ymm0,ymm2) and Im:(ymm4,ymm6) with omega=(ymm12,ymm13) +# invctwiddle Re:(ymm1,ymm3) and Im:(ymm5,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm2,%ymm0,%ymm8 # retw1 +vsubpd %ymm3,%ymm1,%ymm9 # retw2 +vsubpd %ymm6,%ymm4,%ymm10 # imtw1 +vsubpd %ymm7,%ymm5,%ymm11 # imtw2 +vaddpd %ymm2,%ymm0,%ymm0 +vaddpd %ymm3,%ymm1,%ymm1 +vaddpd %ymm6,%ymm4,%ymm4 +vaddpd %ymm7,%ymm5,%ymm5 +# multiply 8,9,10,11 by 12,13, result to 2,3,6,7 +# twiddles use reom=ymm12, imom=ymm13 +vmulpd %ymm10,%ymm13,%ymm2 # imtw1.omai +vmulpd %ymm11,%ymm13,%ymm3 # imtw2.omai +vmulpd %ymm8,%ymm13,%ymm6 # retw1.omai +vmulpd %ymm9,%ymm13,%ymm7 # retw2.omai +vfmsub231pd %ymm8,%ymm12,%ymm2 # rprod0 +vfmsub231pd %ymm9,%ymm12,%ymm3 # rprod4 +vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 +vfmadd231pd %ymm11,%ymm12,%ymm7 # iprod4 + +5: +vmovupd %ymm0,(%rdi) # ra0 +vmovupd %ymm1,0x20(%rdi) # ra4 +vmovupd %ymm2,0x40(%rdi) # ra8 +vmovupd %ymm3,0x60(%rdi) # ra12 +vmovupd %ymm4,(%rsi) # ia0 +vmovupd %ymm5,0x20(%rsi) # ia4 +vmovupd %ymm6,0x40(%rsi) # ia8 +vmovupd %ymm7,0x60(%rsi) # ia12 +vzeroupper +ret diff --git a/spqlios/lib/spqlios/reim/reim_ifft4_avx_fma.c b/spqlios/lib/spqlios/reim/reim_ifft4_avx_fma.c new file mode 100644 index 0000000..5266c95 --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_ifft4_avx_fma.c @@ -0,0 +1,62 @@ +#include +#include +#include + +#include "reim_fft_private.h" + +__always_inline void reim_invctwiddle_avx_fma(__m128d* ra, __m128d* rb, __m128d* ia, __m128d* ib, const __m128d omre, + const __m128d omim) { + __m128d rdiff = _mm_sub_pd(*ra, *rb); + __m128d idiff = _mm_sub_pd(*ia, *ib); + *ra = _mm_add_pd(*ra, *rb); + *ia = _mm_add_pd(*ia, *ib); + + *rb = _mm_mul_pd(idiff, omim); + *rb = _mm_fmsub_pd(rdiff, omre, *rb); + + *ib = _mm_mul_pd(rdiff, omim); + *ib = _mm_fmadd_pd(idiff, omre, *ib); +} + +EXPORT void reim_ifft4_avx_fma(double* dre, double* dim, const void* ompv) { + const double* omp = (const double*)ompv; + + __m128d ra0 = _mm_loadu_pd(dre); + __m128d ra2 = _mm_loadu_pd(dre + 2); + __m128d ia0 = _mm_loadu_pd(dim); + __m128d ia2 = _mm_loadu_pd(dim + 2); + + // 1 + { + const __m128d ifft4neg = _mm_castsi128_pd(_mm_set_epi64x(UINT64_C(1) << 63, 0)); + __m128d omre = _mm_loadu_pd(omp); // omre: r,i + __m128d omim = _mm_xor_pd(_mm_permute_pd(omre, 1), ifft4neg); // omim: i,-r + + __m128d ra = _mm_unpacklo_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r0, r2) + __m128d ia = _mm_unpacklo_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i0, i2) + __m128d rb = _mm_unpackhi_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r1, r3) + __m128d ib = _mm_unpackhi_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i1, i3) + + reim_invctwiddle_avx_fma(&ra, &rb, &ia, &ib, omre, omim); + + ra0 = _mm_unpacklo_pd(ra, rb); + ia0 = _mm_unpacklo_pd(ia, ib); + ra2 = _mm_unpackhi_pd(ra, rb); + ia2 = _mm_unpackhi_pd(ia, ib); + } + + // 2 + { + __m128d om = _mm_loadu_pd(omp + 2); + __m128d omre = _mm_permute_pd(om, 0); + __m128d omim = _mm_permute_pd(om, 3); + + reim_invctwiddle_avx_fma(&ra0, &ra2, &ia0, &ia2, omre, omim); + } + + // 4 + _mm_storeu_pd(dre, ra0); + _mm_storeu_pd(dre + 2, ra2); + _mm_storeu_pd(dim, ia0); + _mm_storeu_pd(dim + 2, ia2); +} diff --git a/spqlios/lib/spqlios/reim/reim_ifft8_avx_fma.c b/spqlios/lib/spqlios/reim/reim_ifft8_avx_fma.c new file mode 100644 index 0000000..b85fdda --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_ifft8_avx_fma.c @@ -0,0 +1,86 @@ +#include +#include +#include + +#include "reim_fft_private.h" + +__always_inline void reim_invctwiddle_avx_fma(__m256d* ra, __m256d* rb, __m256d* ia, __m256d* ib, const __m256d omre, + const __m256d omim) { + __m256d rdiff = _mm256_sub_pd(*ra, *rb); + __m256d idiff = _mm256_sub_pd(*ia, *ib); + *ra = _mm256_add_pd(*ra, *rb); + *ia = _mm256_add_pd(*ia, *ib); + + *rb = _mm256_mul_pd(idiff, omim); + *rb = _mm256_fmsub_pd(rdiff, omre, *rb); + + *ib = _mm256_mul_pd(rdiff, omim); + *ib = _mm256_fmadd_pd(idiff, omre, *ib); +} + +EXPORT void reim_ifft8_avx_fma(double* dre, double* dim, const void* ompv) { + const double* omp = (const double*)ompv; + + __m256d ra0 = _mm256_loadu_pd(dre); + __m256d ra4 = _mm256_loadu_pd(dre + 4); + __m256d ia0 = _mm256_loadu_pd(dim); + __m256d ia4 = _mm256_loadu_pd(dim + 4); + + // 1 + { + const __m256d fft8neg2 = _mm256_castsi256_pd(_mm256_set_epi64x(UINT64_C(1) << 63, UINT64_C(1) << 63, 0, 0)); + __m256d omr = _mm256_loadu_pd(omp); // r0,r1,i0,i1 + __m256d omiirr = _mm256_permute2f128_pd(omr, omr, 1); // i0,i1,r0,r1 + __m256d omi = _mm256_xor_pd(omiirr, fft8neg2); // i0,i1,-r0,-r1 + + __m256d rb = _mm256_unpackhi_pd(ra0, ra4); + __m256d ib = _mm256_unpackhi_pd(ia0, ia4); + __m256d ra = _mm256_unpacklo_pd(ra0, ra4); + __m256d ia = _mm256_unpacklo_pd(ia0, ia4); + + reim_invctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi); + + ra4 = _mm256_unpackhi_pd(ra, rb); + ia4 = _mm256_unpackhi_pd(ia, ib); + ra0 = _mm256_unpacklo_pd(ra, rb); + ia0 = _mm256_unpacklo_pd(ia, ib); + } + + // 2 + { + const __m128d ifft8neg = _mm_castsi128_pd(_mm_set_epi64x(0, UINT64_C(1) << 63)); + __m128d omri = _mm_loadu_pd(omp + 4); // r,i + __m128d ommri = _mm_xor_pd(omri, ifft8neg); // -r,i + __m256d omrimri = _mm256_set_m128d(ommri, omri); // r,i,-r,i + __m256d omi = _mm256_permute_pd(omrimri, 3); // i,i,-r,-r + __m256d omr = _mm256_permute_pd(omrimri, 12); // r,r,i,i + + __m256d rb = _mm256_permute2f128_pd(ra0, ra4, 0x31); + __m256d ib = _mm256_permute2f128_pd(ia0, ia4, 0x31); + __m256d ra = _mm256_permute2f128_pd(ra0, ra4, 0x20); + __m256d ia = _mm256_permute2f128_pd(ia0, ia4, 0x20); + + reim_invctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi); + + ra0 = _mm256_permute2f128_pd(ra, rb, 0x20); + ra4 = _mm256_permute2f128_pd(ra, rb, 0x31); + ia0 = _mm256_permute2f128_pd(ia, ib, 0x20); + ia4 = _mm256_permute2f128_pd(ia, ib, 0x31); + } + + // 3 + { + __m128d omri = _mm_loadu_pd(omp + 6); // r,i + __m256d omriri = _mm256_set_m128d(omri, omri); // r,i,r,i + __m256d omi = _mm256_permute_pd(omriri, 15); // i,i,i,i + __m256d omr = _mm256_permute_pd(omriri, 0); // r,r,r,r + + reim_invctwiddle_avx_fma(&ra0, &ra4, &ia0, &ia4, omr, omi); + } + + // 4 + _mm256_storeu_pd(dre, ra0); + _mm256_storeu_pd(dre + 4, ra4); + _mm256_storeu_pd(dim, ia0); + _mm256_storeu_pd(dim + 4, ia4); +} diff --git a/spqlios/lib/spqlios/reim/reim_ifft_avx2.c b/spqlios/lib/spqlios/reim/reim_ifft_avx2.c new file mode 100644 index 0000000..3c42ea5 --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_ifft_avx2.c @@ -0,0 +1,167 @@ +#include "immintrin.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +void reim_invtwiddle_ifft_ref(uint64_t h, double* re, double* im, double om[2]); +void reim_invbitwiddle_ifft_ref(uint64_t h, double* re, double* im, double om[4]); + +__always_inline void reim_invtwiddle_ifft_avx2_fma(uint32_t h, double* re, double* im, double om[2]) { + const __m128d omx = _mm_load_pd(om); + const __m256d omra = _mm256_set_m128d(omx, omx); + const __m256d omi = _mm256_unpackhi_pd(omra, omra); + const __m256d omr = _mm256_unpacklo_pd(omra, omra); + double* r0 = re; + double* r1 = re + h; + double* i0 = im; + double* i1 = im + h; + for (uint32_t i=0; i>1; // m/2 + while (h < ms2) { + uint32_t mm = h << 2; + for (uint32_t off = 0; off < m; off += mm) { + reim_invbitwiddle_ifft_avx2_fma(h, re + off, im + off, *omg); + *omg += 4; + } + h = mm; + } + if (log2m & 1) { + //if (h!=ms2) abort(); // bug + // do the first twiddle iteration normally + reim_invtwiddle_ifft_avx2_fma(h, re, im, *omg); + *omg += 2; + //h = m; + } +} + +void reim_ifft_rec_16_avx2_fma(uint32_t m, double* re, double* im, double** omg) { + if (m <= 2048) return reim_ifft_bfs_16_avx2_fma(m, re, im, omg); + const uint32_t h = m >> 1; //m / 2; + reim_ifft_rec_16_avx2_fma(h, re, im, omg); + reim_ifft_rec_16_avx2_fma(h, re + h, im + h, omg); + reim_invtwiddle_ifft_avx2_fma(h, re, im, *omg); + *omg += 2; +} + +void reim_ifft_avx2_fma(const REIM_IFFT_PRECOMP* precomp, double* dat) { + const int32_t m = precomp->m; + double* omg = precomp->powomegas; + double* re = dat; + double* im = dat+m; + if (m <= 16) { + switch (m) { + case 1: + return; + case 2: + return reim_ifft2_ref(re, im, omg); + case 4: + return reim_ifft4_avx_fma(re, im, omg); + case 8: + return reim_ifft8_avx_fma(re, im, omg); + case 16: + return reim_ifft16_avx_fma(re, im, omg); + default: + abort(); // m is not a power of 2 + } + } + if (m <= 2048) return reim_ifft_bfs_16_avx2_fma(m, re, im, &omg); + return reim_ifft_rec_16_avx2_fma(m, re, im, &omg); +} diff --git a/spqlios/lib/spqlios/reim/reim_ifft_ref.c b/spqlios/lib/spqlios/reim/reim_ifft_ref.c new file mode 100644 index 0000000..d667e36 --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_ifft_ref.c @@ -0,0 +1,409 @@ +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +void reim_invctwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim) { + double rdiff = *ra - *rb; + double idiff = *ia - *ib; + *ra = *ra + *rb; + *ia = *ia + *ib; + *rb = rdiff * omre - idiff * omim; + *ib = rdiff * omim + idiff * omre; +} + +// i (omre + i omim) = -omim + i omre +void reim_invcitwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim) { + double rdiff = *ra - *rb; + double idiff = *ia - *ib; + *ra = *ra + *rb; + *ia = *ia + *ib; + *rb = rdiff * omim + idiff * omre; + *ib = - rdiff * omre + idiff * omim; +} + +void reim_ifft16_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omare = om[0]; + double ombre = om[1]; + double omcre = om[2]; + double omdre = om[3]; + double omaim = om[4]; + double ombim = om[5]; + double omcim = om[6]; + double omdim = om[7]; + reim_invctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + reim_invcitwiddle(&dre[2], &dim[2], &dre[3], &dim[3], omare, omaim); + reim_invctwiddle(&dre[4], &dim[4], &dre[5], &dim[5], ombre, ombim); + reim_invcitwiddle(&dre[6], &dim[6], &dre[7], &dim[7], ombre, ombim); + reim_invctwiddle(&dre[8], &dim[8], &dre[9], &dim[9], omcre, omcim); + reim_invcitwiddle(&dre[10], &dim[10], &dre[11], &dim[11], omcre, omcim); + reim_invctwiddle(&dre[12], &dim[12], &dre[13], &dim[13], omdre, omdim); + reim_invcitwiddle(&dre[14], &dim[14], &dre[15], &dim[15], omdre, omdim); + } + { + double omare = om[8]; + double omaim = om[9]; + double ombre = om[10]; + double ombim = om[11]; + reim_invctwiddle(&dre[0], &dim[0], &dre[2], &dim[2], omare, omaim); + reim_invctwiddle(&dre[1], &dim[1], &dre[3], &dim[3], omare, omaim); + reim_invcitwiddle(&dre[4], &dim[4], &dre[6], &dim[6], omare, omaim); + reim_invcitwiddle(&dre[5], &dim[5], &dre[7], &dim[7], omare, omaim); + reim_invctwiddle(&dre[8], &dim[8], &dre[10], &dim[10], ombre, ombim); + reim_invctwiddle(&dre[9], &dim[9], &dre[11], &dim[11], ombre, ombim); + reim_invcitwiddle(&dre[12], &dim[12], &dre[14], &dim[14], ombre, ombim); + reim_invcitwiddle(&dre[13], &dim[13], &dre[15], &dim[15], ombre, ombim); + } + { + double omre = om[12]; + double omim = om[13]; + reim_invctwiddle(&dre[0], &dim[0], &dre[4], &dim[4], omre, omim); + reim_invctwiddle(&dre[1], &dim[1], &dre[5], &dim[5], omre, omim); + reim_invctwiddle(&dre[2], &dim[2], &dre[6], &dim[6], omre, omim); + reim_invctwiddle(&dre[3], &dim[3], &dre[7], &dim[7], omre, omim); + reim_invcitwiddle(&dre[8], &dim[8], &dre[12], &dim[12], omre, omim); + reim_invcitwiddle(&dre[9], &dim[9], &dre[13], &dim[13], omre, omim); + reim_invcitwiddle(&dre[10], &dim[10], &dre[14], &dim[14], omre, omim); + reim_invcitwiddle(&dre[11], &dim[11], &dre[15], &dim[15], omre, omim); + } + { + double omre = om[14]; + double omim = om[15]; + reim_invctwiddle(&dre[0], &dim[0], &dre[8], &dim[8], omre, omim); + reim_invctwiddle(&dre[1], &dim[1], &dre[9], &dim[9], omre, omim); + reim_invctwiddle(&dre[2], &dim[2], &dre[10], &dim[10], omre, omim); + reim_invctwiddle(&dre[3], &dim[3], &dre[11], &dim[11], omre, omim); + reim_invctwiddle(&dre[4], &dim[4], &dre[12], &dim[12], omre, omim); + reim_invctwiddle(&dre[5], &dim[5], &dre[13], &dim[13], omre, omim); + reim_invctwiddle(&dre[6], &dim[6], &dre[14], &dim[14], omre, omim); + reim_invctwiddle(&dre[7], &dim[7], &dre[15], &dim[15], omre, omim); + } +} + +void fill_reim_ifft16_omegas(const double entry_pwr, double** omg) { + const double j_pow = 1. / 8.; + const double k_pow = 1. / 16.; + const double pin = entry_pwr / 2.; + const double pin_2 = entry_pwr / 4.; + const double pin_4 = entry_pwr / 8.; + const double pin_8 = entry_pwr / 16.; + // ((8,9,10,11),(12,13,14,15)) are 4 reals then 4 imag of om^1/8*(1,k,j,kj) + (*omg)[0] = cos(2. * M_PI * (pin_8)); + (*omg)[1] = cos(2. * M_PI * (pin_8 + j_pow)); + (*omg)[2] = cos(2. * M_PI * (pin_8 + k_pow)); + (*omg)[3] = cos(2. * M_PI * (pin_8 + j_pow + k_pow)); + (*omg)[4] = - sin(2. * M_PI * (pin_8)); + (*omg)[5] = - sin(2. * M_PI * (pin_8 + j_pow)); + (*omg)[6] = - sin(2. * M_PI * (pin_8 + k_pow)); + (*omg)[7] = - sin(2. * M_PI * (pin_8 + j_pow + k_pow)); + // (4,5) and (6,7) are real and imag of om^1/4 and j.om^1/4 + (*omg)[8] = cos(2. * M_PI * (pin_4)); + (*omg)[9] = - sin(2. * M_PI * (pin_4)); + (*omg)[10] = cos(2. * M_PI * (pin_4 + j_pow)); + (*omg)[11] = - sin(2. * M_PI * (pin_4 + j_pow)); + // 2 and 3 are real and imag of om^1/2 + (*omg)[12] = cos(2. * M_PI * (pin_2)); + (*omg)[13] = - sin(2. * M_PI * (pin_2)); + // 0 and 1 are real and imag of om + (*omg)[14] = cos(2. * M_PI * pin); + (*omg)[15] = - sin(2. * M_PI * pin); + *omg += 16; +} + +void reim_ifft8_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omare = om[0]; + double ombre = om[1]; + double omaim = om[2]; + double ombim = om[3]; + reim_invctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + reim_invcitwiddle(&dre[2], &dim[2], &dre[3], &dim[3], omare, omaim); + reim_invctwiddle(&dre[4], &dim[4], &dre[5], &dim[5], ombre, ombim); + reim_invcitwiddle(&dre[6], &dim[6], &dre[7], &dim[7], ombre, ombim); + } + { + double omare = om[4]; + double omaim = om[5]; + reim_invctwiddle(&dre[0], &dim[0], &dre[2], &dim[2], omare, omaim); + reim_invctwiddle(&dre[1], &dim[1], &dre[3], &dim[3], omare, omaim); + reim_invcitwiddle(&dre[4], &dim[4], &dre[6], &dim[6], omare, omaim); + reim_invcitwiddle(&dre[5], &dim[5], &dre[7], &dim[7], omare, omaim); + } + { + double omre = om[6]; + double omim = om[7]; + reim_invctwiddle(&dre[0], &dim[0], &dre[4], &dim[4], omre, omim); + reim_invctwiddle(&dre[1], &dim[1], &dre[5], &dim[5], omre, omim); + reim_invctwiddle(&dre[2], &dim[2], &dre[6], &dim[6], omre, omim); + reim_invctwiddle(&dre[3], &dim[3], &dre[7], &dim[7], omre, omim); + } +} + +void fill_reim_ifft8_omegas(const double entry_pwr, double** omg) { + const double j_pow = 1. / 8.; + const double pin = entry_pwr / 2.; + const double pin_2 = entry_pwr / 4.; + const double pin_4 = entry_pwr / 8.; + // (4,5) and (6,7) are real and imag of om^1/4 and j.om^1/4 + (*omg)[0] = cos(2. * M_PI * (pin_4)); + (*omg)[1] = cos(2. * M_PI * (pin_4 + j_pow)); + (*omg)[2] = - sin(2. * M_PI * (pin_4)); + (*omg)[3] = - sin(2. * M_PI * (pin_4 + j_pow)); + // 2 and 3 are real and imag of om^1/2 + (*omg)[4] = cos(2. * M_PI * (pin_2)); + (*omg)[5] = - sin(2. * M_PI * (pin_2)); + // 0 and 1 are real and imag of om + (*omg)[6] = cos(2. * M_PI * pin); + (*omg)[7] = - sin(2. * M_PI * pin); + *omg += 8; +} + +void reim_ifft4_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omare = om[0]; + double omaim = om[1]; + reim_invctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + reim_invcitwiddle(&dre[2], &dim[2], &dre[3], &dim[3], omare, omaim); + } + { + double omare = om[2]; + double omaim = om[3]; + reim_invctwiddle(&dre[0], &dim[0], &dre[2], &dim[2], omare, omaim); + reim_invctwiddle(&dre[1], &dim[1], &dre[3], &dim[3], omare, omaim); + } +} + +void fill_reim_ifft4_omegas(const double entry_pwr, double** omg) { + const double pin = entry_pwr / 2.; + const double pin_2 = entry_pwr / 4.; + // 2 and 3 are real and imag of om^1/2 + (*omg)[0] = cos(2. * M_PI * (pin_2)); + (*omg)[1] = - sin(2. * M_PI * (pin_2)); + // 0 and 1 are real and imag of om + (*omg)[2] = cos(2. * M_PI * pin); + (*omg)[3] = - sin(2. * M_PI * pin); + *omg += 4; +} + +void reim_ifft2_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omare = om[0]; + double omaim = om[1]; + reim_invctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + } +} + +void fill_reim_ifft2_omegas(const double entry_pwr, double** omg) { + const double pin = entry_pwr / 2.; + // 0 and 1 are real and imag of om + (*omg)[0] = cos(2. * M_PI * pin); + (*omg)[1] = - sin(2. * M_PI * pin); + *omg += 2; +} + +void reim_invtwiddle_ifft_ref(uint64_t h, double* re, double* im, double om[2]) { + for (uint64_t i=0; im; + double* omg = precomp->powomegas; + double* re = dat; + double* im = dat+m; + if (m <= 16) { + switch (m) { + case 1: + return; + case 2: + return reim_ifft2_ref(re, im, omg); + case 4: + return reim_ifft4_ref(re, im, omg); + case 8: + return reim_ifft8_ref(re, im, omg); + case 16: + return reim_ifft16_ref(re, im, omg); + default: + abort(); // m is not a power of 2 + } + } + if (m <= 2048) return reim_ifft_bfs_16_ref(m, re, im, &omg); + return reim_ifft_rec_16_ref(m, re, im, &omg); +} + +EXPORT REIM_IFFT_PRECOMP* new_reim_ifft_precomp(uint32_t m, uint32_t num_buffers) { + const uint64_t OMG_SPACE = ceilto64b(2 * m * sizeof(double)); + const uint64_t BUF_SIZE = ceilto64b(2 * m * sizeof(double)); + void* reps = malloc(sizeof(REIM_IFFT_PRECOMP) // base + + 63 // padding + + OMG_SPACE // tables //TODO 16? + + num_buffers * BUF_SIZE // buffers + ); + uint64_t aligned_addr = ceilto64b((uint64_t)(reps) + sizeof(REIM_IFFT_PRECOMP)); + REIM_IFFT_PRECOMP* r = (REIM_IFFT_PRECOMP*)reps; + r->m = m; + r->buf_size = BUF_SIZE; + r->powomegas = (double*)aligned_addr; + r->aligned_buffers = (void*)(aligned_addr + OMG_SPACE); + // fill in powomegas + double* omg = (double*) r->powomegas; + if (m <= 16) { + switch (m) { + case 1: + break; + case 2: + fill_reim_ifft2_omegas(0.25, &omg); + break; + case 4: + fill_reim_ifft4_omegas(0.25, &omg); + break; + case 8: + fill_reim_ifft8_omegas(0.25, &omg); + break; + case 16: + fill_reim_ifft16_omegas(0.25, &omg); + break; + default: + abort(); // m is not a power of 2 + } + } else if (m <= 2048) { + fill_reim_ifft_bfs_16_omegas(m, 0.25, &omg); + } else { + fill_reim_ifft_rec_16_omegas(m, 0.25, &omg); + } + if (((uint64_t)omg) - aligned_addr > OMG_SPACE) abort(); + // dispatch the right implementation + { + if (CPU_SUPPORTS("fma")) { + r->function = reim_ifft_avx2_fma; + } else { + r->function = reim_ifft_ref; + } + } + return reps; +} + + +void reim_naive_ifft(uint64_t m, double entry_pwr, double* re, double* im) { + if (m == 1) return; + // twiddle + const uint64_t h = m / 2; + const double s = entry_pwr / 2.; + reim_naive_ifft(h, s, re, im); + reim_naive_ifft(h, s + 0.5, re + h, im + h); + const double sre = cos(2 * M_PI * s); + const double sim = - sin(2 * M_PI * s); + for (uint64_t j = 0; j < h; ++j) { + double rdiff = re[j] - re[h+j]; + double idiff = im[j] - im[h+j]; + re[j] = re[j] + re[h+j]; + im[j] = im[j] + im[h+j]; + re[h+j] = rdiff * sre - idiff * sim; + im[h + j] = idiff * sre + rdiff * sim; + } +} diff --git a/spqlios/lib/spqlios/reim/reim_to_tnx_avx.c b/spqlios/lib/spqlios/reim/reim_to_tnx_avx.c new file mode 100644 index 0000000..daf7ef5 --- /dev/null +++ b/spqlios/lib/spqlios/reim/reim_to_tnx_avx.c @@ -0,0 +1,32 @@ +#include +#include +#include + +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +typedef union {double d; uint64_t u;} dblui64_t; + +EXPORT void reim_to_tnx_avx(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x) { + const uint64_t n = tables->m << 1; + const __m256d add_cst = _mm256_set1_pd(tables->add_cst); + const __m256d mask_and = _mm256_castsi256_pd(_mm256_set1_epi64x(tables->mask_and)); + const __m256d mask_or = _mm256_castsi256_pd(_mm256_set1_epi64x(tables->mask_or)); + const __m256d sub_cst = _mm256_set1_pd(tables->sub_cst); + __m256d reg0,reg1; + for (uint64_t i=0; i +#include + +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +EXPORT void reim_to_tnx_basic_ref(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x) { + const uint64_t n = tables->m << 1; + double divisor = tables->divisor; + for (uint64_t i=0; im << 1; + double add_cst = tables->add_cst; + uint64_t mask_and = tables->mask_and; + uint64_t mask_or = tables->mask_or; + double sub_cst = tables->sub_cst; + dblui64_t cur; + for (uint64_t i=0; i 52) return spqlios_error("log2overhead is too large"); + res->m = m; + res->divisor = divisor; + res->log2overhead = log2overhead; + // 52 + 11 + 1 + // ......1.......01(1)|expo|sign + // .......=========(1)|expo|sign msbbits = log2ovh + 2 + 11 + 1 + uint64_t nbits = 50 - log2overhead; + dblui64_t ovh_cst; + ovh_cst.d = 0.5 + (6<add_cst = ovh_cst.d * divisor; + res->mask_and = ((UINT64_C(1) << nbits) - 1); + res->mask_or = ovh_cst.u & ((UINT64_C(-1)) << nbits); + res->sub_cst = ovh_cst.d; + // TODO: check selection logic + if (CPU_SUPPORTS("avx2")) { + if (m >= 8) { + res->function = reim_to_tnx_avx; + } else { + res->function = reim_to_tnx_ref; + } + } else { + res->function = reim_to_tnx_ref; + } + return res; +} + +EXPORT REIM_TO_TNX_PRECOMP* new_reim_to_tnx_precomp(uint32_t m, double divisor, uint32_t log2overhead) { + REIM_TO_TNX_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim_to_tnx_precomp(res, m, divisor, log2overhead)); +} + +EXPORT void reim_to_tnx(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x) { + tables->function(tables, r, x); +} diff --git a/spqlios/lib/spqlios/reim4/reim4_arithmetic.h b/spqlios/lib/spqlios/reim4/reim4_arithmetic.h new file mode 100644 index 0000000..75ae1d7 --- /dev/null +++ b/spqlios/lib/spqlios/reim4/reim4_arithmetic.h @@ -0,0 +1,149 @@ +#ifndef SPQLIOS_REIM4_ARITHMETIC_H +#define SPQLIOS_REIM4_ARITHMETIC_H + +#include "../commons.h" + +// the reim4 structure represent 4 complex numbers, +// represented as re0,re1,re2,re3,im0,im1,im2,im3 + +// TODO implement these basic arithmetic functions. (only ref is needed) + +/** @brief dest = 0 */ +EXPORT void reim4_zero(double* dest); +/** @brief dest = a + b */ +EXPORT void reim4_add(double* dest, const double* u, const double* v); +/** @brief dest = a * b */ +EXPORT void reim4_mul(double* dest, const double* u, const double* v); +/** @brief dest += a * b */ +EXPORT void reim4_add_mul(double* dest, const double* a, const double* b); + +// TODO add the dot products: vec x mat1col, vec x mat2cols functions here (ref and avx) + +// TODO implement the convolution functions here (ref and avx) +/** + * @brief k-th coefficient of the convolution product a * b. + * + * The k-th coefficient is defined as sum_{i+j=k} a[i].b[j]. + * (in this sum, i and j must remain within the bounds [0,sizea[ and [0,sizeb[) + * + * In practice, accounting for these bounds, the convolution function boils down to + * ``` + * res := 0 + * if (k +#include + +#include "reim4_arithmetic.h" + +void reim4_extract_1blk_from_reim_avx(uint64_t m, uint64_t blk, + double* const dst, // nrows * 8 doubles + const double* const src // a contiguous array of nrows reim vectors +) { + assert(blk < (m >> 2)); + const double* src_ptr = src + (blk << 2); + double* dst_ptr = dst; + _mm256_storeu_pd(dst_ptr, _mm256_loadu_pd(src_ptr)); + _mm256_storeu_pd(dst_ptr + 4, _mm256_loadu_pd(src_ptr + m)); +} + +void reim4_extract_1blk_from_contiguous_reim_avx(uint64_t m, uint64_t nrows, uint64_t blk, double* const dst, + const double* const src) { + assert(blk < (m >> 2)); + const double* src_ptr = src + (blk << 2); + double* dst_ptr = dst; + for (uint64_t i = 0; i < nrows * 2; ++i) { + _mm256_storeu_pd(dst_ptr, _mm256_loadu_pd(src_ptr)); + dst_ptr += 4; + src_ptr += m; + } +} + +EXPORT void reim4_extract_1blk_from_contiguous_reim_sl_avx(uint64_t m, uint64_t sl, uint64_t nrows, uint64_t blk, + double* const dst, const double* const src) { + assert(blk < (m >> 2)); + const double* src_ptr = src + (blk << 2); + double* dst_ptr = dst; + for (uint64_t i = 0; i < nrows; ++i) { + _mm256_storeu_pd(dst_ptr, _mm256_loadu_pd(src_ptr)); + _mm256_storeu_pd(dst_ptr + 4, _mm256_loadu_pd(src_ptr + m)); + dst_ptr += 8; + src_ptr += sl; + } +} + +void reim4_save_1blk_to_reim_avx(uint64_t m, uint64_t blk, + double* dst, // 1 reim vector of length m + const double* src // 8 doubles +) { + assert(blk < (m >> 2)); + const double* src_ptr = src; + double* dst_ptr = dst + (blk << 2); + _mm256_storeu_pd(dst_ptr, _mm256_loadu_pd(src_ptr)); + _mm256_storeu_pd(dst_ptr + m, _mm256_loadu_pd(src_ptr + 4)); +} + +__always_inline void cplx_prod(__m256d* re1, __m256d* re2, __m256d* im, const double* const u_ptr, + const double* const v_ptr) { + const __m256d a = _mm256_loadu_pd(u_ptr); + const __m256d c = _mm256_loadu_pd(v_ptr); + *re1 = _mm256_fmadd_pd(a, c, *re1); + + const __m256d b = _mm256_loadu_pd(u_ptr + 4); + *im = _mm256_fmadd_pd(b, c, *im); + + const __m256d d = _mm256_loadu_pd(v_ptr + 4); + *re2 = _mm256_fmadd_pd(b, d, *re2); + *im = _mm256_fmadd_pd(a, d, *im); +} + +void reim4_vec_mat1col_product_avx2(const uint64_t nrows, + double* const dst, // 8 doubles + const double* const u, // nrows * 8 doubles + const double* const v // nrows * 8 doubles +) { + __m256d re1 = _mm256_setzero_pd(); + __m256d re2 = _mm256_setzero_pd(); + __m256d im1 = _mm256_setzero_pd(); + __m256d im2 = _mm256_setzero_pd(); + + const double* u_ptr = u; + const double* v_ptr = v; + for (uint64_t i = 0; i < nrows; ++i) { + const __m256d a = _mm256_loadu_pd(u_ptr); + const __m256d c = _mm256_loadu_pd(v_ptr); + re1 = _mm256_fmadd_pd(a, c, re1); + + const __m256d b = _mm256_loadu_pd(u_ptr + 4); + im2 = _mm256_fmadd_pd(b, c, im2); + + const __m256d d = _mm256_loadu_pd(v_ptr + 4); + re2 = _mm256_fmadd_pd(b, d, re2); + im1 = _mm256_fmadd_pd(a, d, im1); + + u_ptr += 8; + v_ptr += 8; + } + + _mm256_storeu_pd(dst, _mm256_sub_pd(re1, re2)); + _mm256_storeu_pd(dst + 4, _mm256_add_pd(im1, im2)); +} + +EXPORT void reim4_vec_mat2cols_product_avx2(const uint64_t nrows, + double* const dst, // 16 doubles + const double* const u, // nrows * 16 doubles + const double* const v // nrows * 16 doubles +) { + __m256d re1 = _mm256_setzero_pd(); + __m256d im1 = _mm256_setzero_pd(); + __m256d re2 = _mm256_setzero_pd(); + __m256d im2 = _mm256_setzero_pd(); + + __m256d ur,ui,ar,ai,br,bi; + for (uint64_t i = 0; i < nrows; ++i) { + ur = _mm256_loadu_pd(u+8*i); + ui = _mm256_loadu_pd(u+8*i+4); + ar = _mm256_loadu_pd(v+16*i); + ai = _mm256_loadu_pd(v+16*i+4); + br = _mm256_loadu_pd(v+16*i+8); + bi = _mm256_loadu_pd(v+16*i+12); + re1 = _mm256_fmsub_pd(ui,ai,re1); + re2 = _mm256_fmsub_pd(ui,bi,re2); + im1 = _mm256_fmadd_pd(ur,ai,im1); + im2 = _mm256_fmadd_pd(ur,bi,im2); + re1 = _mm256_fmsub_pd(ur,ar,re1); + re2 = _mm256_fmsub_pd(ur,br,re2); + im1 = _mm256_fmadd_pd(ui,ar,im1); + im2 = _mm256_fmadd_pd(ui,br,im2); + } + _mm256_storeu_pd(dst, re1); + _mm256_storeu_pd(dst + 4, im1); + _mm256_storeu_pd(dst + 8, re2); + _mm256_storeu_pd(dst + 12, im2); +} diff --git a/spqlios/lib/spqlios/reim4/reim4_arithmetic_ref.c b/spqlios/lib/spqlios/reim4/reim4_arithmetic_ref.c new file mode 100644 index 0000000..3d69f5c --- /dev/null +++ b/spqlios/lib/spqlios/reim4/reim4_arithmetic_ref.c @@ -0,0 +1,214 @@ +#include +#include + +#include "reim4_arithmetic.h" + +void reim4_extract_1blk_from_reim_ref(uint64_t m, uint64_t blk, + double* const dst, // 8 doubles + const double* const src // one reim vector +) { + assert(blk < (m >> 2)); + const double* src_ptr = src + (blk << 2); + // copy the real parts + dst[0] = src_ptr[0]; + dst[1] = src_ptr[1]; + dst[2] = src_ptr[2]; + dst[3] = src_ptr[3]; + src_ptr += m; + // copy the imaginary parts + dst[4] = src_ptr[0]; + dst[5] = src_ptr[1]; + dst[6] = src_ptr[2]; + dst[7] = src_ptr[3]; +} + +EXPORT void reim4_extract_1blk_from_contiguous_reim_ref(uint64_t m, uint64_t nrows, uint64_t blk, double* const dst, + const double* const src) { + assert(blk < (m >> 2)); + + const double* src_ptr = src + (blk << 2); + double* dst_ptr = dst; + for (uint64_t i = 0; i < nrows * 2; ++i) { + dst_ptr[0] = src_ptr[0]; + dst_ptr[1] = src_ptr[1]; + dst_ptr[2] = src_ptr[2]; + dst_ptr[3] = src_ptr[3]; + dst_ptr += 4; + src_ptr += m; + } +} + +EXPORT void reim4_extract_1blk_from_contiguous_reim_sl_ref(uint64_t m, uint64_t sl, uint64_t nrows, uint64_t blk, + double* const dst, const double* const src) { + assert(blk < (m >> 2)); + + const double* src_ptr = src + (blk << 2); + double* dst_ptr = dst; + const uint64_t sl_minus_m = sl - m; + for (uint64_t i = 0; i < nrows; ++i) { + dst_ptr[0] = src_ptr[0]; + dst_ptr[1] = src_ptr[1]; + dst_ptr[2] = src_ptr[2]; + dst_ptr[3] = src_ptr[3]; + dst_ptr += 4; + src_ptr += m; + dst_ptr[0] = src_ptr[0]; + dst_ptr[1] = src_ptr[1]; + dst_ptr[2] = src_ptr[2]; + dst_ptr[3] = src_ptr[3]; + dst_ptr += 4; + src_ptr += sl_minus_m; + } +} + +// dest(i)=src +// use scalar or sse or avx? Code +// should be the inverse of reim4_extract_1col_from_reim +void reim4_save_1blk_to_reim_ref(uint64_t m, uint64_t blk, + double* dst, // 1 reim vector of length m + const double* src // 8 doubles +) { + assert(blk < (m >> 2)); + double* dst_ptr = dst + (blk << 2); + // save the real part + dst_ptr[0] = src[0]; + dst_ptr[1] = src[1]; + dst_ptr[2] = src[2]; + dst_ptr[3] = src[3]; + dst_ptr += m; + // save the imag part + dst_ptr[0] = src[4]; + dst_ptr[1] = src[5]; + dst_ptr[2] = src[6]; + dst_ptr[3] = src[7]; +} + +// dest = 0 +void reim4_zero(double* const dst // 8 doubles +) { + for (uint64_t i = 0; i < 8; ++i) dst[i] = 0; +} + +/** @brief dest = a + b */ +void reim4_add(double* const dst, // 8 doubles + const double* const u, // nrows * 8 doubles + const double* const v // nrows * 8 doubles +) { + for (uint64_t k = 0; k < 4; ++k) { + const double a = u[k]; + const double c = v[k]; + const double b = u[k + 4]; + const double d = v[k + 4]; + + dst[k] = a + c; + dst[k + 4] = b + d; + } +} + +/** @brief dest = a * b */ +void reim4_mul(double* const dst, // 8 doubles + const double* const u, // 8 doubles + const double* const v // 8 doubles +) { + for (uint64_t k = 0; k < 4; ++k) { + const double a = u[k]; + const double c = v[k]; + const double b = u[k + 4]; + const double d = v[k + 4]; + + dst[k] = a * c - b * d; + dst[k + 4] = a * d + b * c; + } +} + +/** @brief dest += a * b */ +void reim4_add_mul(double* const dst, // 8 doubles + const double* const u, // 8 doubles + const double* const v // 8 doubles +) { + for (uint64_t k = 0; k < 4; ++k) { + const double a = u[k]; + const double c = v[k]; + const double b = u[k + 4]; + const double d = v[k + 4]; + + dst[k] += a * c - b * d; + dst[k + 4] += a * d + b * c; + } +} + +/** dest = uT * v where u is a vector of size nrows, and v is a nrows x 1 matrix */ +void reim4_vec_mat1col_product_ref(const uint64_t nrows, + double* const dst, // 8 doubles + const double* const u, // nrows * 8 doubles + const double* const v // nrows * 8 doubles +) { + reim4_zero(dst); + + for (uint64_t i = 0, j = 0; i < nrows; ++i, j += 8) { + reim4_add_mul(dst, u + j, v + j); + } +} + +/** dest = uT * v where u is a vector of size nrows, and v is a nrows x 2 matrix */ +void reim4_vec_mat2cols_product_ref(const uint64_t nrows, + double* const dst, // 16 doubles + const double* const u, // nrows * 8 doubles + const double* const v // nrows * 16 doubles +) { + reim4_zero(dst); + reim4_zero(dst + 8); + + double* dst1 = dst; + double* dst2 = dst + 8; + + for (uint64_t i = 0, j = 0; i < nrows; ++i, j += 8) { + uint64_t double_j = j << 1; + reim4_add_mul(dst1, u + j, v + double_j); + reim4_add_mul(dst2, u + j, v + double_j + 8); + } +} + +/** + * @brief k-th coefficient of the convolution product a * b. + * + * The k-th coefficient is defined as sum_{i+j=k} a[i].b[j]. + * (in this sum, i and j must remain within the bounds [0,sizea[ and [0,sizeb[) + * + * In practice, accounting for these bounds, the convolution function boils down to + * ``` + * res := 0 + * if (k= sizea + sizeb) return; + uint64_t jmin = k >= sizea ? k + 1 - sizea : 0; + uint64_t jmax = k < sizeb ? k + 1 : sizeb; + for (uint64_t j = jmin; j < jmax; ++j) { + reim4_add_mul(dest, a + 8 * (k - j), b + 8 * j); + } +} + +/** @brief returns two consecutive convolution coefficients: k and k+1 */ +EXPORT void reim4_convolution_2coeff_ref(uint64_t k, double* dest, const double* a, uint64_t sizea, const double* b, + uint64_t sizeb) { + reim4_convolution_1coeff_ref(k, dest, a, sizea, b, sizeb); + reim4_convolution_1coeff_ref(k + 1, dest + 8, a, sizea, b, sizeb); +} + +/** + * @brief From the convolution a * b, return the coefficients between offset and offset + size + * For the full convolution, use offset=0 and size=sizea+sizeb-1. + */ +EXPORT void reim4_convolution_ref(double* dest, uint64_t dest_size, uint64_t dest_offset, const double* a, + uint64_t sizea, const double* b, uint64_t sizeb) { + for (uint64_t k = 0; k < dest_size; ++k) { + reim4_convolution_1coeff_ref(k + dest_offset, dest + 8 * k, a, sizea, b, sizeb); + } +} diff --git a/spqlios/lib/spqlios/reim4/reim4_execute.c b/spqlios/lib/spqlios/reim4/reim4_execute.c new file mode 100644 index 0000000..9618393 --- /dev/null +++ b/spqlios/lib/spqlios/reim4/reim4_execute.c @@ -0,0 +1,19 @@ +#include "reim4_fftvec_internal.h" +#include "reim4_fftvec_private.h" + +EXPORT void reim4_fftvec_addmul(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, + const double* b) { + tables->function(tables, r, a, b); +} + +EXPORT void reim4_fftvec_mul(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b) { + tables->function(tables, r, a, b); +} + +EXPORT void reim4_from_cplx(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a) { + tables->function(tables, r, a); +} + +EXPORT void reim4_to_cplx(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a) { + tables->function(tables, r, a); +} diff --git a/spqlios/lib/spqlios/reim4/reim4_fallbacks_aarch64.c b/spqlios/lib/spqlios/reim4/reim4_fallbacks_aarch64.c new file mode 100644 index 0000000..bfbc4b2 --- /dev/null +++ b/spqlios/lib/spqlios/reim4/reim4_fallbacks_aarch64.c @@ -0,0 +1,11 @@ +#include "reim4_fftvec_private.h" + +EXPORT void reim4_fftvec_mul_fma(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, + const double* b){UNDEFINED()} + +EXPORT void reim4_fftvec_addmul_fma(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, + const double* b){UNDEFINED()} + +EXPORT void reim4_from_cplx_fma(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a){UNDEFINED()} + +EXPORT void reim4_to_cplx_fma(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a){UNDEFINED()} diff --git a/spqlios/lib/spqlios/reim4/reim4_fftvec_addmul_fma.c b/spqlios/lib/spqlios/reim4/reim4_fftvec_addmul_fma.c new file mode 100644 index 0000000..833a95b --- /dev/null +++ b/spqlios/lib/spqlios/reim4/reim4_fftvec_addmul_fma.c @@ -0,0 +1,54 @@ +#include +#include +#include + +#include "reim4_fftvec_private.h" + +EXPORT void reim4_fftvec_addmul_fma(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r_ptr, const double* a_ptr, + const double* b_ptr) { + const double* const rend_ptr = r_ptr + (tables->m << 1); + while (r_ptr != rend_ptr) { + __m256d rr = _mm256_loadu_pd(r_ptr); + __m256d ri = _mm256_loadu_pd(r_ptr + 4); + const __m256d ar = _mm256_loadu_pd(a_ptr); + const __m256d ai = _mm256_loadu_pd(a_ptr + 4); + const __m256d br = _mm256_loadu_pd(b_ptr); + const __m256d bi = _mm256_loadu_pd(b_ptr + 4); + + rr = _mm256_fmsub_pd(ai, bi, rr); + rr = _mm256_fmsub_pd(ar, br, rr); + ri = _mm256_fmadd_pd(ar, bi, ri); + ri = _mm256_fmadd_pd(ai, br, ri); + + _mm256_storeu_pd(r_ptr, rr); + _mm256_storeu_pd(r_ptr + 4, ri); + + r_ptr += 8; + a_ptr += 8; + b_ptr += 8; + } +} + +EXPORT void reim4_fftvec_mul_fma(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r_ptr, const double* a_ptr, + const double* b_ptr) { + const double* const rend_ptr = r_ptr + (tables->m << 1); + while (r_ptr != rend_ptr) { + const __m256d ar = _mm256_loadu_pd(a_ptr); + const __m256d ai = _mm256_loadu_pd(a_ptr + 4); + const __m256d br = _mm256_loadu_pd(b_ptr); + const __m256d bi = _mm256_loadu_pd(b_ptr + 4); + + const __m256d t1 = _mm256_mul_pd(ai, bi); + const __m256d t2 = _mm256_mul_pd(ar, bi); + + __m256d rr = _mm256_fmsub_pd(ar, br, t1); + __m256d ri = _mm256_fmadd_pd(ai, br, t2); + + _mm256_storeu_pd(r_ptr, rr); + _mm256_storeu_pd(r_ptr + 4, ri); + + r_ptr += 8; + a_ptr += 8; + b_ptr += 8; + } +} diff --git a/spqlios/lib/spqlios/reim4/reim4_fftvec_addmul_ref.c b/spqlios/lib/spqlios/reim4/reim4_fftvec_addmul_ref.c new file mode 100644 index 0000000..5254fd6 --- /dev/null +++ b/spqlios/lib/spqlios/reim4/reim4_fftvec_addmul_ref.c @@ -0,0 +1,97 @@ +#include +#include +#include +#include + +#include "../commons_private.h" +#include "reim4_fftvec_internal.h" +#include "reim4_fftvec_private.h" + +void* init_reim4_fftvec_addmul_precomp(REIM4_FFTVEC_ADDMUL_PRECOMP* res, uint32_t m) { + res->m = m; + if (CPU_SUPPORTS("fma")) { + if (m >= 2) { + res->function = reim4_fftvec_addmul_fma; + } else { + res->function = reim4_fftvec_addmul_ref; + } + } else { + res->function = reim4_fftvec_addmul_ref; + } + return res; +} + +EXPORT REIM4_FFTVEC_ADDMUL_PRECOMP* new_reim4_fftvec_addmul_precomp(uint32_t m) { + REIM4_FFTVEC_ADDMUL_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim4_fftvec_addmul_precomp(res, m)); +} + +EXPORT void reim4_fftvec_addmul_ref(const REIM4_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, + const double* b) { + const uint64_t m = precomp->m; + for (uint64_t j = 0; j < m / 4; ++j) { + for (uint64_t i = 0; i < 4; ++i) { + double re = a[i] * b[i] - a[i + 4] * b[i + 4]; + double im = a[i] * b[i + 4] + a[i + 4] * b[i]; + r[i] += re; + r[i + 4] += im; + } + a += 8; + b += 8; + r += 8; + } +} + +EXPORT void reim4_fftvec_addmul_simple(uint32_t m, double* r, const double* a, const double* b) { + static REIM4_FFTVEC_ADDMUL_PRECOMP precomp[32]; + REIM4_FFTVEC_ADDMUL_PRECOMP* p = precomp + log2m(m); + if (!p->function) { + if (!init_reim4_fftvec_addmul_precomp(p, m)) abort(); + } + p->function(p, r, a, b); +} + +void* init_reim4_fftvec_mul_precomp(REIM4_FFTVEC_MUL_PRECOMP* res, uint32_t m) { + res->m = m; + if (CPU_SUPPORTS("fma")) { + if (m >= 4) { + res->function = reim4_fftvec_mul_fma; + } else { + res->function = reim4_fftvec_mul_ref; + } + } else { + res->function = reim4_fftvec_mul_ref; + } + return res; +} + +EXPORT REIM4_FFTVEC_MUL_PRECOMP* new_reim4_fftvec_mul_precomp(uint32_t m) { + REIM4_FFTVEC_MUL_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim4_fftvec_mul_precomp(res, m)); +} + +EXPORT void reim4_fftvec_mul_ref(const REIM4_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) { + const uint64_t m = precomp->m; + for (uint64_t j = 0; j < m / 4; ++j) { + for (uint64_t i = 0; i < 4; ++i) { + double re = a[i] * b[i] - a[i + 4] * b[i + 4]; + double im = a[i] * b[i + 4] + a[i + 4] * b[i]; + r[i] = re; + r[i + 4] = im; + } + a += 8; + b += 8; + r += 8; + } +} + +EXPORT void reim4_fftvec_mul_simple(uint32_t m, double* r, const double* a, const double* b) { + static REIM4_FFTVEC_MUL_PRECOMP precomp[32]; + REIM4_FFTVEC_MUL_PRECOMP* p = precomp + log2m(m); + if (!p->function) { + if (!init_reim4_fftvec_mul_precomp(p, m)) abort(); + } + p->function(p, r, a, b); +} diff --git a/spqlios/lib/spqlios/reim4/reim4_fftvec_conv_fma.c b/spqlios/lib/spqlios/reim4/reim4_fftvec_conv_fma.c new file mode 100644 index 0000000..c175d0e --- /dev/null +++ b/spqlios/lib/spqlios/reim4/reim4_fftvec_conv_fma.c @@ -0,0 +1,37 @@ +#include +#include +#include + +#include "reim4_fftvec_private.h" + +EXPORT void reim4_from_cplx_fma(const REIM4_FROM_CPLX_PRECOMP* tables, double* r_ptr, const void* a) { + const double* const rend_ptr = r_ptr + (tables->m << 1); + + const double* a_ptr = (double*)a; + while (r_ptr != rend_ptr) { + __m256d t1 = _mm256_loadu_pd(a_ptr); + __m256d t2 = _mm256_loadu_pd(a_ptr + 4); + + _mm256_storeu_pd(r_ptr, _mm256_unpacklo_pd(t1, t2)); + _mm256_storeu_pd(r_ptr + 4, _mm256_unpackhi_pd(t1, t2)); + + r_ptr += 8; + a_ptr += 8; + } +} + +EXPORT void reim4_to_cplx_fma(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a_ptr) { + const double* const aend_ptr = a_ptr + (tables->m << 1); + double* r_ptr = (double*)r; + + while (a_ptr != aend_ptr) { + __m256d t1 = _mm256_loadu_pd(a_ptr); + __m256d t2 = _mm256_loadu_pd(a_ptr + 4); + + _mm256_storeu_pd(r_ptr, _mm256_unpacklo_pd(t1, t2)); + _mm256_storeu_pd(r_ptr + 4, _mm256_unpackhi_pd(t1, t2)); + + r_ptr += 8; + a_ptr += 8; + } +} diff --git a/spqlios/lib/spqlios/reim4/reim4_fftvec_conv_ref.c b/spqlios/lib/spqlios/reim4/reim4_fftvec_conv_ref.c new file mode 100644 index 0000000..5bee8ff --- /dev/null +++ b/spqlios/lib/spqlios/reim4/reim4_fftvec_conv_ref.c @@ -0,0 +1,116 @@ +#include +#include +#include +#include + +#include "../commons_private.h" +#include "reim4_fftvec_internal.h" +#include "reim4_fftvec_private.h" + +EXPORT void reim4_from_cplx_ref(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a) { + const double* x = (double*)a; + const uint64_t m = tables->m; + for (uint64_t i = 0; i < m / 4; ++i) { + double r0 = x[0]; + double i0 = x[1]; + double r1 = x[2]; + double i1 = x[3]; + double r2 = x[4]; + double i2 = x[5]; + double r3 = x[6]; + double i3 = x[7]; + r[0] = r0; + r[1] = r2; + r[2] = r1; + r[3] = r3; + r[4] = i0; + r[5] = i2; + r[6] = i1; + r[7] = i3; + x += 8; + r += 8; + } +} + +void* init_reim4_from_cplx_precomp(REIM4_FROM_CPLX_PRECOMP* res, uint32_t nn) { + res->m = nn / 2; + if (CPU_SUPPORTS("fma")) { + if (nn >= 4) { + res->function = reim4_from_cplx_fma; + } else { + res->function = reim4_from_cplx_ref; + } + } else { + res->function = reim4_from_cplx_ref; + } + return res; +} + +EXPORT REIM4_FROM_CPLX_PRECOMP* new_reim4_from_cplx_precomp(uint32_t m) { + REIM4_FROM_CPLX_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim4_from_cplx_precomp(res, m)); +} + +EXPORT void reim4_from_cplx_simple(uint32_t m, double* r, const void* a) { + static REIM4_FROM_CPLX_PRECOMP precomp[32]; + REIM4_FROM_CPLX_PRECOMP* p = precomp + log2m(m); + if (!p->function) { + if (!init_reim4_from_cplx_precomp(p, m)) abort(); + } + p->function(p, r, a); +} + +EXPORT void reim4_to_cplx_ref(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a) { + double* y = (double*)r; + const uint64_t m = tables->m; + for (uint64_t i = 0; i < m / 4; ++i) { + double r0 = a[0]; + double r2 = a[1]; + double r1 = a[2]; + double r3 = a[3]; + double i0 = a[4]; + double i2 = a[5]; + double i1 = a[6]; + double i3 = a[7]; + y[0] = r0; + y[1] = i0; + y[2] = r1; + y[3] = i1; + y[4] = r2; + y[5] = i2; + y[6] = r3; + y[7] = i3; + a += 8; + y += 8; + } +} + +void* init_reim4_to_cplx_precomp(REIM4_TO_CPLX_PRECOMP* res, uint32_t m) { + res->m = m; + if (CPU_SUPPORTS("fma")) { + if (m >= 2) { + res->function = reim4_to_cplx_fma; + } else { + res->function = reim4_to_cplx_ref; + } + } else { + res->function = reim4_to_cplx_ref; + } + return res; +} + +EXPORT REIM4_TO_CPLX_PRECOMP* new_reim4_to_cplx_precomp(uint32_t m) { + REIM4_TO_CPLX_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim4_to_cplx_precomp(res, m)); +} + +EXPORT void reim4_to_cplx_simple(uint32_t m, void* r, const double* a) { + static REIM4_TO_CPLX_PRECOMP precomp[32]; + REIM4_TO_CPLX_PRECOMP* p = precomp + log2m(m); + if (!p->function) { + if (!init_reim4_to_cplx_precomp(p, m)) abort(); + } + p->function(p, r, a); +} diff --git a/spqlios/lib/spqlios/reim4/reim4_fftvec_internal.h b/spqlios/lib/spqlios/reim4/reim4_fftvec_internal.h new file mode 100644 index 0000000..4b076f0 --- /dev/null +++ b/spqlios/lib/spqlios/reim4/reim4_fftvec_internal.h @@ -0,0 +1,20 @@ +#ifndef SPQLIOS_REIM4_FFTVEC_INTERNAL_H +#define SPQLIOS_REIM4_FFTVEC_INTERNAL_H + +#include "reim4_fftvec_public.h" + +EXPORT void reim4_fftvec_mul_ref(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); +EXPORT void reim4_fftvec_mul_fma(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); + +EXPORT void reim4_fftvec_addmul_ref(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, + const double* b); +EXPORT void reim4_fftvec_addmul_fma(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, + const double* b); + +EXPORT void reim4_from_cplx_ref(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a); +EXPORT void reim4_from_cplx_fma(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a); + +EXPORT void reim4_to_cplx_ref(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a); +EXPORT void reim4_to_cplx_fma(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a); + +#endif // SPQLIOS_REIM4_FFTVEC_INTERNAL_H diff --git a/spqlios/lib/spqlios/reim4/reim4_fftvec_private.h b/spqlios/lib/spqlios/reim4/reim4_fftvec_private.h new file mode 100644 index 0000000..98a5286 --- /dev/null +++ b/spqlios/lib/spqlios/reim4/reim4_fftvec_private.h @@ -0,0 +1,33 @@ +#ifndef SPQLIOS_REIM4_FFTVEC_PRIVATE_H +#define SPQLIOS_REIM4_FFTVEC_PRIVATE_H + +#include "reim4_fftvec_public.h" + +#define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)]) + +typedef void (*R4_FFTVEC_MUL_FUNC)(const REIM4_FFTVEC_MUL_PRECOMP*, double*, const double*, const double*); +typedef void (*R4_FFTVEC_ADDMUL_FUNC)(const REIM4_FFTVEC_ADDMUL_PRECOMP*, double*, const double*, const double*); +typedef void (*R4_FROM_CPLX_FUNC)(const REIM4_FROM_CPLX_PRECOMP*, double*, const void*); +typedef void (*R4_TO_CPLX_FUNC)(const REIM4_TO_CPLX_PRECOMP*, void*, const double*); + +struct reim4_mul_precomp { + R4_FFTVEC_MUL_FUNC function; + int64_t m; +}; + +struct reim4_addmul_precomp { + R4_FFTVEC_ADDMUL_FUNC function; + int64_t m; +}; + +struct reim4_from_cplx_precomp { + R4_FROM_CPLX_FUNC function; + int64_t m; +}; + +struct reim4_to_cplx_precomp { + R4_TO_CPLX_FUNC function; + int64_t m; +}; + +#endif // SPQLIOS_REIM4_FFTVEC_PRIVATE_H diff --git a/spqlios/lib/spqlios/reim4/reim4_fftvec_public.h b/spqlios/lib/spqlios/reim4/reim4_fftvec_public.h new file mode 100644 index 0000000..c833dcf --- /dev/null +++ b/spqlios/lib/spqlios/reim4/reim4_fftvec_public.h @@ -0,0 +1,59 @@ +#ifndef SPQLIOS_REIM4_FFTVEC_PUBLIC_H +#define SPQLIOS_REIM4_FFTVEC_PUBLIC_H + +#include "../commons.h" + +typedef struct reim4_addmul_precomp REIM4_FFTVEC_ADDMUL_PRECOMP; +typedef struct reim4_mul_precomp REIM4_FFTVEC_MUL_PRECOMP; +typedef struct reim4_from_cplx_precomp REIM4_FROM_CPLX_PRECOMP; +typedef struct reim4_to_cplx_precomp REIM4_TO_CPLX_PRECOMP; + +EXPORT REIM4_FFTVEC_MUL_PRECOMP* new_reim4_fftvec_mul_precomp(uint32_t m); +EXPORT void reim4_fftvec_mul(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); +#define delete_reim4_fftvec_mul_precomp free + +EXPORT REIM4_FFTVEC_ADDMUL_PRECOMP* new_reim4_fftvec_addmul_precomp(uint32_t nn); +EXPORT void reim4_fftvec_addmul(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, const double* b); +#define delete_reim4_fftvec_addmul_precomp free + +/** + * @brief prepares a conversion from the cplx fftvec layout to the reim4 layout. + * @param m complex dimension m from C[X] mod X^m-i. + */ +EXPORT REIM4_FROM_CPLX_PRECOMP* new_reim4_from_cplx_precomp(uint32_t m); +EXPORT void reim4_from_cplx(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a); +#define delete_reim4_from_cplx_precomp free + +/** + * @brief prepares a conversion from the reim4 fftvec layout to the cplx layout + * @param m the complex dimension m from C[X] mod X^m-i. + */ +EXPORT REIM4_TO_CPLX_PRECOMP* new_reim4_to_cplx_precomp(uint32_t m); +EXPORT void reim4_to_cplx(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a); +#define delete_reim4_to_cplx_precomp free + +/** + * @brief Simpler API for the fftvec multiplication function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim4_fftvec_mul_simple(uint32_t m, double* r, const double* a, const double* b); + +/** + * @brief Simpler API for the fftvec addmul function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim4_fftvec_addmul_simple(uint32_t m, double* r, const double* a, const double* b); + +/** + * @brief Simpler API for cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim4_from_cplx_simple(uint32_t m, double* r, const void* a); + +/** + * @brief Simpler API for to cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim4_to_cplx_simple(uint32_t m, void* r, const double* a); + +#endif // SPQLIOS_REIM4_FFTVEC_PUBLIC_H \ No newline at end of file diff --git a/spqlios/lib/test/CMakeLists.txt b/spqlios/lib/test/CMakeLists.txt new file mode 100644 index 0000000..132ba71 --- /dev/null +++ b/spqlios/lib/test/CMakeLists.txt @@ -0,0 +1,142 @@ +set(CMAKE_CXX_STANDARD 17) + +set(test_incs ..) +set(gtest_libs) +set(benchmark_libs) +# searching for libgtest +find_path(gtest_inc NAMES gtest/gtest.h) +find_library(gtest NAMES gtest) +find_library(gtest_main REQUIRED NAMES gtest_main) +if (gtest_inc AND gtest AND gtest_main) + message(STATUS "Found gtest: I=${gtest_inc} L=${gtest},${gtest_main}") + set(test_incs ${test_incs} ${gtest_inc}) + set(gtest_libs ${gtest_libs} ${gtest} ${gtest_main} pthread) +else() + message(FATAL_ERROR "Libgtest not found (required if ENABLE_TESTING is on): I=${gtest_inc} L=${gtest},${gtest_main}") +endif() +# searching for libbenchmark +find_path(benchmark_inc NAMES benchmark/benchmark.h) +find_library(benchmark NAMES benchmark) +if (benchmark_inc AND benchmark) + message(STATUS "Found benchmark: I=${benchmark_inc} L=${benchmark}") + set(test_incs ${test_incs} ${benchmark_inc}) + set(benchmark_libs ${benchmark_libs} ${benchmark}) +else() + message(FATAL_ERROR "Libbenchmark not found (required if ENABLE_TESTING is on): I=${benchmark_inc} L=${benchmark}") +endif() +find_path(VALGRIND_DIR NAMES valgrind/valgrind.h) +if (VALGRIND_DIR) + message(STATUS "Found valgrind header ${VALGRIND_DIR}") +else () + # for now, we will fail if we don't find valgrind for tests + message(STATUS "CANNOT FIND valgrind header: ${VALGRIND_DIR}") +endif () + +add_library(spqlios-testlib SHARED + testlib/random.cpp + testlib/test_commons.h + testlib/test_commons.cpp + testlib/mod_q120.h + testlib/mod_q120.cpp + testlib/negacyclic_polynomial.cpp + testlib/negacyclic_polynomial.h + testlib/negacyclic_polynomial_impl.h + testlib/reim4_elem.cpp + testlib/reim4_elem.h + testlib/fft64_dft.cpp + testlib/fft64_dft.h + testlib/fft64_layouts.h + testlib/fft64_layouts.cpp + testlib/ntt120_layouts.cpp + testlib/ntt120_layouts.h + testlib/ntt120_dft.cpp + testlib/ntt120_dft.h + testlib/test_hash.cpp + testlib/sha3.h + testlib/sha3.c + testlib/polynomial_vector.h + testlib/polynomial_vector.cpp + testlib/vec_rnx_layout.h + testlib/vec_rnx_layout.cpp + testlib/zn_layouts.h + testlib/zn_layouts.cpp +) +if (VALGRIND_DIR) + target_include_directories(spqlios-testlib PRIVATE ${VALGRIND_DIR}) + target_compile_definitions(spqlios-testlib PRIVATE VALGRIND_MEM_TESTS) +endif () +target_link_libraries(spqlios-testlib libspqlios) + + + +# main unittest file +message(STATUS "${gtest_libs}") +set(UNITTEST_FILES + spqlios_test.cpp + spqlios_reim_conversions_test.cpp + spqlios_reim_test.cpp + spqlios_reim4_arithmetic_test.cpp + spqlios_cplx_test.cpp + spqlios_cplx_conversions_test.cpp + spqlios_q120_ntt_test.cpp + spqlios_q120_arithmetic_test.cpp + spqlios_coeffs_arithmetic_test.cpp + spqlios_vec_znx_big_test.cpp + spqlios_znx_small_test.cpp + spqlios_vmp_product_test.cpp + spqlios_vec_znx_dft_test.cpp + spqlios_svp_test.cpp + spqlios_svp_product_test.cpp + spqlios_vec_znx_test.cpp + spqlios_vec_rnx_test.cpp + spqlios_vec_rnx_vmp_test.cpp + spqlios_vec_rnx_conversions_test.cpp + spqlios_vec_rnx_ppol_test.cpp + spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp + spqlios_zn_approxdecomp_test.cpp + spqlios_zn_conversions_test.cpp + spqlios_zn_vmp_test.cpp + + +) + +add_executable(spqlios-test ${UNITTEST_FILES}) +target_link_libraries(spqlios-test spqlios-testlib libspqlios ${gtest_libs}) +target_include_directories(spqlios-test PRIVATE ${test_incs}) +add_test(NAME spqlios-test COMMAND spqlios-test) +if (WIN32) + # copy the dlls to the test directory + cmake_minimum_required(VERSION 3.26) + add_custom_command( + POST_BUILD + TARGET spqlios-test + COMMAND ${CMAKE_COMMAND} -E copy + -t $ $ $ + COMMAND_EXPAND_LISTS + ) +endif() + +# benchmarks +add_executable(spqlios-cplx-fft-bench spqlios_cplx_fft_bench.cpp) +target_link_libraries(spqlios-cplx-fft-bench libspqlios ${benchmark_libs} pthread) +target_include_directories(spqlios-cplx-fft-bench PRIVATE ${test_incs}) + +if (X86 OR X86_WIN32) + add_executable(spqlios-q120-ntt-bench spqlios_q120_ntt_bench.cpp) + target_link_libraries(spqlios-q120-ntt-bench libspqlios ${benchmark_libs} pthread) + target_include_directories(spqlios-q120-ntt-bench PRIVATE ${test_incs}) + + add_executable(spqlios-q120-arithmetic-bench spqlios_q120_arithmetic_bench.cpp) + target_link_libraries(spqlios-q120-arithmetic-bench libspqlios ${benchmark_libs} pthread) + target_include_directories(spqlios-q120-arithmetic-bench PRIVATE ${test_incs}) +endif () + +if (X86 OR X86_WIN32) + add_executable(spqlios_reim4_arithmetic_bench spqlios_reim4_arithmetic_bench.cpp) + target_link_libraries(spqlios_reim4_arithmetic_bench ${benchmark_libs} libspqlios pthread) + target_include_directories(spqlios_reim4_arithmetic_bench PRIVATE ${test_incs}) +endif () + +if (DEVMODE_INSTALL) + install(TARGETS spqlios-testlib) +endif() diff --git a/spqlios/lib/test/spqlios_coeffs_arithmetic_test.cpp b/spqlios/lib/test/spqlios_coeffs_arithmetic_test.cpp new file mode 100644 index 0000000..dbdfd11 --- /dev/null +++ b/spqlios/lib/test/spqlios_coeffs_arithmetic_test.cpp @@ -0,0 +1,488 @@ +#include +#include + +#include +#include +#include +#include + +#include "../spqlios/coeffs/coeffs_arithmetic.h" +#include "test/testlib/mod_q120.h" +#include "testlib/negacyclic_polynomial.h" +#include "testlib/test_commons.h" + +/// tests of element-wise operations +template +void test_elemw_op(F elemw_op, G poly_elemw_op) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial a = polynomial::random(n); + polynomial b = polynomial::random(n); + polynomial expect(n); + polynomial actual(n); + // out of place + expect = poly_elemw_op(a, b); + elemw_op(n, actual.data(), a.data(), b.data()); + ASSERT_EQ(actual, expect); + // in place 1 + actual = polynomial::random(n); + expect = poly_elemw_op(actual, b); + elemw_op(n, actual.data(), actual.data(), b.data()); + ASSERT_EQ(actual, expect); + // in place 2 + actual = polynomial::random(n); + expect = poly_elemw_op(a, actual); + elemw_op(n, actual.data(), a.data(), actual.data()); + ASSERT_EQ(actual, expect); + // in place 3 + actual = polynomial::random(n); + expect = poly_elemw_op(actual, actual); + elemw_op(n, actual.data(), actual.data(), actual.data()); + ASSERT_EQ(actual, expect); + } +} + +static polynomial poly_i64_add(const polynomial& u, polynomial& v) { return u + v; } +static polynomial poly_i64_sub(const polynomial& u, polynomial& v) { return u - v; } +TEST(coeffs_arithmetic, znx_add_i64_ref) { test_elemw_op(znx_add_i64_ref, poly_i64_add); } +TEST(coeffs_arithmetic, znx_sub_i64_ref) { test_elemw_op(znx_sub_i64_ref, poly_i64_sub); } +#ifdef __x86_64__ +TEST(coeffs_arithmetic, znx_add_i64_avx) { test_elemw_op(znx_add_i64_avx, poly_i64_add); } +TEST(coeffs_arithmetic, znx_sub_i64_avx) { test_elemw_op(znx_sub_i64_avx, poly_i64_sub); } +#endif + +/// tests of element-wise operations +template +void test_elemw_unary_op(F elemw_op, G poly_elemw_op) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial a = polynomial::random(n); + polynomial expect(n); + polynomial actual(n); + // out of place + expect = poly_elemw_op(a); + elemw_op(n, actual.data(), a.data()); + ASSERT_EQ(actual, expect); + // in place + actual = polynomial::random(n); + expect = poly_elemw_op(actual); + elemw_op(n, actual.data(), actual.data()); + ASSERT_EQ(actual, expect); + } +} + +static polynomial poly_i64_neg(const polynomial& u) { return -u; } +static polynomial poly_i64_copy(const polynomial& u) { return u; } +TEST(coeffs_arithmetic, znx_neg_i64_ref) { test_elemw_unary_op(znx_negate_i64_ref, poly_i64_neg); } +TEST(coeffs_arithmetic, znx_copy_i64_ref) { test_elemw_unary_op(znx_copy_i64_ref, poly_i64_copy); } +#ifdef __x86_64__ +TEST(coeffs_arithmetic, znx_neg_i64_avx) { test_elemw_unary_op(znx_negate_i64_avx, poly_i64_neg); } +#endif + +/// tests of the rotations out of place +template +void test_rotation_outplace(F rotate) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial poly = polynomial::random(n); + polynomial expect(n); + polynomial actual(n); + for (uint64_t trial = 0; trial < 10; ++trial) { + int64_t p = uniform_i64_bits(32); + // rotate by p + for (uint64_t i = 0; i < n; ++i) { + expect.set_coeff(i, poly.get_coeff(i - p)); + } + // rotate using the function + rotate(n, p, actual.data(), poly.data()); + ASSERT_EQ(actual, expect); + } + } +} + +TEST(coeffs_arithmetic, rnx_rotate_f64) { test_rotation_outplace(rnx_rotate_f64); } +TEST(coeffs_arithmetic, znx_rotate_i64) { test_rotation_outplace(znx_rotate_i64); } + +/// tests of the rotations out of place +template +void test_rotation_inplace(F rotate) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial poly = polynomial::random(n); + polynomial expect(n); + for (uint64_t trial = 0; trial < 10; ++trial) { + polynomial actual = poly; + int64_t p = uniform_i64_bits(32); + // rotate by p + for (uint64_t i = 0; i < n; ++i) { + expect.set_coeff(i, poly.get_coeff(i - p)); + } + // rotate using the function + rotate(n, p, actual.data()); + ASSERT_EQ(actual, expect); + } + } +} + +TEST(coeffs_arithmetic, rnx_rotate_inplace_f64) { test_rotation_inplace(rnx_rotate_inplace_f64); } + +TEST(coeffs_arithmetic, znx_rotate_inplace_i64) { test_rotation_inplace(znx_rotate_inplace_i64); } + +/// tests of the rotations out of place +template +void test_mul_xp_minus_one_outplace(F rotate) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial poly = polynomial::random(n); + polynomial expect(n); + polynomial actual(n); + for (uint64_t trial = 0; trial < 10; ++trial) { + int64_t p = uniform_i64_bits(32); + // rotate by p + for (uint64_t i = 0; i < n; ++i) { + expect.set_coeff(i, poly.get_coeff(i - p) - poly.get_coeff(i)); + } + // rotate using the function + rotate(n, p, actual.data(), poly.data()); + ASSERT_EQ(actual, expect); + } + } +} + +TEST(coeffs_arithmetic, rnx_mul_xp_minus_one_f64) { test_mul_xp_minus_one_outplace(rnx_mul_xp_minus_one); } +TEST(coeffs_arithmetic, znx_mul_xp_minus_one_i64) { test_mul_xp_minus_one_outplace(znx_mul_xp_minus_one); } + +/// tests of the rotations out of place +template +void test_mul_xp_minus_one_inplace(F rotate) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial poly = polynomial::random(n); + polynomial expect(n); + for (uint64_t trial = 0; trial < 10; ++trial) { + polynomial actual = poly; + int64_t p = uniform_i64_bits(32); + // rotate by p + for (uint64_t i = 0; i < n; ++i) { + expect.set_coeff(i, poly.get_coeff(i - p) - poly.get_coeff(i)); + } + // rotate using the function + rotate(n, p, actual.data()); + ASSERT_EQ(actual, expect); + } + } +} + +TEST(coeffs_arithmetic, rnx_mul_xp_minus_one_inplace_f64) { + test_mul_xp_minus_one_inplace(rnx_mul_xp_minus_one_inplace); +} + +// TEST(coeffs_arithmetic, znx_mul_xp_minus_one_inplace_i64) { +// test_mul_xp_minus_one_inplace(znx_rotate_inplace_i64); } +/// tests of the automorphisms out of place +template +void test_automorphism_outplace(F automorphism) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial poly = polynomial::random(n); + polynomial expect(n); + polynomial actual(n); + for (uint64_t trial = 0; trial < 10; ++trial) { + int64_t p = uniform_i64_bits(32) | int64_t(1); // make it odd + // automorphism p + for (uint64_t i = 0; i < n; ++i) { + expect.set_coeff(i * p, poly.get_coeff(i)); + } + // rotate using the function + automorphism(n, p, actual.data(), poly.data()); + ASSERT_EQ(actual, expect); + } + } +} + +TEST(coeffs_arithmetic, rnx_automorphism_f64) { test_automorphism_outplace(rnx_automorphism_f64); } +TEST(coeffs_arithmetic, znx_automorphism_i64) { test_automorphism_outplace(znx_automorphism_i64); } + +/// tests of the automorphisms out of place +template +void test_automorphism_inplace(F automorphism) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096, 16384}) { + polynomial poly = polynomial::random(n); + polynomial expect(n); + for (uint64_t trial = 0; trial < 20; ++trial) { + polynomial actual = poly; + int64_t p = uniform_i64_bits(32) | int64_t(1); // make it odd + // automorphism p + for (uint64_t i = 0; i < n; ++i) { + expect.set_coeff(i * p, poly.get_coeff(i)); + } + automorphism(n, p, actual.data()); + if (!(actual == expect)) { + std::cerr << "automorphism p: " << p << std::endl; + for (uint64_t i = 0; i < n; ++i) { + std::cerr << i << " " << actual.get_coeff(i) << " vs " << expect.get_coeff(i) << " " + << (actual.get_coeff(i) == expect.get_coeff(i)) << std::endl; + } + } + ASSERT_EQ(actual, expect); + } + } +} +TEST(coeffs_arithmetic, rnx_automorphism_inplace_f64) { + test_automorphism_inplace(rnx_automorphism_inplace_f64); +} +TEST(coeffs_arithmetic, znx_automorphism_inplace_i64) { + test_automorphism_inplace(znx_automorphism_inplace_i64); +} + +// TODO: write a test later! + +/** + * @brief res = (X^p-1).in + * @param nn the ring dimension + * @param p must be between -2nn <= p <= 2nn + * @param in is a rnx/znx vector of dimension nn + */ +EXPORT void rnx_mul_xp_minus_one(uint64_t nn, int64_t p, double* res, const double* in); +EXPORT void znx_mul_xp_minus_one(uint64_t nn, int64_t p, int64_t* res, const int64_t* in); + +// normalize with no carry in nor carry out +template +void test_znx_normalize(F normalize) { + for (const uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial inp = znx_i64::random_log2bound(n, 62); + if (n >= 2) { + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, (INT64_C(1) << 62)); + } + for (const uint64_t base_k : {2, 3, 19, 35, 62}) { + polynomial out; + int64_t* inp_ptr; + if (inplace_flag == 1) { + out = polynomial(inp); + inp_ptr = out.data(); + } else { + out = polynomial(n); + inp_ptr = inp.data(); + } + + znx_normalize(n, base_k, out.data(), nullptr, inp_ptr, nullptr); + for (uint64_t i = 0; i < n; ++i) { + const int64_t x = inp.get_coeff(i); + const int64_t y = out.get_coeff(i); + const int64_t y_exp = centermod(x, INT64_C(1) << base_k); + ASSERT_EQ(y, y_exp) << n << " " << base_k << " " << i << " " << x << " " << y; + } + } + } +} + +TEST(coeffs_arithmetic, znx_normalize_outplace) { test_znx_normalize<0>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_inplace) { test_znx_normalize<1>(znx_normalize); } + +// normalize with no carry in nor carry out +template +void test_znx_normalize_cout(F normalize) { + static_assert(inplace_flag < 3, "either out or cout can be inplace with inp"); + for (const uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial inp = znx_i64::random_log2bound(n, 62); + if (n >= 2) { + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, (INT64_C(1) << 62)); + } + for (const uint64_t base_k : {2, 3, 19, 35, 62}) { + polynomial out, cout; + int64_t* inp_ptr; + if (inplace_flag == 1) { + // out and inp are the same + out = polynomial(inp); + inp_ptr = out.data(); + cout = polynomial(n); + } else if (inplace_flag == 2) { + // carry out and inp are the same + cout = polynomial(inp); + inp_ptr = cout.data(); + out = polynomial(n); + } else { + // inp, out and carry out are distinct + out = polynomial(n); + cout = polynomial(n); + inp_ptr = inp.data(); + } + + znx_normalize(n, base_k, has_output ? out.data() : nullptr, cout.data(), inp_ptr, nullptr); + for (uint64_t i = 0; i < n; ++i) { + const int64_t x = inp.get_coeff(i); + const int64_t co = cout.get_coeff(i); + const int64_t y_exp = centermod((int64_t)x, INT64_C(1) << base_k); + const int64_t co_exp = (x - y_exp) >> base_k; + ASSERT_EQ(co, co_exp); + + if (has_output) { + const int64_t y = out.get_coeff(i); + ASSERT_EQ(y, y_exp); + } + } + } + } +} + +TEST(coeffs_arithmetic, znx_normalize_cout_outplace) { test_znx_normalize_cout<0, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cout_outplace) { test_znx_normalize_cout<0, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cout_inplace1) { test_znx_normalize_cout<1, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cout_inplace1) { test_znx_normalize_cout<1, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cout_inplace2) { test_znx_normalize_cout<2, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cout_inplace2) { test_znx_normalize_cout<2, true>(znx_normalize); } + +// normalize with no carry in nor carry out +template +void test_znx_normalize_cin(F normalize) { + static_assert(inplace_flag < 3, "either inp or cin can be inplace with out"); + for (const uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial inp = znx_i64::random_log2bound(n, 62); + if (n >= 4) { + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, -(INT64_C(1) << 62)); + inp.set_coeff(2, (INT64_C(1) << 62)); + inp.set_coeff(3, (INT64_C(1) << 62)); + } + for (const uint64_t base_k : {2, 3, 19, 35, 62}) { + polynomial cin = znx_i64::random_log2bound(n, 62); + if (n >= 4) { + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, (INT64_C(1) << 62)); + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, (INT64_C(1) << 62)); + } + + polynomial out; + int64_t *inp_ptr, *cin_ptr; + if (inplace_flag == 1) { + // out and inp are the same + out = polynomial(inp); + inp_ptr = out.data(); + cin_ptr = cin.data(); + } else if (inplace_flag == 2) { + // out and carry in are the same + out = polynomial(cin); + inp_ptr = inp.data(); + cin_ptr = out.data(); + } else { + // inp, carry in and out are distinct + out = polynomial(n); + inp_ptr = inp.data(); + cin_ptr = cin.data(); + } + + znx_normalize(n, base_k, out.data(), nullptr, inp_ptr, cin_ptr); + for (uint64_t i = 0; i < n; ++i) { + const int64_t x = inp.get_coeff(i); + const int64_t ci = cin.get_coeff(i); + const int64_t y = out.get_coeff(i); + + const __int128_t xp = (__int128_t)x + ci; + const int64_t y_exp = centermod((int64_t)xp, INT64_C(1) << base_k); + + ASSERT_EQ(y, y_exp) << n << " " << base_k << " " << i << " " << x << " " << y << " " << ci; + } + } + } +} + +TEST(coeffs_arithmetic, znx_normalize_cin_outplace) { test_znx_normalize_cin<0>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_inplace1) { test_znx_normalize_cin<1>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_inplace2) { test_znx_normalize_cin<2>(znx_normalize); } + +// normalize with no carry in nor carry out +template +void test_znx_normalize_cin_cout(F normalize) { + static_assert(inplace_flag < 7, "either inp or cin can be inplace with out"); + for (const uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial inp = znx_i64::random_log2bound(n, 62); + if (n >= 4) { + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, -(INT64_C(1) << 62)); + inp.set_coeff(2, (INT64_C(1) << 62)); + inp.set_coeff(3, (INT64_C(1) << 62)); + } + for (const uint64_t base_k : {2, 3, 19, 35, 62}) { + polynomial cin = znx_i64::random_log2bound(n, 62); + if (n >= 4) { + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, (INT64_C(1) << 62)); + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, (INT64_C(1) << 62)); + } + + polynomial out, cout; + int64_t *inp_ptr, *cin_ptr; + if (inplace_flag == 1) { + // out == inp + out = polynomial(inp); + cout = polynomial(n); + inp_ptr = out.data(); + cin_ptr = cin.data(); + } else if (inplace_flag == 2) { + // cout == inp + out = polynomial(n); + cout = polynomial(inp); + inp_ptr = cout.data(); + cin_ptr = cin.data(); + } else if (inplace_flag == 3) { + // out == cin + out = polynomial(cin); + cout = polynomial(n); + inp_ptr = inp.data(); + cin_ptr = out.data(); + } else if (inplace_flag == 4) { + // cout == cin + out = polynomial(n); + cout = polynomial(cin); + inp_ptr = inp.data(); + cin_ptr = cout.data(); + } else if (inplace_flag == 5) { + // out == inp, cout == cin + out = polynomial(inp); + cout = polynomial(cin); + inp_ptr = out.data(); + cin_ptr = cout.data(); + } else if (inplace_flag == 6) { + // out == cin, cout == inp + out = polynomial(cin); + cout = polynomial(inp); + inp_ptr = cout.data(); + cin_ptr = out.data(); + } else { + out = polynomial(n); + cout = polynomial(n); + inp_ptr = inp.data(); + cin_ptr = cin.data(); + } + + znx_normalize(n, base_k, has_output ? out.data() : nullptr, cout.data(), inp_ptr, cin_ptr); + for (uint64_t i = 0; i < n; ++i) { + const int64_t x = inp.get_coeff(i); + const int64_t ci = cin.get_coeff(i); + const int64_t co = cout.get_coeff(i); + + const __int128_t xp = (__int128_t)x + ci; + const int64_t y_exp = centermod((int64_t)xp, INT64_C(1) << base_k); + const int64_t co_exp = (xp - y_exp) >> base_k; + ASSERT_EQ(co, co_exp); + + if (has_output) { + const int64_t y = out.get_coeff(i); + ASSERT_EQ(y, y_exp); + } + } + } + } +} + +TEST(coeffs_arithmetic, znx_normalize_cin_cout_outplace) { test_znx_normalize_cin_cout<0, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cin_cout_outplace) { test_znx_normalize_cin_cout<0, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_cout_inplace1) { test_znx_normalize_cin_cout<1, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cin_cout_inplace1) { test_znx_normalize_cin_cout<1, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_cout_inplace2) { test_znx_normalize_cin_cout<2, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cin_cout_inplace2) { test_znx_normalize_cin_cout<2, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_cout_inplace3) { test_znx_normalize_cin_cout<3, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cin_cout_inplace3) { test_znx_normalize_cin_cout<3, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_cout_inplace4) { test_znx_normalize_cin_cout<4, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cin_cout_inplace4) { test_znx_normalize_cin_cout<4, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_cout_inplace5) { test_znx_normalize_cin_cout<5, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cin_cout_inplace5) { test_znx_normalize_cin_cout<5, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_cout_inplace6) { test_znx_normalize_cin_cout<6, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cin_cout_inplace6) { test_znx_normalize_cin_cout<6, true>(znx_normalize); } diff --git a/spqlios/lib/test/spqlios_cplx_conversions_test.cpp b/spqlios/lib/test/spqlios_cplx_conversions_test.cpp new file mode 100644 index 0000000..32c9158 --- /dev/null +++ b/spqlios/lib/test/spqlios_cplx_conversions_test.cpp @@ -0,0 +1,86 @@ +#include + +#include + +#include "spqlios/cplx/cplx_fft_internal.h" +#include "spqlios/cplx/cplx_fft_private.h" + +#ifdef __x86_64__ +TEST(fft, cplx_from_znx32_ref_vs_fma) { + const uint32_t m = 128; + int32_t* src = (int32_t*)spqlios_alloc_custom_align(32, 10 * m * sizeof(int32_t)); + CPLX* dst1 = (CPLX*)(src + 2 * m); + CPLX* dst2 = (CPLX*)(src + 6 * m); + for (uint64_t i = 0; i < 2 * m; ++i) { + src[i] = rand() - RAND_MAX / 2; + } + CPLX_FROM_ZNX32_PRECOMP precomp; + precomp.m = m; + cplx_from_znx32_ref(&precomp, dst1, src); + // cplx_from_znx32_simple(m, 32, dst1, src); + cplx_from_znx32_avx2_fma(&precomp, dst2, src); + for (uint64_t i = 0; i < m; ++i) { + ASSERT_EQ(dst1[i][0], dst2[i][0]); + ASSERT_EQ(dst1[i][1], dst2[i][1]); + } + spqlios_free(src); +} +#endif + +#ifdef __x86_64__ +TEST(fft, cplx_from_tnx32_ref_vs_fma) { + const uint32_t m = 128; + int32_t* src = (int32_t*)spqlios_alloc_custom_align(32, 10 * m * sizeof(int32_t)); + CPLX* dst1 = (CPLX*)(src + 2 * m); + CPLX* dst2 = (CPLX*)(src + 6 * m); + for (uint64_t i = 0; i < 2 * m; ++i) { + src[i] = rand() + (rand() << 20); + } + CPLX_FROM_TNX32_PRECOMP precomp; + precomp.m = m; + cplx_from_tnx32_ref(&precomp, dst1, src); + // cplx_from_tnx32_simple(m, dst1, src); + cplx_from_tnx32_avx2_fma(&precomp, dst2, src); + for (uint64_t i = 0; i < m; ++i) { + ASSERT_EQ(dst1[i][0], dst2[i][0]); + ASSERT_EQ(dst1[i][1], dst2[i][1]); + } + spqlios_free(src); +} +#endif + +#ifdef __x86_64__ +TEST(fft, cplx_to_tnx32_ref_vs_fma) { + for (const uint32_t m : {8, 128, 1024, 65536}) { + for (const double divisor : {double(1), double(m), double(0.5)}) { + CPLX* src = (CPLX*)spqlios_alloc_custom_align(32, 10 * m * sizeof(int32_t)); + int32_t* dst1 = (int32_t*)(src + m); + int32_t* dst2 = (int32_t*)(src + 2 * m); + for (uint64_t i = 0; i < 2 * m; ++i) { + src[i][0] = (rand() / double(RAND_MAX) - 0.5) * pow(2., 19 - (rand() % 60)) * divisor; + src[i][1] = (rand() / double(RAND_MAX) - 0.5) * pow(2., 19 - (rand() % 60)) * divisor; + } + CPLX_TO_TNX32_PRECOMP precomp; + precomp.m = m; + precomp.divisor = divisor; + cplx_to_tnx32_ref(&precomp, dst1, src); + cplx_to_tnx32_avx2_fma(&precomp, dst2, src); + // cplx_to_tnx32_simple(m, divisor, 18, dst2, src); + for (uint64_t i = 0; i < 2 * m; ++i) { + double truevalue = + (src[i % m][i / m] / divisor - floor(src[i % m][i / m] / divisor + 0.5)) * (INT64_C(1) << 32); + if (fabs(truevalue - floor(truevalue)) == 0.5) { + // ties can differ by 0, 1 or -1 + ASSERT_LE(abs(dst1[i] - dst2[i]), 0) + << i << " " << dst1[i] << " " << dst2[i] << " " << truevalue << std::endl; + } else { + // otherwise, we should have equality + ASSERT_LE(abs(dst1[i] - dst2[i]), 0) + << i << " " << dst1[i] << " " << dst2[i] << " " << truevalue << std::endl; + } + } + spqlios_free(src); + } + } +} +#endif diff --git a/spqlios/lib/test/spqlios_cplx_fft_bench.cpp b/spqlios/lib/test/spqlios_cplx_fft_bench.cpp new file mode 100644 index 0000000..19f4905 --- /dev/null +++ b/spqlios/lib/test/spqlios_cplx_fft_bench.cpp @@ -0,0 +1,112 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../spqlios/cplx/cplx_fft_internal.h" +#include "spqlios/reim/reim_fft.h" + +using namespace std; + +void init_random_values(uint64_t n, double* v) { + for (uint64_t i = 0; i < n; ++i) v[i] = rand() - (RAND_MAX >> 1); +} + +void benchmark_cplx_fft(benchmark::State& state) { + const int32_t nn = state.range(0); + CPLX_FFT_PRECOMP* a = new_cplx_fft_precomp(nn / 2, 1); + double* c = (double*)cplx_fft_precomp_get_buffer(a, 0); + init_random_values(nn, c); + for (auto _ : state) { + // cplx_fft_simple(nn/2, c); + cplx_fft(a, c); + } + delete_cplx_fft_precomp(a); +} + +void benchmark_cplx_ifft(benchmark::State& state) { + const int32_t nn = state.range(0); + CPLX_IFFT_PRECOMP* a = new_cplx_ifft_precomp(nn / 2, 1); + double* c = (double*)cplx_ifft_precomp_get_buffer(a, 0); + init_random_values(nn, c); + for (auto _ : state) { + // cplx_ifft_simple(nn/2, c); + cplx_ifft(a, c); + } + delete_cplx_ifft_precomp(a); +} + +void benchmark_reim_fft(benchmark::State& state) { + const int32_t nn = state.range(0); + const uint32_t m = nn / 2; + REIM_FFT_PRECOMP* a = new_reim_fft_precomp(m, 1); + double* c = reim_fft_precomp_get_buffer(a, 0); + init_random_values(nn, c); + for (auto _ : state) { + // cplx_fft_simple(nn/2, c); + reim_fft(a, c); + } + delete_reim_fft_precomp(a); +} + +#ifdef __aarch64__ +EXPORT REIM_FFT_PRECOMP* new_reim_fft_precomp_neon(uint32_t m, uint32_t num_buffers); +EXPORT void reim_fft_neon(const REIM_FFT_PRECOMP* precomp, double* d); + +void benchmark_reim_fft_neon(benchmark::State& state) { + const int32_t nn = state.range(0); + const uint32_t m = nn / 2; + REIM_FFT_PRECOMP* a = new_reim_fft_precomp_neon(m, 1); + double* c = reim_fft_precomp_get_buffer(a, 0); + init_random_values(nn, c); + for (auto _ : state) { + // cplx_fft_simple(nn/2, c); + reim_fft_neon(a, c); + } + delete_reim_fft_precomp(a); +} +#endif + +void benchmark_reim_ifft(benchmark::State& state) { + const int32_t nn = state.range(0); + const uint32_t m = nn / 2; + REIM_IFFT_PRECOMP* a = new_reim_ifft_precomp(m, 1); + double* c = reim_ifft_precomp_get_buffer(a, 0); + init_random_values(nn, c); + for (auto _ : state) { + // cplx_ifft_simple(nn/2, c); + reim_ifft(a, c); + } + delete_reim_ifft_precomp(a); +} + +// #define ARGS Arg(1024)->Arg(8192)->Arg(32768)->Arg(65536) +#define ARGS Arg(64)->Arg(256)->Arg(1024)->Arg(2048)->Arg(4096)->Arg(8192)->Arg(16384)->Arg(32768)->Arg(65536) + +int main(int argc, char** argv) { + ::benchmark::Initialize(&argc, argv); + if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return 1; + std::cout << "Dimensions n in the benchmark below are in \"real FFT\" modulo X^n+1" << std::endl; + std::cout << "The complex dimension m (modulo X^m-i) is half of it" << std::endl; + BENCHMARK(benchmark_cplx_fft)->ARGS; + BENCHMARK(benchmark_cplx_ifft)->ARGS; + BENCHMARK(benchmark_reim_fft)->ARGS; +#ifdef __aarch64__ + BENCHMARK(benchmark_reim_fft_neon)->ARGS; +#endif + BENCHMARK(benchmark_reim_ifft)->ARGS; + // if (CPU_SUPPORTS("avx512f")) { + // BENCHMARK(bench_cplx_fftvec_twiddle_avx512)->ARGS; + // BENCHMARK(bench_cplx_fftvec_bitwiddle_avx512)->ARGS; + //} + ::benchmark::RunSpecifiedBenchmarks(); + ::benchmark::Shutdown(); + return 0; +} diff --git a/spqlios/lib/test/spqlios_cplx_test.cpp b/spqlios/lib/test/spqlios_cplx_test.cpp new file mode 100644 index 0000000..13c5513 --- /dev/null +++ b/spqlios/lib/test/spqlios_cplx_test.cpp @@ -0,0 +1,496 @@ +#include + +#include "gtest/gtest.h" +#include "spqlios/commons_private.h" +#include "spqlios/cplx/cplx_fft.h" +#include "spqlios/cplx/cplx_fft_internal.h" +#include "spqlios/cplx/cplx_fft_private.h" + +#ifdef __x86_64__ +TEST(fft, ifft16_fma_vs_ref) { + CPLX data[16]; + CPLX omega[8]; + for (uint64_t i = 0; i < 32; ++i) ((double*)data)[i] = 2 * i + 1; //(rand()%100)-50; + for (uint64_t i = 0; i < 16; ++i) ((double*)omega)[i] = i + 1; //(rand()%100)-50; + CPLX copydata[16]; + CPLX copyomega[8]; + memcpy(copydata, data, sizeof(copydata)); + memcpy(copyomega, omega, sizeof(copyomega)); + cplx_ifft16_avx_fma(data, omega); + cplx_ifft16_ref(copydata, copyomega); + double distance = 0; + for (uint64_t i = 0; i < 16; ++i) { + double d1 = fabs(data[i][0] - copydata[i][0]); + double d2 = fabs(data[i][0] - copydata[i][0]); + if (d1 > distance) distance = d1; + if (d2 > distance) distance = d2; + } + /* + printf("data:\n"); + for (uint64_t i=0; i<4; ++i) { + for (uint64_t j=0; j<8; ++j) { + printf("%.5lf ", data[4 * i + j / 2][j % 2]); + } + printf("\n"); + } + printf("copydata:\n"); + for (uint64_t i=0; i<4; ++i) { + for (uint64_t j=0; j<8; ++j) { + printf("%5.5lf ", copydata[4 * i + j / 2][j % 2]); + } + printf("\n"); + } + */ + ASSERT_EQ(distance, 0); +} + +#endif + +void cplx_zero(CPLX r) { r[0] = r[1] = 0; } +void cplx_addmul(CPLX r, const CPLX a, const CPLX b) { + double re = r[0] + a[0] * b[0] - a[1] * b[1]; + double im = r[1] + a[0] * b[1] + a[1] * b[0]; + r[0] = re; + r[1] = im; +} + +void halfcfft_eval(CPLX res, uint32_t nn, uint32_t k, const CPLX* coeffs, const CPLX* powomegas) { + const uint32_t N = nn / 2; + cplx_zero(res); + for (uint64_t i = 0; i < N; ++i) { + cplx_addmul(res, coeffs[i], powomegas[(k * i) % (2 * nn)]); + } +} +void halfcfft_naive(uint32_t nn, CPLX* data) { + const uint32_t N = nn / 2; + CPLX* in = (CPLX*)malloc(N * sizeof(CPLX)); + CPLX* powomega = (CPLX*)malloc(2 * nn * sizeof(CPLX)); + for (uint64_t i = 0; i < (2 * nn); ++i) { + powomega[i][0] = m_accurate_cos((M_PI * i) / nn); + powomega[i][1] = m_accurate_sin((M_PI * i) / nn); + } + memcpy(in, data, N * sizeof(CPLX)); + for (uint64_t j = 0; j < N; ++j) { + uint64_t p = rint(log2(N)) + 2; + uint64_t k = revbits(p, j) + 1; + halfcfft_eval(data[j], nn, k, in, powomega); + } + free(powomega); + free(in); +} + +#ifdef __x86_64__ +TEST(fft, fft16_fma_vs_ref) { + CPLX data[16]; + CPLX omega[8]; + for (uint64_t i = 0; i < 32; ++i) ((double*)data)[i] = rand() % 1000; + for (uint64_t i = 0; i < 16; ++i) ((double*)omega)[i] = rand() % 1000; + CPLX copydata[16]; + CPLX copyomega[8]; + memcpy(copydata, data, sizeof(copydata)); + memcpy(copyomega, omega, sizeof(copyomega)); + cplx_fft16_avx_fma(data, omega); + cplx_fft16_ref(copydata, omega); + double distance = 0; + for (uint64_t i = 0; i < 16; ++i) { + double d1 = fabs(data[i][0] - copydata[i][0]); + double d2 = fabs(data[i][0] - copydata[i][0]); + if (d1 > distance) distance = d1; + if (d2 > distance) distance = d2; + } + ASSERT_EQ(distance, 0); +} +#endif + +TEST(fft, citwiddle_then_invcitwiddle) { + CPLX om; + CPLX ombar; + CPLX data[2]; + CPLX copydata[2]; + om[0] = cos(3); + om[1] = sin(3); + ombar[0] = om[0]; + ombar[1] = -om[1]; + data[0][0] = 47; + data[0][1] = 23; + data[1][0] = -12; + data[1][1] = -9; + memcpy(copydata, data, sizeof(copydata)); + citwiddle(data[0], data[1], om); + invcitwiddle(data[0], data[1], ombar); + double distance = 0; + for (uint64_t i = 0; i < 2; ++i) { + double d1 = fabs(data[i][0] - 2 * copydata[i][0]); + double d2 = fabs(data[i][1] - 2 * copydata[i][1]); + if (d1 > distance) distance = d1; + if (d2 > distance) distance = d2; + } + ASSERT_LE(distance, 1e-9); +} + +TEST(fft, ctwiddle_then_invctwiddle) { + CPLX om; + CPLX ombar; + CPLX data[2]; + CPLX copydata[2]; + om[0] = cos(3); + om[1] = sin(3); + ombar[0] = om[0]; + ombar[1] = -om[1]; + data[0][0] = 47; + data[0][1] = 23; + data[1][0] = -12; + data[1][1] = -9; + memcpy(copydata, data, sizeof(copydata)); + ctwiddle(data[0], data[1], om); + invctwiddle(data[0], data[1], ombar); + double distance = 0; + for (uint64_t i = 0; i < 2; ++i) { + double d1 = fabs(data[i][0] - 2 * copydata[i][0]); + double d2 = fabs(data[i][1] - 2 * copydata[i][1]); + if (d1 > distance) distance = d1; + if (d2 > distance) distance = d2; + } + ASSERT_LE(distance, 1e-9); +} + +TEST(fft, fft16_then_ifft16_ref) { + CPLX full_omegas[64]; + CPLX full_omegabars[64]; + for (uint64_t i = 0; i < 64; ++i) { + full_omegas[i][0] = cos(M_PI * i / 32.); + full_omegas[i][1] = sin(M_PI * i / 32.); + full_omegabars[i][0] = full_omegas[i][0]; + full_omegabars[i][1] = -full_omegas[i][1]; + } + CPLX omega[8]; + CPLX omegabar[8]; + cplx_set(omega[0], full_omegas[8]); // j + cplx_set(omega[1], full_omegas[4]); // k + cplx_set(omega[2], full_omegas[2]); // l + cplx_set(omega[3], full_omegas[10]); // lj + cplx_set(omega[4], full_omegas[1]); // n + cplx_set(omega[5], full_omegas[9]); // nj + cplx_set(omega[6], full_omegas[5]); // nk + cplx_set(omega[7], full_omegas[13]); // njk + cplx_set(omegabar[0], full_omegabars[1]); // n + cplx_set(omegabar[1], full_omegabars[9]); // nj + cplx_set(omegabar[2], full_omegabars[5]); // nk + cplx_set(omegabar[3], full_omegabars[13]); // njk + cplx_set(omegabar[4], full_omegabars[2]); // l + cplx_set(omegabar[5], full_omegabars[10]); // lj + cplx_set(omegabar[6], full_omegabars[4]); // k + cplx_set(omegabar[7], full_omegabars[8]); // j + CPLX data[16]; + CPLX copydata[16]; + for (uint64_t i = 0; i < 32; ++i) ((double*)data)[i] = rand() % 1000; + memcpy(copydata, data, sizeof(copydata)); + cplx_fft16_ref(data, omega); + cplx_ifft16_ref(data, omegabar); + double distance = 0; + for (uint64_t i = 0; i < 16; ++i) { + double d1 = fabs(data[i][0] - 16 * copydata[i][0]); + double d2 = fabs(data[i][0] - 16 * copydata[i][0]); + if (d1 > distance) distance = d1; + if (d2 > distance) distance = d2; + } + ASSERT_LE(distance, 1e-9); +} + +TEST(fft, halfcfft_ref_vs_naive) { + for (uint64_t nn : {4, 8, 16, 64, 256, 8192}) { + uint64_t m = nn / 2; + CPLX_FFT_PRECOMP* tables = new_cplx_fft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, m * sizeof(CPLX)); + CPLX* a1 = (CPLX*)spqlios_alloc_custom_align(32, m * sizeof(CPLX)); + CPLX* a2 = (CPLX*)spqlios_alloc_custom_align(32, m * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < m; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, m * sizeof(CPLX)); + memcpy(a2, a, m * sizeof(CPLX)); + + halfcfft_naive(nn, a1); + cplx_fft_naive(m, 0.25, a2); + + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a1[i][0] - a2[i][0]); + double dim = fabs(a1[i][1] - a2[i][1]); + if (dre > d) d = dre; + if (dim > d) d = dim; + } + ASSERT_LE(d, nn * 1e-10) << nn; + spqlios_free(a); + spqlios_free(a1); + spqlios_free(a2); + delete_cplx_fft_precomp(tables); + } +} + +#ifdef __x86_64__ +TEST(fft, halfcfft_fma_vs_ref) { + typedef void (*FFTF)(const CPLX_FFT_PRECOMP*, void* data); + for (FFTF fft : {cplx_fft_ref, cplx_fft_avx2_fma}) { + for (uint64_t nn : {8, 16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + CPLX_FFT_PRECOMP* tables = new_cplx_fft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + CPLX* a1 = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + CPLX* a2 = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn / 2; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, nn / 2 * sizeof(CPLX)); + memcpy(a2, a, nn / 2 * sizeof(CPLX)); + cplx_fft_naive(m, 0.25, a2); + fft(tables, a1); + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a1[i][0] - a2[i][0]); + double dim = fabs(a1[i][1] - a2[i][1]); + if (dre > d) d = dre; + if (dim > d) d = dim; + } + ASSERT_LE(d, nn * 1e-10) << nn; + spqlios_free(a); + spqlios_free(a1); + spqlios_free(a2); + delete_cplx_fft_precomp(tables); + } + } +} +#endif + +TEST(fft, halfcfft_then_ifft_ref) { + for (uint64_t nn : {4, 8, 16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + CPLX_FFT_PRECOMP* tables = new_cplx_fft_precomp(m, 0); + CPLX_IFFT_PRECOMP* itables = new_cplx_ifft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + CPLX* a1 = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn / 2; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, nn / 2 * sizeof(CPLX)); + cplx_fft_ref(tables, a1); + cplx_ifft_ref(itables, a1); + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a[i][0] - a1[i][0] / (nn / 2)); + double dim = fabs(a[i][1] - a1[i][1] / (nn / 2)); + if (dre > d) d = dre; + if (dim > d) d = dim; + } + ASSERT_LE(d, 1e-8); + spqlios_free(a); + spqlios_free(a1); + delete_cplx_fft_precomp(tables); + delete_cplx_ifft_precomp(itables); + } +} + +#ifdef __x86_64__ +TEST(fft, halfcfft_ifft_fma_vs_ref) { + for (IFFT_FUNCTION ifft : {cplx_ifft_ref, cplx_ifft_avx2_fma}) { + for (uint64_t nn : {8, 16, 32, 1024, 4096, 8192, 65536}) { + uint64_t m = nn / 2; + CPLX_IFFT_PRECOMP* itables = new_cplx_ifft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + CPLX* a1 = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + CPLX* a2 = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn / 2; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, nn / 2 * sizeof(CPLX)); + memcpy(a2, a, nn / 2 * sizeof(CPLX)); + cplx_ifft_naive(m, 0.25, a2); + ifft(itables, a1); + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a1[i][0] - a2[i][0]); + double dim = fabs(a1[i][1] - a2[i][1]); + if (dre > d) d = dre; + if (dim > d) d = dim; + } + ASSERT_LE(d, 1e-8); + spqlios_free(a); + spqlios_free(a1); + spqlios_free(a2); + delete_cplx_ifft_precomp(itables); + } + } +} +#endif + +// test the reference and simple implementations of mul on all dimensions +TEST(fftvec, cplx_fftvec_mul_ref) { + for (uint64_t nn : {2, 4, 8, 16, 32, 1024, 4096, 8192, 65536}) { + uint64_t m = nn / 2; + CPLX_FFTVEC_MUL_PRECOMP* precomp = new_cplx_fftvec_mul_precomp(m); + CPLX* a = new CPLX[m]; + CPLX* b = new CPLX[m]; + CPLX* r0 = new CPLX[m]; + CPLX* r1 = new CPLX[m]; + CPLX* r2 = new CPLX[m]; + int64_t p = 1 << 16; + for (uint32_t i = 0; i < m; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + b[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + b[i][1] = (rand() % p) - p / 2; + r2[i][0] = r1[i][0] = r0[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + r2[i][1] = r1[i][1] = r0[i][1] = (rand() % p) - p / 2; + } + cplx_fftvec_mul_simple(m, r0, a, b); + cplx_fftvec_mul_ref(precomp, r1, a, b); + for (uint32_t i = 0; i < m; i++) { + r2[i][0] = a[i][0] * b[i][0] - a[i][1] * b[i][1]; + r2[i][1] = a[i][0] * b[i][1] + a[i][1] * b[i][0]; + ASSERT_LE(fabs(r1[i][0] - r2[i][0]) + fabs(r1[i][1] - r2[i][1]), 1e-8); + ASSERT_LE(fabs(r0[i][0] - r2[i][0]) + fabs(r0[i][1] - r2[i][1]), 1e-8); + } + delete[] a; + delete[] b; + delete[] r0; + delete[] r1; + delete[] r2; + delete_cplx_fftvec_mul_precomp(precomp); + } +} + +// test the reference and simple implementations of addmul on all dimensions +TEST(fftvec, cplx_fftvec_addmul_ref) { + for (uint64_t nn : {2, 4, 8, 16, 32, 1024, 4096, 8192, 65536}) { + uint64_t m = nn / 2; + CPLX_FFTVEC_ADDMUL_PRECOMP* precomp = new_cplx_fftvec_addmul_precomp(m); + CPLX* a = new CPLX[m]; + CPLX* b = new CPLX[m]; + CPLX* r0 = new CPLX[m]; + CPLX* r1 = new CPLX[m]; + CPLX* r2 = new CPLX[m]; + int64_t p = 1 << 16; + for (uint32_t i = 0; i < m; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + b[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + b[i][1] = (rand() % p) - p / 2; + r2[i][0] = r1[i][0] = r0[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + r2[i][1] = r1[i][1] = r0[i][1] = (rand() % p) - p / 2; + } + cplx_fftvec_addmul_simple(m, r0, a, b); + cplx_fftvec_addmul_ref(precomp, r1, a, b); + for (uint32_t i = 0; i < m; i++) { + r2[i][0] += a[i][0] * b[i][0] - a[i][1] * b[i][1]; + r2[i][1] += a[i][0] * b[i][1] + a[i][1] * b[i][0]; + ASSERT_LE(fabs(r1[i][0] - r2[i][0]) + fabs(r1[i][1] - r2[i][1]), 1e-8); + ASSERT_LE(fabs(r0[i][0] - r2[i][0]) + fabs(r0[i][1] - r2[i][1]), 1e-8); + } + delete[] a; + delete[] b; + delete[] r0; + delete[] r1; + delete[] r2; + delete_cplx_fftvec_addmul_precomp(precomp); + } +} + +// comparative tests between mul ref vs. optimized (only relevant dimensions) +TEST(fftvec, cplx_fftvec_mul_ref_vs_optim) { + struct totest { + FFTVEC_MUL_FUNCTION f; + uint64_t min_m; + totest(FFTVEC_MUL_FUNCTION f, uint64_t min_m) : f(f), min_m(min_m) {} + }; + std::vector totestset; + totestset.emplace_back(cplx_fftvec_mul, 1); +#ifdef __x86_64__ + totestset.emplace_back(cplx_fftvec_mul_fma, 8); +#endif + for (uint64_t m : {1, 2, 4, 8, 16, 1024, 4096, 8192, 65536}) { + CPLX_FFTVEC_MUL_PRECOMP* precomp = new_cplx_fftvec_mul_precomp(m); + for (const totest& t : totestset) { + if (t.min_m > m) continue; + CPLX* a = new CPLX[m]; + CPLX* b = new CPLX[m]; + CPLX* r1 = new CPLX[m]; + CPLX* r2 = new CPLX[m]; + int64_t p = 1 << 16; + for (uint32_t i = 0; i < m; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + b[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + b[i][1] = (rand() % p) - p / 2; + r2[i][0] = r1[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + r2[i][1] = r1[i][1] = (rand() % p) - p / 2; + } + t.f(precomp, r1, a, b); + cplx_fftvec_mul_ref(precomp, r2, a, b); + for (uint32_t i = 0; i < m; i++) { + double dre = fabs(r1[i][0] - r2[i][0]); + double dim = fabs(r1[i][1] - r2[i][1]); + ASSERT_LE(dre, 1e-8); + ASSERT_LE(dim, 1e-8); + } + delete[] a; + delete[] b; + delete[] r1; + delete[] r2; + } + delete_cplx_fftvec_mul_precomp(precomp); + } +} + +// comparative tests between addmul ref vs. optimized (only relevant dimensions) +TEST(fftvec, cplx_fftvec_addmul_ref_vs_optim) { + struct totest { + FFTVEC_ADDMUL_FUNCTION f; + uint64_t min_m; + totest(FFTVEC_ADDMUL_FUNCTION f, uint64_t min_m) : f(f), min_m(min_m) {} + }; + std::vector totestset; + totestset.emplace_back(cplx_fftvec_addmul, 1); +#ifdef __x86_64__ + totestset.emplace_back(cplx_fftvec_addmul_fma, 8); +#endif + for (uint64_t m : {1, 2, 4, 8, 16, 1024, 4096, 8192, 65536}) { + CPLX_FFTVEC_ADDMUL_PRECOMP* precomp = new_cplx_fftvec_addmul_precomp(m); + for (const totest& t : totestset) { + if (t.min_m > m) continue; + CPLX* a = new CPLX[m]; + CPLX* b = new CPLX[m]; + CPLX* r1 = new CPLX[m]; + CPLX* r2 = new CPLX[m]; + int64_t p = 1 << 16; + for (uint32_t i = 0; i < m; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + b[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + b[i][1] = (rand() % p) - p / 2; + r2[i][0] = r1[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + r2[i][1] = r1[i][1] = (rand() % p) - p / 2; + } + t.f(precomp, r1, a, b); + cplx_fftvec_addmul_ref(precomp, r2, a, b); + for (uint32_t i = 0; i < m; i++) { + double dre = fabs(r1[i][0] - r2[i][0]); + double dim = fabs(r1[i][1] - r2[i][1]); + ASSERT_LE(dre, 1e-8); + ASSERT_LE(dim, 1e-8); + } + delete[] a; + delete[] b; + delete[] r1; + delete[] r2; + } + delete_cplx_fftvec_addmul_precomp(precomp); + } +} diff --git a/spqlios/lib/test/spqlios_q120_arithmetic_bench.cpp b/spqlios/lib/test/spqlios_q120_arithmetic_bench.cpp new file mode 100644 index 0000000..b344503 --- /dev/null +++ b/spqlios/lib/test/spqlios_q120_arithmetic_bench.cpp @@ -0,0 +1,136 @@ +#include + +#include + +#include "spqlios/q120/q120_arithmetic.h" + +#define ARGS Arg(128)->Arg(4096)->Arg(10000) + +template +void benchmark_baa(benchmark::State& state) { + const uint64_t ell = state.range(0); + q120_mat1col_product_baa_precomp* precomp = q120_new_vec_mat1col_product_baa_precomp(); + + uint64_t* a = new uint64_t[ell * 4]; + uint64_t* b = new uint64_t[ell * 4]; + uint64_t* c = new uint64_t[4]; + for (uint64_t i = 0; i < 4 * ell; i++) { + a[i] = rand(); + b[i] = rand(); + } + for (auto _ : state) { + f(precomp, ell, (q120b*)c, (q120a*)a, (q120a*)b); + } + delete[] c; + delete[] b; + delete[] a; + q120_delete_vec_mat1col_product_baa_precomp(precomp); +} + +BENCHMARK(benchmark_baa)->Name("q120_vec_mat1col_product_baa_ref")->ARGS; +BENCHMARK(benchmark_baa)->Name("q120_vec_mat1col_product_baa_avx2")->ARGS; + +template +void benchmark_bbb(benchmark::State& state) { + const uint64_t ell = state.range(0); + q120_mat1col_product_bbb_precomp* precomp = q120_new_vec_mat1col_product_bbb_precomp(); + + uint64_t* a = new uint64_t[ell * 4]; + uint64_t* b = new uint64_t[ell * 4]; + uint64_t* c = new uint64_t[4]; + for (uint64_t i = 0; i < 4 * ell; i++) { + a[i] = rand(); + b[i] = rand(); + } + for (auto _ : state) { + f(precomp, ell, (q120b*)c, (q120b*)a, (q120b*)b); + } + delete[] c; + delete[] b; + delete[] a; + q120_delete_vec_mat1col_product_bbb_precomp(precomp); +} + +BENCHMARK(benchmark_bbb)->Name("q120_vec_mat1col_product_bbb_ref")->ARGS; +BENCHMARK(benchmark_bbb)->Name("q120_vec_mat1col_product_bbb_avx2")->ARGS; + +template +void benchmark_bbc(benchmark::State& state) { + const uint64_t ell = state.range(0); + q120_mat1col_product_bbc_precomp* precomp = q120_new_vec_mat1col_product_bbc_precomp(); + + uint64_t* a = new uint64_t[ell * 4]; + uint64_t* b = new uint64_t[ell * 4]; + uint64_t* c = new uint64_t[4]; + for (uint64_t i = 0; i < 4 * ell; i++) { + a[i] = rand(); + b[i] = rand(); + } + for (auto _ : state) { + f(precomp, ell, (q120b*)c, (q120b*)a, (q120c*)b); + } + delete[] c; + delete[] b; + delete[] a; + q120_delete_vec_mat1col_product_bbc_precomp(precomp); +} + +BENCHMARK(benchmark_bbc)->Name("q120_vec_mat1col_product_bbc_ref")->ARGS; +BENCHMARK(benchmark_bbc)->Name("q120_vec_mat1col_product_bbc_avx2")->ARGS; + +EXPORT void q120x2_vec_mat2cols_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); +EXPORT void q120x2_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); + +template +void benchmark_x2c2_bbc(benchmark::State& state) { + const uint64_t ell = state.range(0); + q120_mat1col_product_bbc_precomp* precomp = q120_new_vec_mat1col_product_bbc_precomp(); + + uint64_t* a = new uint64_t[ell * 8]; + uint64_t* b = new uint64_t[ell * 16]; + uint64_t* c = new uint64_t[16]; + for (uint64_t i = 0; i < 8 * ell; i++) { + a[i] = rand(); + } + for (uint64_t i = 0; i < 16 * ell; i++) { + b[i] = rand(); + } + for (auto _ : state) { + f(precomp, ell, (q120b*)c, (q120b*)a, (q120c*)b); + } + delete[] c; + delete[] b; + delete[] a; + q120_delete_vec_mat1col_product_bbc_precomp(precomp); +} + +BENCHMARK(benchmark_x2c2_bbc)->Name("q120x2_vec_mat2col_product_bbc_avx2")->ARGS; + +template +void benchmark_x2c1_bbc(benchmark::State& state) { + const uint64_t ell = state.range(0); + q120_mat1col_product_bbc_precomp* precomp = q120_new_vec_mat1col_product_bbc_precomp(); + + uint64_t* a = new uint64_t[ell * 8]; + uint64_t* b = new uint64_t[ell * 8]; + uint64_t* c = new uint64_t[8]; + for (uint64_t i = 0; i < 8 * ell; i++) { + a[i] = rand(); + } + for (uint64_t i = 0; i < 8 * ell; i++) { + b[i] = rand(); + } + for (auto _ : state) { + f(precomp, ell, (q120b*)c, (q120b*)a, (q120c*)b); + } + delete[] c; + delete[] b; + delete[] a; + q120_delete_vec_mat1col_product_bbc_precomp(precomp); +} + +BENCHMARK(benchmark_x2c1_bbc)->Name("q120x2_vec_mat1col_product_bbc_avx2")->ARGS; + +BENCHMARK_MAIN(); diff --git a/spqlios/lib/test/spqlios_q120_arithmetic_test.cpp b/spqlios/lib/test/spqlios_q120_arithmetic_test.cpp new file mode 100644 index 0000000..ca261ff --- /dev/null +++ b/spqlios/lib/test/spqlios_q120_arithmetic_test.cpp @@ -0,0 +1,437 @@ +#include + +#include +#include + +#include "spqlios/q120/q120_arithmetic.h" +#include "test/testlib/negacyclic_polynomial.h" +#include "test/testlib/ntt120_layouts.h" +#include "testlib/mod_q120.h" + +typedef typeof(q120_vec_mat1col_product_baa_ref) vec_mat1col_product_baa_f; + +void test_vec_mat1col_product_baa(vec_mat1col_product_baa_f vec_mat1col_product_baa) { + q120_mat1col_product_baa_precomp* precomp = q120_new_vec_mat1col_product_baa_precomp(); + for (uint64_t ell : {1, 2, 100, 10000}) { + std::vector a(ell * 4); + std::vector b(ell * 4); + std::vector res(4); + uint64_t* pa = a.data(); + uint64_t* pb = b.data(); + uint64_t* pr = res.data(); + // generate some random data + uniform_q120b(pr); + for (uint64_t i = 0; i < ell; ++i) { + uniform_q120a(pa + 4 * i); + uniform_q120a(pb + 4 * i); + } + // compute the expected result + mod_q120 expect_r; + for (uint64_t i = 0; i < ell; ++i) { + expect_r += mod_q120::from_q120a(pa + 4 * i) * mod_q120::from_q120a(pb + 4 * i); + } + // compute the function + vec_mat1col_product_baa(precomp, ell, (q120b*)pr, (q120a*)pa, (q120a*)pb); + mod_q120 comp_r = mod_q120::from_q120b(pr); + // check for equality + ASSERT_EQ(comp_r, expect_r) << ell; + } + q120_delete_vec_mat1col_product_baa_precomp(precomp); +} + +TEST(q120_arithmetic, q120_vec_mat1col_product_baa_ref) { + test_vec_mat1col_product_baa(q120_vec_mat1col_product_baa_ref); +} +#ifdef __x86_64__ +TEST(q120_arithmetic, q120_vec_mat1col_product_baa_avx2) { + test_vec_mat1col_product_baa(q120_vec_mat1col_product_baa_avx2); +} +#endif + +typedef typeof(q120_vec_mat1col_product_bbb_ref) vec_mat1col_product_bbb_f; + +void test_vec_mat1col_product_bbb(vec_mat1col_product_bbb_f vec_mat1col_product_bbb) { + q120_mat1col_product_bbb_precomp* precomp = q120_new_vec_mat1col_product_bbb_precomp(); + for (uint64_t ell : {1, 2, 100, 10000}) { + std::vector a(ell * 4); + std::vector b(ell * 4); + std::vector res(4); + uint64_t* pa = a.data(); + uint64_t* pb = b.data(); + uint64_t* pr = res.data(); + // generate some random data + uniform_q120b(pr); + for (uint64_t i = 0; i < ell; ++i) { + uniform_q120b(pa + 4 * i); + uniform_q120b(pb + 4 * i); + } + // compute the expected result + mod_q120 expect_r; + for (uint64_t i = 0; i < ell; ++i) { + expect_r += mod_q120::from_q120b(pa + 4 * i) * mod_q120::from_q120b(pb + 4 * i); + } + // compute the function + vec_mat1col_product_bbb(precomp, ell, (q120b*)pr, (q120b*)pa, (q120b*)pb); + mod_q120 comp_r = mod_q120::from_q120b(pr); + // check for equality + ASSERT_EQ(comp_r, expect_r); + } + q120_delete_vec_mat1col_product_bbb_precomp(precomp); +} + +TEST(q120_arithmetic, q120_vec_mat1col_product_bbb_ref) { + test_vec_mat1col_product_bbb(q120_vec_mat1col_product_bbb_ref); +} +#ifdef __x86_64__ +TEST(q120_arithmetic, q120_vec_mat1col_product_bbb_avx2) { + test_vec_mat1col_product_bbb(q120_vec_mat1col_product_bbb_avx2); +} +#endif + +typedef typeof(q120_vec_mat1col_product_bbc_ref) vec_mat1col_product_bbc_f; + +void test_vec_mat1col_product_bbc(vec_mat1col_product_bbc_f vec_mat1col_product_bbc) { + q120_mat1col_product_bbc_precomp* precomp = q120_new_vec_mat1col_product_bbc_precomp(); + for (uint64_t ell : {1, 2, 100, 10000}) { + std::vector a(ell * 4); + std::vector b(ell * 4); + std::vector res(4); + uint64_t* pa = a.data(); + uint64_t* pb = b.data(); + uint64_t* pr = res.data(); + // generate some random data + uniform_q120b(pr); + for (uint64_t i = 0; i < ell; ++i) { + uniform_q120b(pa + 4 * i); + uniform_q120c(pb + 4 * i); + } + // compute the expected result + mod_q120 expect_r; + for (uint64_t i = 0; i < ell; ++i) { + expect_r += mod_q120::from_q120b(pa + 4 * i) * mod_q120::from_q120c(pb + 4 * i); + } + // compute the function + vec_mat1col_product_bbc(precomp, ell, (q120b*)pr, (q120b*)pa, (q120c*)pb); + mod_q120 comp_r = mod_q120::from_q120b(pr); + // check for equality + ASSERT_EQ(comp_r, expect_r); + } + q120_delete_vec_mat1col_product_bbc_precomp(precomp); +} + +TEST(q120_arithmetic, q120_vec_mat1col_product_bbc_ref) { + test_vec_mat1col_product_bbc(q120_vec_mat1col_product_bbc_ref); +} +#ifdef __x86_64__ +TEST(q120_arithmetic, q120_vec_mat1col_product_bbc_avx2) { + test_vec_mat1col_product_bbc(q120_vec_mat1col_product_bbc_avx2); +} +#endif + +EXPORT void q120x2_vec_mat2cols_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); +EXPORT void q120x2_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); + +typedef typeof(q120x2_vec_mat2cols_product_bbc_avx2) q120x2_prod_bbc_f; + +void test_q120x2_vec_mat2cols_product_bbc(q120x2_prod_bbc_f q120x2_prod_bbc) { + q120_mat1col_product_bbc_precomp* precomp = q120_new_vec_mat1col_product_bbc_precomp(); + for (uint64_t ell : {1, 2, 100, 10000}) { + std::vector a(ell * 8); + std::vector b(ell * 16); + std::vector res(16); + uint64_t* pa = a.data(); + uint64_t* pb = b.data(); + uint64_t* pr = res.data(); + // generate some random data + uniform_q120b(pr); + for (uint64_t i = 0; i < 2 * ell; ++i) { + uniform_q120b(pa + 4 * i); + } + for (uint64_t i = 0; i < 4 * ell; ++i) { + uniform_q120c(pb + 4 * i); + } + // compute the expected result + mod_q120 expect_r[4]; + for (uint64_t i = 0; i < ell; ++i) { + mod_q120 va = mod_q120::from_q120b(pa + 8 * i); + mod_q120 vb = mod_q120::from_q120b(pa + 8 * i + 4); + mod_q120 m1a = mod_q120::from_q120c(pb + 16 * i); + mod_q120 m1b = mod_q120::from_q120c(pb + 16 * i + 4); + mod_q120 m2a = mod_q120::from_q120c(pb + 16 * i + 8); + mod_q120 m2b = mod_q120::from_q120c(pb + 16 * i + 12); + expect_r[0] += va * m1a; + expect_r[1] += vb * m1b; + expect_r[2] += va * m2a; + expect_r[3] += vb * m2b; + } + // compute the function + q120x2_prod_bbc(precomp, ell, (q120b*)pr, (q120b*)pa, (q120c*)pb); + // check for equality + ASSERT_EQ(mod_q120::from_q120b(pr), expect_r[0]); + ASSERT_EQ(mod_q120::from_q120b(pr + 4), expect_r[1]); + ASSERT_EQ(mod_q120::from_q120b(pr + 8), expect_r[2]); + ASSERT_EQ(mod_q120::from_q120b(pr + 12), expect_r[3]); + } + q120_delete_vec_mat1col_product_bbc_precomp(precomp); +} + +TEST(q120_arithmetic, q120x2_vec_mat2cols_product_bbc_ref) { + test_q120x2_vec_mat2cols_product_bbc(q120x2_vec_mat2cols_product_bbc_ref); +} +#ifdef __x86_64__ +TEST(q120_arithmetic, q120x2_vec_mat2cols_product_bbc_avx2) { + test_q120x2_vec_mat2cols_product_bbc(q120x2_vec_mat2cols_product_bbc_avx2); +} +#endif + +typedef typeof(q120x2_vec_mat1col_product_bbc_avx2) q120x2_c1_prod_bbc_f; + +void test_q120x2_vec_mat1col_product_bbc(q120x2_c1_prod_bbc_f q120x2_c1_prod_bbc) { + q120_mat1col_product_bbc_precomp* precomp = q120_new_vec_mat1col_product_bbc_precomp(); + for (uint64_t ell : {1, 2, 100, 10000}) { + std::vector a(ell * 8); + std::vector b(ell * 8); + std::vector res(8); + uint64_t* pa = a.data(); + uint64_t* pb = b.data(); + uint64_t* pr = res.data(); + // generate some random data + uniform_q120b(pr); + for (uint64_t i = 0; i < 2 * ell; ++i) { + uniform_q120b(pa + 4 * i); + } + for (uint64_t i = 0; i < 2 * ell; ++i) { + uniform_q120c(pb + 4 * i); + } + // compute the expected result + mod_q120 expect_r[2]; + for (uint64_t i = 0; i < ell; ++i) { + mod_q120 va = mod_q120::from_q120b(pa + 8 * i); + mod_q120 vb = mod_q120::from_q120b(pa + 8 * i + 4); + mod_q120 m1a = mod_q120::from_q120c(pb + 8 * i); + mod_q120 m1b = mod_q120::from_q120c(pb + 8 * i + 4); + expect_r[0] += va * m1a; + expect_r[1] += vb * m1b; + } + // compute the function + q120x2_c1_prod_bbc(precomp, ell, (q120b*)pr, (q120b*)pa, (q120c*)pb); + // check for equality + ASSERT_EQ(mod_q120::from_q120b(pr), expect_r[0]); + ASSERT_EQ(mod_q120::from_q120b(pr + 4), expect_r[1]); + } + q120_delete_vec_mat1col_product_bbc_precomp(precomp); +} + +TEST(q120_arithmetic, q120x2_vec_mat1col_product_bbc_ref) { + test_q120x2_vec_mat1col_product_bbc(q120x2_vec_mat1col_product_bbc_ref); +} +#ifdef __x86_64__ +TEST(q120_arithmetic, q120x2_vec_mat1col_product_bbc_avx2) { + test_q120x2_vec_mat1col_product_bbc(q120x2_vec_mat1col_product_bbc_avx2); +} +#endif + +typedef typeof(q120x2_extract_1blk_from_q120b_ref) q120x2_extract_f; +void test_q120x2_extract_1blk(q120x2_extract_f q120x2_extract) { + for (uint64_t n : {2, 4, 64}) { + ntt120_vec_znx_dft_layout v(n, 1); + std::vector r(8); + std::vector expect(8); + for (uint64_t blk = 0; blk < n / 2; ++blk) { + for (uint64_t i = 0; i < 8; ++i) { + expect[i] = uniform_u64(); + } + memcpy(v.get_blk(0, blk), expect.data(), 8 * sizeof(uint64_t)); + q120x2_extract_1blk_from_q120b_ref(n, blk, (q120x2b*)r.data(), (q120b*)v.data); + ASSERT_EQ(r, expect); + } + } +} + +TEST(q120_arithmetic, q120x2_extract_1blk_from_q120b_ref) { + test_q120x2_extract_1blk(q120x2_extract_1blk_from_q120b_ref); +} + +typedef typeof(q120x2_extract_1blk_from_contiguous_q120b_ref) q120x2_extract_vec_f; +void test_q120x2_extract_1blk_vec(q120x2_extract_vec_f q120x2_extract) { + for (uint64_t n : {2, 4, 32}) { + for (uint64_t size : {1, 2, 7}) { + ntt120_vec_znx_dft_layout v(n, size); + std::vector r(8 * size); + std::vector expect(8 * size); + for (uint64_t blk = 0; blk < n / 2; ++blk) { + for (uint64_t i = 0; i < 8 * size; ++i) { + expect[i] = uniform_u64(); + } + for (uint64_t i = 0; i < size; ++i) { + memcpy(v.get_blk(i, blk), expect.data() + 8 * i, 8 * sizeof(uint64_t)); + } + q120x2_extract(n, size, blk, (q120x2b*)r.data(), (q120b*)v.data); + ASSERT_EQ(r, expect); + } + } + } +} + +TEST(q120_arithmetic, q120x2_extract_1blk_from_contiguous_q120b_ref) { + test_q120x2_extract_1blk_vec(q120x2_extract_1blk_from_contiguous_q120b_ref); +} + +typedef typeof(q120x2b_save_1blk_to_q120b_ref) q120x2_save_f; +void test_q120x2_save_1blk(q120x2_save_f q120x2_save) { + for (uint64_t n : {2, 4, 64}) { + ntt120_vec_znx_dft_layout v(n, 1); + std::vector r(8); + std::vector expect(8); + for (uint64_t blk = 0; blk < n / 2; ++blk) { + for (uint64_t i = 0; i < 8; ++i) { + expect[i] = uniform_u64(); + } + q120x2_save(n, blk, (q120b*)v.data, (q120x2b*)expect.data()); + memcpy(r.data(), v.get_blk(0, blk), 8 * sizeof(uint64_t)); + ASSERT_EQ(r, expect); + } + } +} + +TEST(q120_arithmetic, q120x2b_save_1blk_to_q120b_ref) { test_q120x2_save_1blk(q120x2b_save_1blk_to_q120b_ref); } + +TEST(q120_arithmetic, q120_add_bbb_simple) { + for (const uint64_t n : {2, 4, 1024}) { + std::vector a(n * 4); + std::vector b(n * 4); + std::vector r(n * 4); + uint64_t* pa = a.data(); + uint64_t* pb = b.data(); + uint64_t* pr = r.data(); + + // generate some random data + for (uint64_t i = 0; i < n; ++i) { + uniform_q120b(pa + 4 * i); + uniform_q120b(pb + 4 * i); + } + + // compute the function + q120_add_bbb_simple(n, (q120b*)pr, (q120b*)pa, (q120b*)pb); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 ae = mod_q120::from_q120b(pa + 4 * i); + mod_q120 be = mod_q120::from_q120b(pb + 4 * i); + mod_q120 re = mod_q120::from_q120b(pr + 4 * i); + + ASSERT_EQ(ae + be, re); + } + } +} + +TEST(q120_arithmetic, q120_add_ccc_simple) { + for (const uint64_t n : {2, 4, 1024}) { + std::vector a(n * 4); + std::vector b(n * 4); + std::vector r(n * 4); + uint64_t* pa = a.data(); + uint64_t* pb = b.data(); + uint64_t* pr = r.data(); + + // generate some random data + for (uint64_t i = 0; i < n; ++i) { + uniform_q120c(pa + 4 * i); + uniform_q120c(pb + 4 * i); + } + + // compute the function + q120_add_ccc_simple(n, (q120c*)pr, (q120c*)pa, (q120c*)pb); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 ae = mod_q120::from_q120c(pa + 4 * i); + mod_q120 be = mod_q120::from_q120c(pb + 4 * i); + mod_q120 re = mod_q120::from_q120c(pr + 4 * i); + + ASSERT_EQ(ae + be, re); + } + } +} + +TEST(q120_arithmetic, q120_c_from_b_simple) { + for (const uint64_t n : {2, 4, 1024}) { + std::vector a(n * 4); + std::vector r(n * 4); + uint64_t* pa = a.data(); + uint64_t* pr = r.data(); + + // generate some random data + for (uint64_t i = 0; i < n; ++i) { + uniform_q120b(pa + 4 * i); + } + + // compute the function + q120_c_from_b_simple(n, (q120c*)pr, (q120b*)pa); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 ae = mod_q120::from_q120b(pa + 4 * i); + mod_q120 re = mod_q120::from_q120c(pr + 4 * i); + + ASSERT_EQ(ae, re); + } + } +} + +TEST(q120_arithmetic, q120_b_from_znx64_simple) { + for (const uint64_t n : {2, 4, 1024}) { + znx_i64 x = znx_i64::random_log2bound(n, 62); + std::vector r(n * 4); + uint64_t* pr = r.data(); + + q120_b_from_znx64_simple(n, (q120b*)pr, x.data()); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 re = mod_q120::from_q120b(pr + 4 * i); + + for (uint64_t k = 0; k < 4; ++k) { + ASSERT_EQ(centermod(x.get_coeff(i), mod_q120::Qi[k]), re.a[k]); + } + } + } +} + +TEST(q120_arithmetic, q120_c_from_znx64_simple) { + for (const uint64_t n : {2, 4, 1024}) { + znx_i64 x = znx_i64::random(n); + std::vector r(n * 4); + uint64_t* pr = r.data(); + + q120_c_from_znx64_simple(n, (q120c*)pr, x.data()); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 re = mod_q120::from_q120c(pr + 4 * i); + + for (uint64_t k = 0; k < 4; ++k) { + ASSERT_EQ(centermod(x.get_coeff(i), mod_q120::Qi[k]), re.a[k]); + } + } + } +} + +TEST(q120_arithmetic, q120_b_to_znx128_simple) { + for (const uint64_t n : {2, 4, 1024}) { + std::vector x(n * 4); + uint64_t* px = x.data(); + + // generate some random data + for (uint64_t i = 0; i < n; ++i) { + uniform_q120b(px + 4 * i); + } + + znx_i128 r(n); + q120_b_to_znx128_simple(n, r.data(), (q120b*)px); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 xe = mod_q120::from_q120b(px + 4 * i); + for (uint64_t k = 0; k < 4; ++k) { + ASSERT_EQ(centermod((int64_t)(r.get_coeff(i) % mod_q120::Qi[k]), mod_q120::Qi[k]), xe.a[k]); + } + } + } +} diff --git a/spqlios/lib/test/spqlios_q120_ntt_bench.cpp b/spqlios/lib/test/spqlios_q120_ntt_bench.cpp new file mode 100644 index 0000000..50e1f15 --- /dev/null +++ b/spqlios/lib/test/spqlios_q120_ntt_bench.cpp @@ -0,0 +1,44 @@ +#include + +#include + +#include "spqlios/q120/q120_ntt.h" + +#define ARGS Arg(1 << 10)->Arg(1 << 11)->Arg(1 << 12)->Arg(1 << 13)->Arg(1 << 14)->Arg(1 << 15)->Arg(1 << 16) + +template +void benchmark_ntt(benchmark::State& state) { + const uint64_t n = state.range(0); + q120_ntt_precomp* precomp = q120_new_ntt_bb_precomp(n); + + uint64_t* px = new uint64_t[n * 4]; + for (uint64_t i = 0; i < 4 * n; i++) { + px[i] = (rand() << 31) + rand(); + } + for (auto _ : state) { + f(precomp, (q120b*)px); + } + delete[] px; + q120_del_ntt_bb_precomp(precomp); +} + +template +void benchmark_intt(benchmark::State& state) { + const uint64_t n = state.range(0); + q120_ntt_precomp* precomp = q120_new_intt_bb_precomp(n); + + uint64_t* px = new uint64_t[n * 4]; + for (uint64_t i = 0; i < 4 * n; i++) { + px[i] = (rand() << 31) + rand(); + } + for (auto _ : state) { + f(precomp, (q120b*)px); + } + delete[] px; + q120_del_intt_bb_precomp(precomp); +} + +BENCHMARK(benchmark_ntt)->Name("q120_ntt_bb_avx2")->ARGS; +BENCHMARK(benchmark_intt)->Name("q120_intt_bb_avx2")->ARGS; + +BENCHMARK_MAIN(); diff --git a/spqlios/lib/test/spqlios_q120_ntt_test.cpp b/spqlios/lib/test/spqlios_q120_ntt_test.cpp new file mode 100644 index 0000000..61c9dd6 --- /dev/null +++ b/spqlios/lib/test/spqlios_q120_ntt_test.cpp @@ -0,0 +1,174 @@ +#include + +#include +#include +#include +#include +#include + +#include "spqlios/q120/q120_common.h" +#include "spqlios/q120/q120_ntt.h" +#include "testlib/mod_q120.h" + +std::vector q120_ntt(const std::vector& x) { + const uint64_t n = x.size(); + + mod_q120 omega_2pow17{OMEGA1, OMEGA2, OMEGA3, OMEGA4}; + mod_q120 omega = pow(omega_2pow17, (1 << 16) / n); + + std::vector res(n); + for (uint64_t i = 0; i < n; ++i) { + res[i] = x[i]; + } + + for (uint64_t i = 0; i < n; ++i) { + res[i] = res[i] * pow(omega, i); + } + + for (uint64_t nn = n; nn > 1; nn /= 2) { + const uint64_t halfnn = nn / 2; + const uint64_t m = n / halfnn; + + for (uint64_t j = 0; j < n; j += nn) { + for (uint64_t k = 0; k < halfnn; ++k) { + mod_q120 a = res[j + k]; + mod_q120 b = res[j + halfnn + k]; + + res[j + k] = a + b; + res[j + halfnn + k] = (a - b) * pow(omega, k * m); + } + } + } + + return res; +} + +std::vector q120_intt(const std::vector& x) { + const uint64_t n = x.size(); + + mod_q120 omega_2pow17{OMEGA1, OMEGA2, OMEGA3, OMEGA4}; + mod_q120 omega = pow(omega_2pow17, (1 << 16) / n); + + std::vector res(n); + for (uint64_t i = 0; i < n; ++i) { + res[i] = x[i]; + } + + for (uint64_t nn = 2; nn <= n; nn *= 2) { + const uint64_t halfnn = nn / 2; + const uint64_t m = n / halfnn; + + for (uint64_t j = 0; j < n; j += nn) { + for (uint64_t k = 0; k < halfnn; ++k) { + mod_q120 a = res[j + k]; + mod_q120 b = res[j + halfnn + k]; + + mod_q120 bo = b * pow(omega, -k * m); + res[j + k] = a + bo; + res[j + halfnn + k] = a - bo; + } + } + } + + mod_q120 n_q120{(int64_t)n, (int64_t)n, (int64_t)n, (int64_t)n}; + mod_q120 n_inv_q120 = pow(n_q120, -1); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 po = pow(omega, -i) * n_inv_q120; + res[i] = res[i] * po; + } + + return res; +} + +class ntt : public testing::TestWithParam {}; + +#ifdef __x86_64__ + +TEST_P(ntt, q120_ntt_bb_avx2) { + const uint64_t n = GetParam(); + q120_ntt_precomp* precomp = q120_new_ntt_bb_precomp(n); + + std::vector x(n * 4); + uint64_t* px = x.data(); + for (uint64_t i = 0; i < 4 * n; i += 4) { + uniform_q120b(px + i); + } + + std::vector x_modq(n); + for (uint64_t i = 0; i < n; ++i) { + x_modq[i] = mod_q120::from_q120b(px + 4 * i); + } + + std::vector y_exp = q120_ntt(x_modq); + + q120_ntt_bb_avx2(precomp, (q120b*)px); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 comp_r = mod_q120::from_q120b(px + 4 * i); + ASSERT_EQ(comp_r, y_exp[i]) << i; + } + + q120_del_ntt_bb_precomp(precomp); +} + +TEST_P(ntt, q120_intt_bb_avx2) { + const uint64_t n = GetParam(); + q120_ntt_precomp* precomp = q120_new_intt_bb_precomp(n); + + std::vector x(n * 4); + uint64_t* px = x.data(); + for (uint64_t i = 0; i < 4 * n; i += 4) { + uniform_q120b(px + i); + } + + std::vector x_modq(n); + for (uint64_t i = 0; i < n; ++i) { + x_modq[i] = mod_q120::from_q120b(px + 4 * i); + } + + q120_intt_bb_avx2(precomp, (q120b*)px); + + std::vector y_exp = q120_intt(x_modq); + for (uint64_t i = 0; i < n; ++i) { + mod_q120 comp_r = mod_q120::from_q120b(px + 4 * i); + ASSERT_EQ(comp_r, y_exp[i]) << i; + } + + q120_del_intt_bb_precomp(precomp); +} + +TEST_P(ntt, q120_ntt_intt_bb_avx2) { + const uint64_t n = GetParam(); + q120_ntt_precomp* precomp_ntt = q120_new_ntt_bb_precomp(n); + q120_ntt_precomp* precomp_intt = q120_new_intt_bb_precomp(n); + + std::vector x(n * 4); + uint64_t* px = x.data(); + for (uint64_t i = 0; i < 4 * n; i += 4) { + uniform_q120b(px + i); + } + + std::vector x_modq(n); + for (uint64_t i = 0; i < n; ++i) { + x_modq[i] = mod_q120::from_q120b(px + 4 * i); + } + + q120_ntt_bb_avx2(precomp_ntt, (q120b*)px); + q120_intt_bb_avx2(precomp_intt, (q120b*)px); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 comp_r = mod_q120::from_q120b(px + 4 * i); + ASSERT_EQ(comp_r, x_modq[i]) << i; + } + + q120_del_intt_bb_precomp(precomp_intt); + q120_del_ntt_bb_precomp(precomp_ntt); +} + +INSTANTIATE_TEST_SUITE_P(q120, ntt, + testing::Values(1, 2, 4, 16, 256, UINT64_C(1) << 10, UINT64_C(1) << 11, UINT64_C(1) << 12, + UINT64_C(1) << 13, UINT64_C(1) << 14, UINT64_C(1) << 15, UINT64_C(1) << 16), + testing::PrintToStringParamName()); + +#endif diff --git a/spqlios/lib/test/spqlios_reim4_arithmetic_bench.cpp b/spqlios/lib/test/spqlios_reim4_arithmetic_bench.cpp new file mode 100644 index 0000000..2acd7df --- /dev/null +++ b/spqlios/lib/test/spqlios_reim4_arithmetic_bench.cpp @@ -0,0 +1,52 @@ +#include + +#include "spqlios/reim4/reim4_arithmetic.h" + +void init_random_values(uint64_t n, double* v) { + for (uint64_t i = 0; i < n; ++i) + v[i] = (double(rand() % (UINT64_C(1) << 14)) - (UINT64_C(1) << 13)) / (UINT64_C(1) << 12); +} + +// Run the benchmark +BENCHMARK_MAIN(); + +#undef ARGS +#define ARGS Args({47, 16384})->Args({93, 32768}) + +/* + * reim4_vec_mat1col_product + * reim4_vec_mat2col_product + * reim4_vec_mat3col_product + * reim4_vec_mat4col_product + */ + +template +void benchmark_reim4_vec_matXcols_product(benchmark::State& state) { + const uint64_t nrows = state.range(0); + + double* u = new double[nrows * 8]; + init_random_values(8 * nrows, u); + double* v = new double[nrows * X * 8]; + init_random_values(X * 8 * nrows, v); + double* dst = new double[X * 8]; + + for (auto _ : state) { + fnc(nrows, dst, u, v); + } + + delete[] dst; + delete[] v; + delete[] u; +} + +#undef ARGS +#define ARGS Arg(128)->Arg(1024)->Arg(4096) + +#ifdef __x86_64__ +BENCHMARK(benchmark_reim4_vec_matXcols_product<1, reim4_vec_mat1col_product_avx2>)->ARGS; +// TODO: please remove when fixed: +BENCHMARK(benchmark_reim4_vec_matXcols_product<2, reim4_vec_mat2cols_product_avx2>)->ARGS; +#endif +BENCHMARK(benchmark_reim4_vec_matXcols_product<1, reim4_vec_mat1col_product_ref>)->ARGS; +BENCHMARK(benchmark_reim4_vec_matXcols_product<2, reim4_vec_mat2cols_product_ref>)->ARGS; diff --git a/spqlios/lib/test/spqlios_reim4_arithmetic_test.cpp b/spqlios/lib/test/spqlios_reim4_arithmetic_test.cpp new file mode 100644 index 0000000..671bb7e --- /dev/null +++ b/spqlios/lib/test/spqlios_reim4_arithmetic_test.cpp @@ -0,0 +1,253 @@ +#include + +#include +#include + +#include "../spqlios/reim4/reim4_arithmetic.h" +#include "test/testlib/reim4_elem.h" + +/// Actual tests + +typedef typeof(reim4_extract_1blk_from_reim_ref) reim4_extract_1blk_from_reim_f; +void test_reim4_extract_1blk_from_reim(reim4_extract_1blk_from_reim_f reim4_extract_1blk_from_reim) { + static const uint64_t numtrials = 100; + for (uint64_t m : {4, 8, 16, 1024, 4096, 32768}) { + double* v = (double*)malloc(2 * m * sizeof(double)); + double* w = (double*)malloc(8 * sizeof(double)); + reim_view vv(m, v); + for (uint64_t i = 0; i < numtrials; ++i) { + reim4_elem el = gaussian_reim4(); + uint64_t blk = rand() % (m / 4); + vv.set_blk(blk, el); + reim4_extract_1blk_from_reim(m, blk, w, v); + reim4_elem actual(w); + ASSERT_EQ(el, actual); + } + free(v); + free(w); + } +} + +TEST(reim4_arithmetic, reim4_extract_1blk_from_reim_ref) { + test_reim4_extract_1blk_from_reim(reim4_extract_1blk_from_reim_ref); +} +#ifdef __x86_64__ +TEST(reim4_arithmetic, reim4_extract_1blk_from_reim_avx) { + test_reim4_extract_1blk_from_reim(reim4_extract_1blk_from_reim_avx); +} +#endif + +typedef typeof(reim4_save_1blk_to_reim_ref) reim4_save_1blk_to_reim_f; +void test_reim4_save_1blk_to_reim(reim4_save_1blk_to_reim_f reim4_save_1blk_to_reim) { + static const uint64_t numtrials = 100; + for (uint64_t m : {4, 8, 16, 1024, 4096, 32768}) { + double* v = (double*)malloc(2 * m * sizeof(double)); + double* w = (double*)malloc(8 * sizeof(double)); + reim_view vv(m, v); + for (uint64_t i = 0; i < numtrials; ++i) { + reim4_elem el = gaussian_reim4(); + el.save_as(w); + uint64_t blk = rand() % (m / 4); + reim4_save_1blk_to_reim_ref(m, blk, v, w); + reim4_elem actual = vv.get_blk(blk); + ASSERT_EQ(el, actual); + } + free(v); + free(w); + } +} + +TEST(reim4_arithmetic, reim4_save_1blk_to_reim_ref) { test_reim4_save_1blk_to_reim(reim4_save_1blk_to_reim_ref); } +#ifdef __x86_64__ +TEST(reim4_arithmetic, reim4_save_1blk_to_reim_avx) { test_reim4_save_1blk_to_reim(reim4_save_1blk_to_reim_avx); } +#endif + +typedef typeof(reim4_extract_1blk_from_contiguous_reim_ref) reim4_extract_1blk_from_contiguous_reim_f; +void test_reim4_extract_1blk_from_contiguous_reim( + reim4_extract_1blk_from_contiguous_reim_f reim4_extract_1blk_from_contiguous_reim) { + static const uint64_t numtrials = 20; + for (uint64_t m : {4, 8, 16, 1024, 4096, 32768}) { + for (uint64_t nrows : {1, 2, 5, 128}) { + double* v = (double*)malloc(2 * m * nrows * sizeof(double)); + double* w = (double*)malloc(8 * nrows * sizeof(double)); + reim_vector_view vv(m, nrows, v); + reim4_array_view ww(nrows, w); + for (uint64_t i = 0; i < numtrials; ++i) { + uint64_t blk = rand() % (m / 4); + for (uint64_t j = 0; j < nrows; ++j) { + reim4_elem el = gaussian_reim4(); + vv.row(j).set_blk(blk, el); + } + reim4_extract_1blk_from_contiguous_reim_ref(m, nrows, blk, w, v); + for (uint64_t j = 0; j < nrows; ++j) { + reim4_elem el = vv.row(j).get_blk(blk); + reim4_elem actual = ww.get(j); + ASSERT_EQ(el, actual); + } + } + free(v); + free(w); + } + } +} + +TEST(reim4_arithmetic, reim4_extract_1blk_from_contiguous_reim_ref) { + test_reim4_extract_1blk_from_contiguous_reim(reim4_extract_1blk_from_contiguous_reim_ref); +} +#ifdef __x86_64__ +TEST(reim4_arithmetic, reim4_extract_1blk_from_contiguous_reim_avx) { + test_reim4_extract_1blk_from_contiguous_reim(reim4_extract_1blk_from_contiguous_reim_avx); +} +#endif + +// test of basic arithmetic functions + +TEST(reim4_arithmetic, add) { + reim4_elem x = gaussian_reim4(); + reim4_elem y = gaussian_reim4(); + reim4_elem expect = x + y; + reim4_elem actual; + reim4_add(actual.value, x.value, y.value); + ASSERT_EQ(actual, expect); +} + +TEST(reim4_arithmetic, mul) { + reim4_elem x = gaussian_reim4(); + reim4_elem y = gaussian_reim4(); + reim4_elem expect = x * y; + reim4_elem actual; + reim4_mul(actual.value, x.value, y.value); + ASSERT_EQ(actual, expect); +} + +TEST(reim4_arithmetic, add_mul) { + reim4_elem x = gaussian_reim4(); + reim4_elem y = gaussian_reim4(); + reim4_elem z = gaussian_reim4(); + reim4_elem expect = z; + reim4_elem actual = z; + expect += x * y; + reim4_add_mul(actual.value, x.value, y.value); + ASSERT_EQ(actual, expect) << infty_dist(expect, actual); +} + +// test of dot products + +typedef typeof(reim4_vec_mat1col_product_ref) reim4_vec_mat1col_product_f; +void test_reim4_vec_mat1col_product(reim4_vec_mat1col_product_f product) { + for (uint64_t ell : {1, 2, 5, 13, 69, 129}) { + std::vector actual(8); + std::vector a(ell * 8); + std::vector b(ell * 8); + reim4_array_view va(ell, a.data()); + reim4_array_view vb(ell, b.data()); + reim4_array_view vactual(1, actual.data()); + // initialize random values + for (uint64_t i = 0; i < ell; ++i) { + va.set(i, gaussian_reim4()); + vb.set(i, gaussian_reim4()); + } + // compute the mat1col product + reim4_elem expect; + for (uint64_t i = 0; i < ell; ++i) { + expect += va.get(i) * vb.get(i); + } + // compute the actual product + product(ell, actual.data(), a.data(), b.data()); + // compare + ASSERT_LE(infty_dist(vactual.get(0), expect), 1e-10); + } +} + +TEST(reim4_arithmetic, reim4_vec_mat1col_product_ref) { test_reim4_vec_mat1col_product(reim4_vec_mat1col_product_ref); } +#ifdef __x86_64__ +TEST(reim4_arena, reim4_vec_mat1col_product_avx2) { test_reim4_vec_mat1col_product(reim4_vec_mat1col_product_avx2); } +#endif + +typedef typeof(reim4_vec_mat2cols_product_ref) reim4_vec_mat2col_product_f; +void test_reim4_vec_mat2cols_product(reim4_vec_mat2col_product_f product) { + for (uint64_t ell : {1, 2, 5, 13, 69, 129}) { + std::vector actual(16); + std::vector a(ell * 8); + std::vector b(ell * 16); + reim4_array_view va(ell, a.data()); + reim4_matrix_view vb(ell, 2, b.data()); + reim4_array_view vactual(2, actual.data()); + // initialize random values + for (uint64_t i = 0; i < ell; ++i) { + va.set(i, gaussian_reim4()); + vb.set(i, 0, gaussian_reim4()); + vb.set(i, 1, gaussian_reim4()); + } + // compute the mat1col product + reim4_elem expect[2]; + for (uint64_t i = 0; i < ell; ++i) { + expect[0] += va.get(i) * vb.get(i, 0); + expect[1] += va.get(i) * vb.get(i, 1); + } + // compute the actual product + product(ell, actual.data(), a.data(), b.data()); + // compare + ASSERT_LE(infty_dist(vactual.get(0), expect[0]), 1e-10); + ASSERT_LE(infty_dist(vactual.get(1), expect[1]), 1e-10); + } +} + +TEST(reim4_arithmetic, reim4_vec_mat2cols_product_ref) { + test_reim4_vec_mat2cols_product(reim4_vec_mat2cols_product_ref); +} +#ifdef __x86_64__ +TEST(reim4_arithmetic, reim4_vec_mat2cols_product_avx2) { + test_reim4_vec_mat2cols_product(reim4_vec_mat2cols_product_avx2); +} +#endif + +// for now, we do not need avx implementations, +// so we will keep a single test function +TEST(reim4_arithmetic, reim4_vec_convolution_ref) { + for (uint64_t sizea : {1, 2, 3, 5, 8}) { + for (uint64_t sizeb : {1, 3, 6, 9, 13}) { + std::vector a(8 * sizea); + std::vector b(8 * sizeb); + std::vector expect(8 * (sizea + sizeb - 1)); + std::vector actual(8 * (sizea + sizeb - 1)); + reim4_array_view va(sizea, a.data()); + reim4_array_view vb(sizeb, b.data()); + std::vector vexpect(sizea + sizeb + 3); + reim4_array_view vactual(sizea + sizeb - 1, actual.data()); + for (uint64_t i = 0; i < sizea; ++i) { + va.set(i, gaussian_reim4()); + } + for (uint64_t j = 0; j < sizeb; ++j) { + vb.set(j, gaussian_reim4()); + } + // manual convolution + for (uint64_t i = 0; i < sizea; ++i) { + for (uint64_t j = 0; j < sizeb; ++j) { + vexpect[i + j] += va.get(i) * vb.get(j); + } + } + // partial convolution single coeff + for (uint64_t k = 0; k < sizea + sizeb + 3; ++k) { + double dest[8] = {0}; + reim4_convolution_1coeff_ref(k, dest, a.data(), sizea, b.data(), sizeb); + ASSERT_LE(infty_dist(reim4_elem(dest), vexpect[k]), 1e-10); + } + // partial convolution dual coeff + for (uint64_t k = 0; k < sizea + sizeb + 2; ++k) { + double dest[16] = {0}; + reim4_convolution_2coeff_ref(k, dest, a.data(), sizea, b.data(), sizeb); + ASSERT_LE(infty_dist(reim4_elem(dest), vexpect[k]), 1e-10); + ASSERT_LE(infty_dist(reim4_elem(dest + 8), vexpect[k + 1]), 1e-10); + } + // actual convolution + reim4_convolution_ref(actual.data(), sizea + sizeb - 1, 0, a.data(), sizea, b.data(), sizeb); + for (uint64_t k = 0; k < sizea + sizeb - 1; ++k) { + ASSERT_LE(infty_dist(vactual.get(k), vexpect[k]), 1e-10) << k; + } + } + } +} + +EXPORT void reim4_convolution_ref(double* dest, uint64_t dest_size, uint64_t dest_offset, const double* a, + uint64_t sizea, const double* b, uint64_t sizeb); diff --git a/spqlios/lib/test/spqlios_reim_conversions_test.cpp b/spqlios/lib/test/spqlios_reim_conversions_test.cpp new file mode 100644 index 0000000..d266f89 --- /dev/null +++ b/spqlios/lib/test/spqlios_reim_conversions_test.cpp @@ -0,0 +1,115 @@ +#include +#include + +#include "testlib/test_commons.h" + +TEST(reim_conversions, reim_to_tnx) { + for (uint32_t m : {1, 2, 64, 128, 512}) { + for (double divisor : {1, 2, int(m)}) { + for (uint32_t log2overhead : {1, 2, 10, 18, 35, 42}) { + double maxdiff = pow(2., log2overhead - 50); + std::vector data(2 * m); + std::vector dout(2 * m); + for (uint64_t i = 0; i < 2 * m; ++i) { + data[i] = (uniform_f64_01() - 0.5) * pow(2., log2overhead + 1) * divisor; + } + REIM_TO_TNX_PRECOMP* p = new_reim_to_tnx_precomp(m, divisor, 18); + reim_to_tnx_ref(p, dout.data(), data.data()); + for (uint64_t i = 0; i < 2 * m; ++i) { + ASSERT_LE(fabs(dout[i]), 0.5); + double diff = dout[i] - data[i] / divisor; + double fracdiff = diff - rint(diff); + ASSERT_LE(fabs(fracdiff), maxdiff); + } + delete_reim_to_tnx_precomp(p); + } + } + } +} + +#ifdef __x86_64__ +TEST(reim_conversions, reim_to_tnx_ref_vs_avx) { + for (uint32_t m : {8, 16, 64, 128, 512}) { + for (double divisor : {1, 2, int(m)}) { + for (uint32_t log2overhead : {1, 2, 10, 18, 35, 42}) { + // double maxdiff = pow(2., log2overhead - 50); + std::vector data(2 * m); + std::vector dout1(2 * m); + std::vector dout2(2 * m); + for (uint64_t i = 0; i < 2 * m; ++i) { + data[i] = (uniform_f64_01() - 0.5) * pow(2., log2overhead + 1) * divisor; + } + REIM_TO_TNX_PRECOMP* p = new_reim_to_tnx_precomp(m, divisor, 18); + reim_to_tnx_ref(p, dout1.data(), data.data()); + reim_to_tnx_avx(p, dout2.data(), data.data()); + for (uint64_t i = 0; i < 2 * m; ++i) { + ASSERT_LE(fabs(dout1[i] - dout2[i]), 0.); + } + delete_reim_to_tnx_precomp(p); + } + } + } +} +#endif + +typedef typeof(reim_from_znx64_ref) reim_from_znx64_f; + +void test_reim_from_znx64(reim_from_znx64_f reim_from_znx64, uint64_t maxbnd) { + for (uint32_t m : {4, 8, 16, 64, 16384}) { + REIM_FROM_ZNX64_PRECOMP* p = new_reim_from_znx64_precomp(m, maxbnd); + std::vector data(2 * m); + std::vector dout(2 * m); + for (uint64_t i = 0; i < 2 * m; ++i) { + int64_t magnitude = int64_t(uniform_u64() % (maxbnd + 1)); + data[i] = uniform_i64() >> (63 - magnitude); + REQUIRE_DRAMATICALLY(abs(data[i]) <= (INT64_C(1) << magnitude), "pb"); + } + reim_from_znx64(p, dout.data(), data.data()); + for (uint64_t i = 0; i < 2 * m; ++i) { + ASSERT_EQ(dout[i], double(data[i])) << dout[i] << " " << data[i]; + } + delete_reim_from_znx64_precomp(p); + } +} + +TEST(reim_conversions, reim_from_znx64) { + for (uint64_t maxbnd : {50}) { + test_reim_from_znx64(reim_from_znx64, maxbnd); + } +} +TEST(reim_conversions, reim_from_znx64_ref) { test_reim_from_znx64(reim_from_znx64_ref, 50); } +#ifdef __x86_64__ +TEST(reim_conversions, reim_from_znx64_avx2_bnd50_fma) { test_reim_from_znx64(reim_from_znx64_bnd50_fma, 50); } +#endif + +typedef typeof(reim_to_znx64_ref) reim_to_znx64_f; + +void test_reim_to_znx64(reim_to_znx64_f reim_to_znx64_fcn, int64_t maxbnd) { + for (uint32_t m : {4, 8, 16, 64, 16384}) { + for (double divisor : {1, 2, int(m)}) { + REIM_TO_ZNX64_PRECOMP* p = new_reim_to_znx64_precomp(m, divisor, maxbnd); + std::vector data(2 * m); + std::vector dout(2 * m); + for (uint64_t i = 0; i < 2 * m; ++i) { + int64_t magnitude = int64_t(uniform_u64() % (maxbnd + 11)) - 10; + data[i] = (uniform_f64_01() - 0.5) * pow(2., magnitude + 1) * divisor; + } + reim_to_znx64_fcn(p, dout.data(), data.data()); + for (uint64_t i = 0; i < 2 * m; ++i) { + ASSERT_LE(dout[i] - data[i] / divisor, 0.5) << dout[i] << " " << data[i]; + } + delete_reim_to_znx64_precomp(p); + } + } +} + +TEST(reim_conversions, reim_to_znx64) { + for (uint64_t maxbnd : {63, 50}) { + test_reim_to_znx64(reim_to_znx64, maxbnd); + } +} +TEST(reim_conversions, reim_to_znx64_ref) { test_reim_to_znx64(reim_to_znx64_ref, 63); } +#ifdef __x86_64__ +TEST(reim_conversions, reim_to_znx64_avx2_bnd63_fma) { test_reim_to_znx64(reim_to_znx64_avx2_bnd63_fma, 63); } +TEST(reim_conversions, reim_to_znx64_avx2_bnd50_fma) { test_reim_to_znx64(reim_to_znx64_avx2_bnd50_fma, 50); } +#endif diff --git a/spqlios/lib/test/spqlios_reim_test.cpp b/spqlios/lib/test/spqlios_reim_test.cpp new file mode 100644 index 0000000..3432e32 --- /dev/null +++ b/spqlios/lib/test/spqlios_reim_test.cpp @@ -0,0 +1,477 @@ +#include + +#include + +#include "gtest/gtest.h" +#include "spqlios/commons_private.h" +#include "spqlios/cplx/cplx_fft_internal.h" +#include "spqlios/reim/reim_fft_internal.h" +#include "spqlios/reim/reim_fft_private.h" + +#ifdef __x86_64__ +TEST(fft, reim_fft_avx2_vs_fft_reim_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + // CPLX_FFT_PRECOMP* tables = new_cplx_fft_precomp(m, 0); + REIM_FFT_PRECOMP* reimtables = new_reim_fft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn / 2; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, nn / 2 * sizeof(CPLX)); + memcpy(a2, a, nn / 2 * sizeof(CPLX)); + reim_fft_ref(reimtables, a2); + reim_fft_avx2_fma(reimtables, a1); + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a1[i] - a2[i]); + double dim = fabs(a1[nn / 2 + i] - a2[nn / 2 + i]); + if (dre > d) d = dre; + if (dim > d) d = dim; + ASSERT_LE(d, nn * 1e-10) << nn; + } + ASSERT_LE(d, nn * 1e-10) << nn; + spqlios_free(a); + spqlios_free(a1); + spqlios_free(a2); + // delete_cplx_fft_precomp(tables); + delete_reim_fft_precomp(reimtables); + } +} +#endif + +#ifdef __x86_64__ +TEST(fft, reim_ifft_avx2_vs_reim_ifft_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + // CPLX_FFT_PRECOMP* tables = new_cplx_fft_precomp(m, 0); + REIM_IFFT_PRECOMP* reimtables = new_reim_ifft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn / 2; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, nn / 2 * sizeof(CPLX)); + memcpy(a2, a, nn / 2 * sizeof(CPLX)); + reim_ifft_ref(reimtables, a2); + reim_ifft_avx2_fma(reimtables, a1); + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a1[i] - a2[i]); + double dim = fabs(a1[nn / 2 + i] - a2[nn / 2 + i]); + if (dre > d) d = dre; + if (dim > d) d = dim; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a); + spqlios_free(a1); + spqlios_free(a2); + // delete_cplx_fft_precomp(tables); + delete_reim_fft_precomp(reimtables); + } +} +#endif + +#ifdef __x86_64__ +TEST(fft, reim_vecfft_addmul_fma_vs_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + REIM_FFTVEC_ADDMUL_PRECOMP* tbl = new_reim_fftvec_addmul_precomp(m); + ASSERT_TRUE(tbl != nullptr); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn; i++) { + a1[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + b1[i] = (rand() % p) - p / 2; + r1[i] = (rand() % p) - p / 2; + } + memcpy(a2, a1, nn / 2 * sizeof(CPLX)); + memcpy(b2, b1, nn / 2 * sizeof(CPLX)); + memcpy(r2, r1, nn / 2 * sizeof(CPLX)); + reim_fftvec_addmul_ref(tbl, r1, a1, b1); + reim_fftvec_addmul_fma(tbl, r2, a2, b2); + double d = 0; + for (uint32_t i = 0; i < nn; i++) { + double di = fabs(r1[i] - r2[i]); + if (di > d) d = di; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a1); + spqlios_free(a2); + spqlios_free(b1); + spqlios_free(b2); + spqlios_free(r1); + spqlios_free(r2); + delete_reim_fftvec_addmul_precomp(tbl); + } +} +#endif + +#ifdef __x86_64__ +TEST(fft, reim_vecfft_mul_fma_vs_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + REIM_FFTVEC_MUL_PRECOMP* tbl = new_reim_fftvec_mul_precomp(m); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn; i++) { + a1[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + b1[i] = (rand() % p) - p / 2; + r1[i] = (rand() % p) - p / 2; + } + memcpy(a2, a1, nn / 2 * sizeof(CPLX)); + memcpy(b2, b1, nn / 2 * sizeof(CPLX)); + memcpy(r2, r1, nn / 2 * sizeof(CPLX)); + reim_fftvec_mul_ref(tbl, r1, a1, b1); + reim_fftvec_mul_fma(tbl, r2, a2, b2); + double d = 0; + for (uint32_t i = 0; i < nn; i++) { + double di = fabs(r1[i] - r2[i]); + if (di > d) d = di; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a1); + spqlios_free(a2); + spqlios_free(b1); + spqlios_free(b2); + spqlios_free(r1); + spqlios_free(r2); + delete_reim_fftvec_mul_precomp(tbl); + } +} +#endif + +typedef void (*FILL_REIM_FFT_OMG_F)(const double entry_pwr, double** omg); +typedef void (*REIM_FFT_F)(double* dre, double* dim, const void* omega); + +// template to test a fixed-dimension fft vs. naive +template +void test_reim_fft_ref_vs_naive(FILL_REIM_FFT_OMG_F fill_omega_f, REIM_FFT_F reim_fft_f) { + double om[N]; + double data[2 * N]; + double datacopy[2 * N]; + double* omg = om; + fill_omega_f(0.25, &omg); + ASSERT_EQ(omg - om, ptrdiff_t(N)); // it may depend on N + for (uint64_t i = 0; i < N; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[N + i] = data[N + i] = (rand() % 100) - 50; + } + reim_fft_f(datacopy, datacopy + N, om); + reim_naive_fft(N, 0.25, data, data + N); + double d = 0; + for (uint64_t i = 0; i < 2 * N; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-7); +} + +template +void test_reim_fft_ref_vs_accel(REIM_FFT_F reim_fft_ref_f, REIM_FFT_F reim_fft_accel_f) { + double om[N]; + double data[2 * N]; + double datacopy[2 * N]; + for (uint64_t i = 0; i < N; ++i) { + om[i] = (rand() % 100) - 50; + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[N + i] = data[N + i] = (rand() % 100) - 50; + } + reim_fft_ref_f(datacopy, datacopy + N, om); + reim_fft_accel_f(data, data + N, om); + double d = 0; + for (uint64_t i = 0; i < 2 * N; ++i) { + d += fabs(datacopy[i] - data[i]); + } + if (d > 1e-15) { + for (uint64_t i = 0; i < N; ++i) { + printf("%" PRId64 " %lf %lf %lf %lf\n", i, data[i], data[N + i], datacopy[i], datacopy[N + i]); + } + ASSERT_LE(d, 0); + } +} + +TEST(fft, reim_fft16_ref_vs_naive) { test_reim_fft_ref_vs_naive<16>(fill_reim_fft16_omegas, reim_fft16_ref); } +#ifdef __aarch64__ +TEST(fft, reim_fft16_neon_vs_naive) { test_reim_fft_ref_vs_naive<16>(fill_reim_fft16_omegas_neon, reim_fft16_neon); } +#endif + +#ifdef __x86_64__ +TEST(fft, reim_fft16_ref_vs_fma) { test_reim_fft_ref_vs_accel<16>(reim_fft16_ref, reim_fft16_avx_fma); } +#endif + +#ifdef __aarch64__ +static void reim_fft16_ref_neon_pom(double* dre, double* dim, const void* omega) { + const double* pom = (double*) omega; + // put the omegas in neon order + double x_pom[] = { + pom[0], pom[1], pom[2], pom[3], + pom[4],pom[5], pom[6], pom[7], + pom[8], pom[10],pom[12], pom[14], + pom[9], pom[11],pom[13], pom[15] + }; + reim_fft16_ref(dre, dim, x_pom); +} +TEST(fft, reim_fft16_ref_vs_neon) { test_reim_fft_ref_vs_accel<16>(reim_fft16_ref_neon_pom, reim_fft16_neon); } +#endif + +TEST(fft, reim_fft8_ref_vs_naive) { test_reim_fft_ref_vs_naive<8>(fill_reim_fft8_omegas, reim_fft8_ref); } + +#ifdef __x86_64__ +TEST(fft, reim_fft8_ref_vs_fma) { test_reim_fft_ref_vs_accel<8>(reim_fft8_ref, reim_fft8_avx_fma); } +#endif + +TEST(fft, reim_fft4_ref_vs_naive) { test_reim_fft_ref_vs_naive<4>(fill_reim_fft4_omegas, reim_fft4_ref); } + +#ifdef __x86_64__ +TEST(fft, reim_fft4_ref_vs_fma) { test_reim_fft_ref_vs_accel<4>(reim_fft4_ref, reim_fft4_avx_fma); } +#endif + +TEST(fft, reim_fft2_ref_vs_naive) { test_reim_fft_ref_vs_naive<2>(fill_reim_fft2_omegas, reim_fft2_ref); } + +TEST(fft, reim_fft_bfs_16_ref_vs_naive) { + for (const uint64_t m : {16, 32, 64, 128, 256, 512, 1024, 2048}) { + std::vector om(2 * m); + std::vector data(2 * m); + std::vector datacopy(2 * m); + double* omg = om.data(); + fill_reim_fft_bfs_16_omegas(m, 0.25, &omg); + ASSERT_LE(omg - om.data(), ptrdiff_t(2 * m)); // it may depend on m + for (uint64_t i = 0; i < m; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[m + i] = data[m + i] = (rand() % 100) - 50; + } + omg = om.data(); + reim_fft_bfs_16_ref(m, datacopy.data(), datacopy.data() + m, &omg); + reim_naive_fft(m, 0.25, data.data(), data.data() + m); + double d = 0; + for (uint64_t i = 0; i < 2 * m; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-7); + } +} + +TEST(fft, reim_fft_rec_16_ref_vs_naive) { + for (const uint64_t m : {2048, 4096, 8192, 32768, 65536}) { + std::vector om(2 * m); + std::vector data(2 * m); + std::vector datacopy(2 * m); + double* omg = om.data(); + fill_reim_fft_rec_16_omegas(m, 0.25, &omg); + ASSERT_LE(omg - om.data(), ptrdiff_t(2 * m)); // it may depend on m + for (uint64_t i = 0; i < m; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[m + i] = data[m + i] = (rand() % 100) - 50; + } + omg = om.data(); + reim_fft_rec_16_ref(m, datacopy.data(), datacopy.data() + m, &omg); + reim_naive_fft(m, 0.25, data.data(), data.data() + m); + double d = 0; + for (uint64_t i = 0; i < 2 * m; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-5); + } +} + +TEST(fft, reim_fft_ref_vs_naive) { + for (const uint64_t m : {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 32768, 65536}) { + std::vector om(2 * m); + std::vector data(2 * m); + std::vector datacopy(2 * m); + REIM_FFT_PRECOMP* precomp = new_reim_fft_precomp(m, 0); + for (uint64_t i = 0; i < m; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[m + i] = data[m + i] = (rand() % 100) - 50; + } + reim_fft_ref(precomp, datacopy.data()); + reim_naive_fft(m, 0.25, data.data(), data.data() + m); + double d = 0; + for (uint64_t i = 0; i < 2 * m; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-5) << m; + delete_reim_fft_precomp(precomp); + } +} + +#ifdef __aarch64__ +EXPORT REIM_FFT_PRECOMP* new_reim_fft_precomp_neon(uint32_t m, uint32_t num_buffers); +EXPORT void reim_fft_neon(const REIM_FFT_PRECOMP* precomp, double* d); +TEST(fft, reim_fft_neon_vs_naive) { + for (const uint64_t m : {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 32768, 65536}) { + std::vector om(2 * m); + std::vector data(2 * m); + std::vector datacopy(2 * m); + REIM_FFT_PRECOMP* precomp = new_reim_fft_precomp_neon(m, 0); + for (uint64_t i = 0; i < m; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[m + i] = data[m + i] = (rand() % 100) - 50; + } + reim_fft_neon(precomp, datacopy.data()); + reim_naive_fft(m, 0.25, data.data(), data.data() + m); + double d = 0; + for (uint64_t i = 0; i < 2 * m; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-5) << m; + delete_reim_fft_precomp(precomp); + } +} +#endif + +typedef void (*FILL_REIM_IFFT_OMG_F)(const double entry_pwr, double** omg); +typedef void (*REIM_IFFT_F)(double* dre, double* dim, const void* omega); + +// template to test a fixed-dimension fft vs. naive +template +void test_reim_ifft_ref_vs_naive(FILL_REIM_IFFT_OMG_F fill_omega_f, REIM_IFFT_F reim_ifft_f) { + double om[N]; + double data[2 * N]; + double datacopy[2 * N]; + double* omg = om; + fill_omega_f(0.25, &omg); + ASSERT_EQ(omg - om, ptrdiff_t(N)); // it may depend on N + for (uint64_t i = 0; i < N; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[N + i] = data[N + i] = (rand() % 100) - 50; + } + reim_ifft_f(datacopy, datacopy + N, om); + reim_naive_ifft(N, 0.25, data, data + N); + double d = 0; + for (uint64_t i = 0; i < 2 * N; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-7); +} + +template +void test_reim_ifft_ref_vs_accel(REIM_IFFT_F reim_ifft_ref_f, REIM_IFFT_F reim_ifft_accel_f) { + double om[N]; + double data[2 * N]; + double datacopy[2 * N]; + for (uint64_t i = 0; i < N; ++i) { + om[i] = (rand() % 100) - 50; + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[N + i] = data[N + i] = (rand() % 100) - 50; + } + reim_ifft_ref_f(datacopy, datacopy + N, om); + reim_ifft_accel_f(data, data + N, om); + double d = 0; + for (uint64_t i = 0; i < 2 * N; ++i) { + d += fabs(datacopy[i] - data[i]); + } + if (d > 1e-15) { + for (uint64_t i = 0; i < N; ++i) { + printf("%" PRId64 " %lf %lf %lf %lf\n", i, data[i], data[N + i], datacopy[i], datacopy[N + i]); + } + ASSERT_LE(d, 0); + } +} + +TEST(fft, reim_ifft16_ref_vs_naive) { test_reim_ifft_ref_vs_naive<16>(fill_reim_ifft16_omegas, reim_ifft16_ref); } + +#ifdef __x86_64__ +TEST(fft, reim_ifft16_ref_vs_fma) { test_reim_ifft_ref_vs_accel<16>(reim_ifft16_ref, reim_ifft16_avx_fma); } +#endif + +TEST(fft, reim_ifft8_ref_vs_naive) { test_reim_ifft_ref_vs_naive<8>(fill_reim_ifft8_omegas, reim_ifft8_ref); } + +#ifdef __x86_64__ +TEST(fft, reim_ifft8_ref_vs_fma) { test_reim_ifft_ref_vs_accel<8>(reim_ifft8_ref, reim_ifft8_avx_fma); } +#endif + +TEST(fft, reim_ifft4_ref_vs_naive) { test_reim_ifft_ref_vs_naive<4>(fill_reim_ifft4_omegas, reim_ifft4_ref); } + +#ifdef __x86_64__ +TEST(fft, reim_ifft4_ref_vs_fma) { test_reim_ifft_ref_vs_accel<4>(reim_ifft4_ref, reim_ifft4_avx_fma); } +#endif + +TEST(fft, reim_ifft2_ref_vs_naive) { test_reim_ifft_ref_vs_naive<2>(fill_reim_ifft2_omegas, reim_ifft2_ref); } + +TEST(fft, reim_ifft_bfs_16_ref_vs_naive) { + for (const uint64_t m : {16, 32, 64, 128, 256, 512, 1024, 2048}) { + std::vector om(2 * m); + std::vector data(2 * m); + std::vector datacopy(2 * m); + double* omg = om.data(); + fill_reim_ifft_bfs_16_omegas(m, 0.25, &omg); + ASSERT_LE(omg - om.data(), ptrdiff_t(2 * m)); // it may depend on m + for (uint64_t i = 0; i < m; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[m + i] = data[m + i] = (rand() % 100) - 50; + } + omg = om.data(); + reim_ifft_bfs_16_ref(m, datacopy.data(), datacopy.data() + m, &omg); + reim_naive_ifft(m, 0.25, data.data(), data.data() + m); + double d = 0; + for (uint64_t i = 0; i < 2 * m; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-7); + } +} + +TEST(fft, reim_ifft_rec_16_ref_vs_naive) { + for (const uint64_t m : {2048, 4096, 8192, 32768, 65536}) { + std::vector om(2 * m); + std::vector data(2 * m); + std::vector datacopy(2 * m); + double* omg = om.data(); + fill_reim_ifft_rec_16_omegas(m, 0.25, &omg); + ASSERT_LE(omg - om.data(), ptrdiff_t(2 * m)); // it may depend on m + for (uint64_t i = 0; i < m; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[m + i] = data[m + i] = (rand() % 100) - 50; + } + omg = om.data(); + reim_ifft_rec_16_ref(m, datacopy.data(), datacopy.data() + m, &omg); + reim_naive_ifft(m, 0.25, data.data(), data.data() + m); + double d = 0; + for (uint64_t i = 0; i < 2 * m; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-5); + } +} + +TEST(fft, reim_ifft_ref_vs_naive) { + for (const uint64_t m : {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 32768, 65536}) { + std::vector om(2 * m); + std::vector data(2 * m); + std::vector datacopy(2 * m); + REIM_IFFT_PRECOMP* precomp = new_reim_ifft_precomp(m, 0); + for (uint64_t i = 0; i < m; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[m + i] = data[m + i] = (rand() % 100) - 50; + } + reim_ifft_ref(precomp, datacopy.data()); + reim_naive_ifft(m, 0.25, data.data(), data.data() + m); + double d = 0; + for (uint64_t i = 0; i < 2 * m; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-5) << m; + delete_reim_ifft_precomp(precomp); + } +} diff --git a/spqlios/lib/test/spqlios_svp_product_test.cpp b/spqlios/lib/test/spqlios_svp_product_test.cpp new file mode 100644 index 0000000..db7cb23 --- /dev/null +++ b/spqlios/lib/test/spqlios_svp_product_test.cpp @@ -0,0 +1,28 @@ +#include + +#include "../spqlios/arithmetic/vec_znx_arithmetic_private.h" +#include "testlib/fft64_dft.h" +#include "testlib/fft64_layouts.h" +#include "testlib/polynomial_vector.h" + +// todo: remove when registered +typedef typeof(fft64_svp_prepare_ref) SVP_PREPARE_F; + +void test_fft64_svp_prepare(SVP_PREPARE_F svp_prepare) { + for (uint64_t n : {2, 4, 8, 64, 128}) { + MODULE* module = new_module_info(n, FFT64); + znx_i64 in = znx_i64::random_log2bound(n, 40); + fft64_svp_ppol_layout out(n); + reim_fft64vec expect = simple_fft64(in); + svp_prepare(module, out.data, in.data()); + const double* ed = (double*)expect.data(); + const double* ac = (double*)out.data; + for (uint64_t i = 0; i < n; ++i) { + ASSERT_LE(abs(ed[i] - ac[i]), 1e-10) << i << n; + } + delete_module_info(module); + } +} + +TEST(svp_prepare, fft64_svp_prepare_ref) { test_fft64_svp_prepare(fft64_svp_prepare_ref); } +TEST(svp_prepare, svp_prepare) { test_fft64_svp_prepare(svp_prepare); } diff --git a/spqlios/lib/test/spqlios_svp_test.cpp b/spqlios/lib/test/spqlios_svp_test.cpp new file mode 100644 index 0000000..c75f00e --- /dev/null +++ b/spqlios/lib/test/spqlios_svp_test.cpp @@ -0,0 +1,47 @@ +#include + +#include "../spqlios/arithmetic/vec_znx_arithmetic_private.h" +#include "testlib/fft64_dft.h" +#include "testlib/fft64_layouts.h" +#include "testlib/polynomial_vector.h" + +void test_fft64_svp_apply_dft(SVP_APPLY_DFT_F svp) { + for (uint64_t n : {2, 4, 8, 64, 128}) { + MODULE* module = new_module_info(n, FFT64); + // poly 1 to multiply - create and prepare + fft64_svp_ppol_layout ppol(n); + ppol.fill_random(1.); + for (uint64_t sa : {3, 5, 8}) { + for (uint64_t sr : {3, 5, 8}) { + uint64_t a_sl = n + uniform_u64_bits(2); + // poly 2 to multiply + znx_vec_i64_layout a(n, sa, a_sl); + a.fill_random(19); + // original operation result + fft64_vec_znx_dft_layout res(n, sr); + thash hash_a_before = a.content_hash(); + thash hash_ppol_before = ppol.content_hash(); + svp(module, res.data, sr, ppol.data, a.data(), sa, a_sl); + ASSERT_EQ(a.content_hash(), hash_a_before); + ASSERT_EQ(ppol.content_hash(), hash_ppol_before); + // create expected value + reim_fft64vec ppo = ppol.get_copy(); + std::vector expect(sr); + for (uint64_t i = 0; i < sr; ++i) { + expect[i] = ppo * simple_fft64(a.get_copy_zext(i)); + } + // this is the largest precision we can safely expect + double prec_expect = n * pow(2., 19 - 52); + for (uint64_t i = 0; i < sr; ++i) { + reim_fft64vec actual = res.get_copy_zext(i); + ASSERT_LE(infty_dist(actual, expect[i]), prec_expect); + } + } + } + + delete_module_info(module); + } +} + +TEST(fft64_svp_apply_dft, svp_apply_dft) { test_fft64_svp_apply_dft(svp_apply_dft); } +TEST(fft64_svp_apply_dft, fft64_svp_apply_dft_ref) { test_fft64_svp_apply_dft(fft64_svp_apply_dft_ref); } diff --git a/spqlios/lib/test/spqlios_test.cpp b/spqlios/lib/test/spqlios_test.cpp new file mode 100644 index 0000000..f6048c1 --- /dev/null +++ b/spqlios/lib/test/spqlios_test.cpp @@ -0,0 +1,493 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "spqlios/cplx/cplx_fft_internal.h" + +using namespace std; + +/*namespace { +bool very_close(const double& a, const double& b) { + bool reps = (abs(a - b) < 1e-5); + if (!reps) { + cerr << "not close: " << a << " vs. " << b << endl; + } + return reps; +} + +}*/ // namespace + +TEST(fft, fftvec_convolution) { + uint64_t nn = 65536; // vary accross (8192, 16384), 32768, 65536 + static const uint64_t k = 18; // vary from 10 to 20 + // double* buf_fft = fft_precomp_get_buffer(tables, 0); + // double* buf_ifft = ifft_precomp_get_buffer(itables, 0); + double* a = (double*)spqlios_alloc_custom_align(32, nn * 8); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn * 8); + double* b = (double*)spqlios_alloc_custom_align(32, nn * 8); + double* dist_vector = (double*)spqlios_alloc_custom_align(32, nn * 8); + int64_t p = UINT64_C(1) << k; + printf("p size: %" PRId64 "\n", p); + for (uint32_t i = 0; i < nn; i++) { + a[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + b[i] = (rand() % p) - p / 2; + a2[i] = 0; + } + cplx_fft_simple(nn / 2, a); + cplx_fft_simple(nn / 2, b); + cplx_fftvec_addmul_simple(nn / 2, a2, a, b); + cplx_ifft_simple(nn / 2, a2); // normalization is missing + double distance = 0; + // for (int32_t i = 0; i < 10; i++) { + // printf("%lf %lf\n", a2[i], a2[i] / (nn / 2.)); + //} + for (uint32_t i = 0; i < nn; i++) { + double curdist = fabs(a2[i] / (nn / 2.) - rint(a2[i] / (nn / 2.))); + if (distance < curdist) distance = curdist; + dist_vector[i] = a2[i] / (nn / 2.) - rint(a2[i] / (nn / 2.)); + } + printf("distance: %lf\n", distance); + ASSERT_LE(distance, 0.5); // switch from previous 0.1 to 0.5 per experiment 1 reqs + // double a3[] = {2,4,4,4,5,5,7,9}; //instead of dist_vector, for test + // nn = 8; + double mean = 0; + for (uint32_t i = 0; i < nn; i++) { + mean = mean + dist_vector[i]; + } + mean = mean / nn; + double variance = 0; + for (uint32_t i = 0; i < nn; i++) { + variance = variance + pow((mean - dist_vector[i]), 2); + } + double stdev = sqrt(variance / nn); + printf("stdev: %lf\n", stdev); + + spqlios_free(a); + spqlios_free(b); + spqlios_free(a2); + spqlios_free(dist_vector); +} + +typedef double CPLX[2]; +EXPORT uint32_t revbits(uint32_t i, uint32_t v); + +void cplx_zero(CPLX r); +void cplx_addmul(CPLX r, const CPLX a, const CPLX b); + +void halfcfft_eval(CPLX res, uint32_t nn, uint32_t k, const CPLX* coeffs, const CPLX* powomegas); +void halfcfft_naive(uint32_t nn, CPLX* data); + +EXPORT void cplx_set(CPLX, const CPLX); +EXPORT void citwiddle(CPLX, CPLX, const CPLX); +EXPORT void invcitwiddle(CPLX, CPLX, const CPLX); +EXPORT void ctwiddle(CPLX, CPLX, const CPLX); +EXPORT void invctwiddle(CPLX, CPLX, const CPLX); + +#include "../spqlios/cplx/cplx_fft_private.h" +#include "../spqlios/reim/reim_fft_internal.h" +#include "../spqlios/reim/reim_fft_private.h" +#include "../spqlios/reim4/reim4_fftvec_internal.h" +#include "../spqlios/reim4/reim4_fftvec_private.h" + +TEST(fft, simple_fft_test) { // test for checking the simple_fft api + uint64_t nn = 8; // vary accross (8192, 16384), 32768, 65536 + // double* buf_fft = fft_precomp_get_buffer(tables, 0); + // double* buf_ifft = ifft_precomp_get_buffer(itables, 0); + + // define the complex coefficients of two polynomials mod X^4-i + double a[4][2] = {{1.1, 2.2}, {3.3, 4.4}, {5.5, 6.6}, {7.7, 8.8}}; + double b[4][2] = {{9., 10.}, {11., 12.}, {13., 14.}, {15., 16.}}; + double c[4][2]; // for the result + double a2[4][2]; // for testing inverse fft + memcpy(a2, a, 8 * nn); + cplx_fft_simple(4, a); + cplx_fft_simple(4, b); + cplx_fftvec_mul_simple(4, c, a, b); + cplx_ifft_simple(4, c); + // c contains the complex coefficients 4.a*b mod X^4-i + cplx_ifft_simple(4, a); + + double distance = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dist = fabs(a[i][0] / 4. - a2[i][0]); + if (distance < dist) distance = dist; + dist = fabs(a[i][1] / 4. - a2[i][1]); + if (distance < dist) distance = dist; + } + printf("distance: %lf\n", distance); + ASSERT_LE(distance, 0.1); // switch from previous 0.1 to 0.5 per experiment 1 reqs + + for (uint32_t i = 0; i < nn / 4; i++) { + printf("%lf %lf\n", a2[i][0], a[i][0] / (nn / 2.)); + printf("%lf %lf\n", a2[i][1], a[i][1] / (nn / 2.)); + } +} + +TEST(fft, reim_test) { + // double a[16] __attribute__ ((aligned(32)))= {1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10.,11.,12.,13.,14.,15.,16.}; + // double b[16] __attribute__ ((aligned(32)))= {17.,18.,19.,20.,21.,22.,23.,24.,25.,26.,27.,28.,29.,30., 31.,32.}; + // double c[16] __attribute__ ((aligned(32))); // for the result in reference layout + double a[16] = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10., 11., 12., 13., 14., 15., 16.}; + double b[16] = {17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.}; + double c[16]; // for the result in reference layout + reim_fft_simple(8, a); + reim_fft_simple(8, b); + reim_fftvec_mul_simple(8, c, a, b); + reim_ifft_simple(8, c); +} + +TEST(fft, reim_vs_regular_layout_mul_test) { + uint64_t nn = 16; + + // define the complex coefficients of two polynomials mod X^8-i + + double a1[8][2] __attribute__((aligned(32))) = {{1.1, 2.2}, {3.3, 4.4}, {5.5, 6.6}, {7.7, 8.8}, + {9.9, 10.}, {11., 12.}, {13., 14.}, {15., 16.}}; + double b1[8][2] __attribute__((aligned(32))) = {{17., 18.}, {19., 20.}, {21., 22.}, {23., 24.}, + {25., 26.}, {27., 28.}, {29., 30.}, {31., 32.}}; + double c1[8][2] __attribute__((aligned(32))); // for the result + double c2[16] __attribute__((aligned(32))); // for the result + double c3[8][2] __attribute__((aligned(32))); // for the result + + double* a2 = + (double*)spqlios_alloc_custom_align(32, nn / 2 * 2 * sizeof(double)); // for storing the coefs in reim layout + double* b2 = + (double*)spqlios_alloc_custom_align(32, nn / 2 * 2 * sizeof(double)); // for storing the coefs in reim layout + // double* c2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); // for storing the coefs in reim + // layout + + // organise the coefficients in the reim layout + for (uint32_t i = 0; i < nn / 2; i++) { + a2[i] = a1[i][0]; // a1 = a2, b1 = b2 + a2[nn / 2 + i] = a1[i][1]; + b2[i] = b1[i][0]; + b2[nn / 2 + i] = b1[i][1]; + } + + // fft + cplx_fft_simple(8, a1); + reim_fft_simple(8, a2); + + cplx_fft_simple(8, b1); + reim_fft_simple(8, b2); + + cplx_fftvec_mul_simple(8, c1, a1, b1); + reim_fftvec_mul_simple(8, c2, a2, b2); + + cplx_ifft_simple(8, c1); + reim_ifft_simple(8, c2); + + // check base layout and reim layout result in the same values + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + // printf("RE: cplx_result %lf and reim_result %lf \n", c1[i][0], c2[i]); + // printf("IM: cplx_result %lf and reim_result %lf \n", c1[i][1], c2[nn / 2 + i]); + double dre = fabs(c1[i][0] - c2[i]); + double dim = fabs(c1[i][1] - c2[nn / 2 + i]); + if (dre > d) d = dre; + if (dim > d) d = dim; + ASSERT_LE(d, 1e-7); + } + ASSERT_LE(d, 1e-7); + + // check converting back to base layout: + + for (uint32_t i = 0; i < nn / 2; i++) { + c3[i][0] = c2[i]; + c3[i][1] = c2[8 + i]; + } + + d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(c1[i][0] - c3[i][0]); + double dim = fabs(c1[i][1] - c3[i][1]); + if (dre > d) d = dre; + if (dim > d) d = dim; + ASSERT_LE(d, 1e-7); + } + ASSERT_LE(d, 1e-7); + + spqlios_free(a2); + spqlios_free(b2); + // spqlios_free(c2); +} + +TEST(fft, fftvec_convolution_recursiveoverk) { + static const uint64_t nn = 32768; // vary accross (8192, 16384), 32768, 65536 + double* a = (double*)spqlios_alloc_custom_align(32, nn * 8); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn * 8); + double* b = (double*)spqlios_alloc_custom_align(32, nn * 8); + double* dist_vector = (double*)spqlios_alloc_custom_align(32, nn * 8); + + printf("N size: %" PRId64 "\n", nn); + + for (uint32_t k = 14; k <= 24; k++) { // vary k + printf("k size: %" PRId32 "\n", k); + int64_t p = UINT64_C(1) << k; + for (uint32_t i = 0; i < nn; i++) { + a[i] = (rand() % p) - p / 2; + b[i] = (rand() % p) - p / 2; + a2[i] = 0; + } + cplx_fft_simple(nn / 2, a); + cplx_fft_simple(nn / 2, b); + cplx_fftvec_addmul_simple(nn / 2, a2, a, b); + cplx_ifft_simple(nn / 2, a2); + double distance = 0; + for (uint32_t i = 0; i < nn; i++) { + double curdist = fabs(a2[i] / (nn / 2.) - rint(a2[i] / (nn / 2.))); + if (distance < curdist) distance = curdist; + dist_vector[i] = a2[i] / (nn / 2.) - rint(a2[i] / (nn / 2.)); + } + printf("distance: %lf\n", distance); + ASSERT_LE(distance, 0.5); // switch from previous 0.1 to 0.5 per experiment 1 reqs + double mean = 0; + for (uint32_t i = 0; i < nn; i++) { + mean = mean + dist_vector[i]; + } + mean = mean / nn; + double variance = 0; + for (uint32_t i = 0; i < nn; i++) { + variance = variance + pow((mean - dist_vector[i]), 2); + } + double stdev = sqrt(variance / nn); + printf("stdev: %lf\n", stdev); + } + + spqlios_free(a); + spqlios_free(b); + spqlios_free(a2); + spqlios_free(dist_vector); +} + +#ifdef __x86_64__ +TEST(fft, cplx_fft_ref_vs_fft_reim_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + CPLX_FFT_PRECOMP* tables = new_cplx_fft_precomp(m, 0); + REIM_FFT_PRECOMP* reimtables = new_reim_fft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + CPLX* a1 = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn / 2; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, nn / 2 * sizeof(CPLX)); + for (uint32_t i = 0; i < nn / 2; i++) { + a2[i] = a[i][0]; + a2[nn / 2 + i] = a[i][1]; + } + cplx_fft_ref(tables, a1); + reim_fft_ref(reimtables, a2); + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a1[i][0] - a2[i]); + double dim = fabs(a1[i][1] - a2[nn / 2 + i]); + if (dre > d) d = dre; + if (dim > d) d = dim; + ASSERT_LE(d, 1e-7); + } + ASSERT_LE(d, 1e-7); + spqlios_free(a); + spqlios_free(a1); + spqlios_free(a2); + delete_cplx_fft_precomp(tables); + delete_reim_fft_precomp(reimtables); + } +} +#endif + +TEST(fft, cplx_ifft_ref_vs_reim_ifft_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + CPLX_IFFT_PRECOMP* tables = new_cplx_ifft_precomp(m, 0); + REIM_IFFT_PRECOMP* reimtables = new_reim_ifft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + CPLX* a1 = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn / 2; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, nn / 2 * sizeof(CPLX)); + for (uint32_t i = 0; i < nn / 2; i++) { + a2[i] = a[i][0]; + a2[nn / 2 + i] = a[i][1]; + } + cplx_ifft_ref(tables, a1); + reim_ifft_ref(reimtables, a2); + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a1[i][0] - a2[i]); + double dim = fabs(a1[i][1] - a2[nn / 2 + i]); + if (dre > d) d = dre; + if (dim > d) d = dim; + ASSERT_LE(d, 1e-7); + } + ASSERT_LE(d, 1e-7); + spqlios_free(a); + spqlios_free(a1); + spqlios_free(a2); + delete_cplx_fft_precomp(tables); + delete_reim_fft_precomp(reimtables); + } +} + +#ifdef __x86_64__ +TEST(fft, reim4_vecfft_addmul_fma_vs_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + REIM4_FFTVEC_ADDMUL_PRECOMP* tbl = new_reim4_fftvec_addmul_precomp(m); + ASSERT_TRUE(tbl != nullptr); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn; i++) { + a1[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + b1[i] = (rand() % p) - p / 2; + r1[i] = (rand() % p) - p / 2; + } + memcpy(a2, a1, nn / 2 * sizeof(CPLX)); + memcpy(b2, b1, nn / 2 * sizeof(CPLX)); + memcpy(r2, r1, nn / 2 * sizeof(CPLX)); + reim4_fftvec_addmul_ref(tbl, r1, a1, b1); + reim4_fftvec_addmul_fma(tbl, r2, a2, b2); + double d = 0; + for (uint32_t i = 0; i < nn; i++) { + double di = fabs(r1[i] - r2[i]); + if (di > d) d = di; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a1); + spqlios_free(a2); + spqlios_free(b1); + spqlios_free(b2); + spqlios_free(r1); + spqlios_free(r2); + delete_reim4_fftvec_addmul_precomp(tbl); + } +} +#endif + +#ifdef __x86_64__ +TEST(fft, reim4_vecfft_mul_fma_vs_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + REIM4_FFTVEC_MUL_PRECOMP* tbl = new_reim4_fftvec_mul_precomp(m); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn; i++) { + a1[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + b1[i] = (rand() % p) - p / 2; + r1[i] = (rand() % p) - p / 2; + } + memcpy(a2, a1, nn / 2 * sizeof(CPLX)); + memcpy(b2, b1, nn / 2 * sizeof(CPLX)); + memcpy(r2, r1, nn / 2 * sizeof(CPLX)); + reim4_fftvec_mul_ref(tbl, r1, a1, b1); + reim4_fftvec_mul_fma(tbl, r2, a2, b2); + double d = 0; + for (uint32_t i = 0; i < nn; i++) { + double di = fabs(r1[i] - r2[i]); + if (di > d) d = di; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a1); + spqlios_free(a2); + spqlios_free(b1); + spqlios_free(b2); + spqlios_free(r1); + spqlios_free(r2); + delete_reim_fftvec_mul_precomp(tbl); + } +} +#endif + +#ifdef __x86_64__ +TEST(fft, reim4_from_cplx_fma_vs_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + REIM4_FROM_CPLX_PRECOMP* tbl = new_reim4_from_cplx_precomp(m); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn; i++) { + a1[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + r1[i] = (rand() % p) - p / 2; + } + memcpy(a2, a1, nn / 2 * sizeof(CPLX)); + memcpy(r2, r1, nn / 2 * sizeof(CPLX)); + reim4_from_cplx_ref(tbl, r1, a1); + reim4_from_cplx_fma(tbl, r2, a2); + double d = 0; + for (uint32_t i = 0; i < nn; i++) { + double di = fabs(r1[i] - r2[i]); + if (di > d) d = di; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a1); + spqlios_free(a2); + spqlios_free(r1); + spqlios_free(r2); + delete_reim4_from_cplx_precomp(tbl); + } +} +#endif + +#ifdef __x86_64__ +TEST(fft, reim4_to_cplx_fma_vs_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + REIM4_TO_CPLX_PRECOMP* tbl = new_reim4_to_cplx_precomp(m); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn; i++) { + a1[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + r1[i] = (rand() % p) - p / 2; + } + memcpy(a2, a1, nn / 2 * sizeof(CPLX)); + memcpy(r2, r1, nn / 2 * sizeof(CPLX)); + reim4_to_cplx_ref(tbl, r1, a1); + reim4_to_cplx_fma(tbl, r2, a2); + double d = 0; + for (uint32_t i = 0; i < nn; i++) { + double di = fabs(r1[i] - r2[i]); + if (di > d) d = di; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a1); + spqlios_free(a2); + spqlios_free(r1); + spqlios_free(r2); + delete_reim4_from_cplx_precomp(tbl); + } +} +#endif diff --git a/spqlios/lib/test/spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp b/spqlios/lib/test/spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp new file mode 100644 index 0000000..5b36ed0 --- /dev/null +++ b/spqlios/lib/test/spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp @@ -0,0 +1,42 @@ +#include "gtest/gtest.h" +#include "spqlios/arithmetic/vec_rnx_arithmetic_private.h" +#include "testlib/vec_rnx_layout.h" + +static void test_rnx_approxdecomp(RNX_APPROXDECOMP_FROM_TNXDBL_F approxdec) { + for (const uint64_t nn : {2, 4, 8, 32}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (const uint64_t ell : {1, 2, 7}) { + for (const uint64_t k : {2, 5}) { + TNXDBL_APPROXDECOMP_GADGET* gadget = new_tnxdbl_approxdecomp_gadget(module, k, ell); + for (const uint64_t res_size : {ell, ell - 1, ell + 1}) { + const uint64_t res_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in(nn, 1, nn); + in.fill_random(3); + rnx_vec_f64_layout out(nn, res_size, res_sl); + approxdec(module, gadget, out.data(), res_size, res_sl, in.data()); + // reconstruct the output + uint64_t msize = std::min(res_size, ell); + double err_bnd = msize == ell ? pow(2., -double(msize * k) - 1) : pow(2., -double(msize * k)); + for (uint64_t j = 0; j < nn; ++j) { + double in_j = in.data()[j]; + double out_j = 0; + for (uint64_t i = 0; i < res_size; ++i) { + out_j += out.get_copy(i).get_coeff(j) * pow(2., -double((i + 1) * k)); + } + double err = out_j - in_j; + double err_abs = fabs(err - rint(err)); + ASSERT_LE(err_abs, err_bnd); + } + } + delete_tnxdbl_approxdecomp_gadget(gadget); + } + } + delete_rnx_module_info(module); + } +} + +TEST(vec_rnx, rnx_approxdecomp) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl); } +TEST(vec_rnx, rnx_approxdecomp_ref) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl_ref); } +#ifdef __x86_64__ +TEST(vec_rnx, rnx_approxdecomp_avx) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl_avx); } +#endif diff --git a/spqlios/lib/test/spqlios_vec_rnx_conversions_test.cpp b/spqlios/lib/test/spqlios_vec_rnx_conversions_test.cpp new file mode 100644 index 0000000..3b629e6 --- /dev/null +++ b/spqlios/lib/test/spqlios_vec_rnx_conversions_test.cpp @@ -0,0 +1,134 @@ +#include +#include + +#include "testlib/test_commons.h" + +template +static void test_conv(void (*conv_f)(const MOD_RNX*, // + DST_T* res, uint64_t res_size, uint64_t res_sl, // + const SRC_T* a, uint64_t a_size, uint64_t a_sl), // + DST_T (*ideal_conv_f)(SRC_T x), // + SRC_T (*random_f)() // +) { + for (uint64_t nn : {2, 4, 16, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t a_size : {0, 1, 2, 5}) { + for (uint64_t res_size : {0, 1, 3, 5}) { + for (uint64_t trials = 0; trials < 20; ++trials) { + uint64_t a_sl = nn + uniform_u64_bits(2); + uint64_t res_sl = nn + uniform_u64_bits(2); + std::vector a(a_sl * a_size); + std::vector res(res_sl * res_size); + uint64_t msize = std::min(a_size, res_size); + for (uint64_t i = 0; i < a_size; ++i) { + for (uint64_t j = 0; j < nn; ++j) { + a[i * a_sl + j] = random_f(); + } + } + conv_f(module, res.data(), res_size, res_sl, a.data(), a_size, a_sl); + for (uint64_t i = 0; i < msize; ++i) { + for (uint64_t j = 0; j < nn; ++j) { + SRC_T aij = a[i * a_sl + j]; + DST_T expect = ideal_conv_f(aij); + DST_T actual = res[i * res_sl + j]; + ASSERT_EQ(expect, actual); + } + } + for (uint64_t i = msize; i < res_size; ++i) { + DST_T expect = 0; + for (uint64_t j = 0; j < nn; ++j) { + SRC_T actual = res[i * res_sl + j]; + ASSERT_EQ(expect, actual); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static int32_t ideal_dbl_to_tn32(double a) { + double _2p32 = INT64_C(1) << 32; + double a_mod_1 = a - rint(a); + int64_t t = rint(a_mod_1 * _2p32); + return int32_t(t); +} + +static double random_f64_10() { return uniform_f64_bounds(-10, 10); } + +static void test_vec_rnx_to_tnx32(VEC_RNX_TO_TNX32_F vec_rnx_to_tnx32_f) { + test_conv(vec_rnx_to_tnx32_f, ideal_dbl_to_tn32, random_f64_10); +} + +TEST(vec_rnx_arithmetic, vec_rnx_to_tnx32) { test_vec_rnx_to_tnx32(vec_rnx_to_tnx32); } +TEST(vec_rnx_arithmetic, vec_rnx_to_tnx32_ref) { test_vec_rnx_to_tnx32(vec_rnx_to_tnx32_ref); } + +static double ideal_tn32_to_dbl(int32_t a) { + const double _2p32 = INT64_C(1) << 32; + return double(a) / _2p32; +} + +static int32_t random_t32() { return uniform_i64_bits(32); } + +static void test_vec_rnx_from_tnx32(VEC_RNX_FROM_TNX32_F vec_rnx_from_tnx32_f) { + test_conv(vec_rnx_from_tnx32_f, ideal_tn32_to_dbl, random_t32); +} + +TEST(vec_rnx_arithmetic, vec_rnx_from_tnx32) { test_vec_rnx_from_tnx32(vec_rnx_from_tnx32); } +TEST(vec_rnx_arithmetic, vec_rnx_from_tnx32_ref) { test_vec_rnx_from_tnx32(vec_rnx_from_tnx32_ref); } + +static int32_t ideal_dbl_round_to_i32(double a) { return int32_t(rint(a)); } + +static double random_dbl_explaw_18() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(6) % 19); } + +static void test_vec_rnx_to_znx32(VEC_RNX_TO_ZNX32_F vec_rnx_to_znx32_f) { + test_conv(vec_rnx_to_znx32_f, ideal_dbl_round_to_i32, random_dbl_explaw_18); +} + +TEST(zn_arithmetic, vec_rnx_to_znx32) { test_vec_rnx_to_znx32(vec_rnx_to_znx32); } +TEST(zn_arithmetic, vec_rnx_to_znx32_ref) { test_vec_rnx_to_znx32(vec_rnx_to_znx32_ref); } + +static double ideal_i32_to_dbl(int32_t a) { return double(a); } + +static int32_t random_i32_explaw_18() { return uniform_i64_bits(uniform_u64_bits(6) % 19); } + +static void test_vec_rnx_from_znx32(VEC_RNX_FROM_ZNX32_F vec_rnx_from_znx32_f) { + test_conv(vec_rnx_from_znx32_f, ideal_i32_to_dbl, random_i32_explaw_18); +} + +TEST(zn_arithmetic, vec_rnx_from_znx32) { test_vec_rnx_from_znx32(vec_rnx_from_znx32); } +TEST(zn_arithmetic, vec_rnx_from_znx32_ref) { test_vec_rnx_from_znx32(vec_rnx_from_znx32_ref); } + +static double ideal_dbl_to_tndbl(double a) { return a - rint(a); } + +static void test_vec_rnx_to_tnxdbl(VEC_RNX_TO_TNXDBL_F vec_rnx_to_tnxdbl_f) { + test_conv(vec_rnx_to_tnxdbl_f, ideal_dbl_to_tndbl, random_f64_10); +} + +TEST(zn_arithmetic, vec_rnx_to_tnxdbl) { test_vec_rnx_to_tnxdbl(vec_rnx_to_tnxdbl); } +TEST(zn_arithmetic, vec_rnx_to_tnxdbl_ref) { test_vec_rnx_to_tnxdbl(vec_rnx_to_tnxdbl_ref); } + +#if 0 +static int64_t ideal_dbl_round_to_i64(double a) { return rint(a); } + +static double random_dbl_explaw_50() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(7) % 51); } + +static void test_dbl_round_to_i64(DBL_ROUND_TO_I64_F dbl_round_to_i64_f) { + test_conv(dbl_round_to_i64_f, ideal_dbl_round_to_i64, random_dbl_explaw_50); +} + +TEST(zn_arithmetic, dbl_round_to_i64) { test_dbl_round_to_i64(dbl_round_to_i64); } +TEST(zn_arithmetic, dbl_round_to_i64_ref) { test_dbl_round_to_i64(dbl_round_to_i64_ref); } + +static double ideal_i64_to_dbl(int64_t a) { return double(a); } + +static int64_t random_i64_explaw_50() { return uniform_i64_bits(uniform_u64_bits(7) % 51); } + +static void test_i64_to_dbl(I64_TO_DBL_F i64_to_dbl_f) { + test_conv(i64_to_dbl_f, ideal_i64_to_dbl, random_i64_explaw_50); +} + +TEST(zn_arithmetic, i64_to_dbl) { test_i64_to_dbl(i64_to_dbl); } +TEST(zn_arithmetic, i64_to_dbl_ref) { test_i64_to_dbl(i64_to_dbl_ref); } +#endif diff --git a/spqlios/lib/test/spqlios_vec_rnx_ppol_test.cpp b/spqlios/lib/test/spqlios_vec_rnx_ppol_test.cpp new file mode 100644 index 0000000..e5e0cbd --- /dev/null +++ b/spqlios/lib/test/spqlios_vec_rnx_ppol_test.cpp @@ -0,0 +1,73 @@ +#include + +#include "spqlios/arithmetic/vec_rnx_arithmetic_private.h" +#include "spqlios/reim/reim_fft.h" +#include "test/testlib/vec_rnx_layout.h" + +static void test_vec_rnx_svp_prepare(RNX_SVP_PREPARE_F* rnx_svp_prepare, BYTES_OF_RNX_SVP_PPOL_F* tmp_bytes) { + for (uint64_t n : {2, 4, 8, 64}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + const double invm = 1. / mod->m; + + rnx_f64 in = rnx_f64::random_log2bound(n, 40); + rnx_f64 in_divide_by_m = rnx_f64::zero(n); + for (uint64_t i = 0; i < n; ++i) { + in_divide_by_m.set_coeff(i, in.get_coeff(i) * invm); + } + fft64_rnx_svp_ppol_layout out(n); + reim_fft64vec expect = simple_fft64(in_divide_by_m); + rnx_svp_prepare(mod, out.data, in.data()); + const double* ed = (double*)expect.data(); + const double* ac = (double*)out.data; + for (uint64_t i = 0; i < n; ++i) { + ASSERT_LE(abs(ed[i] - ac[i]), 1e-10) << i << n; + } + delete_rnx_module_info(mod); + } +} +TEST(vec_rnx, vec_rnx_svp_prepare) { test_vec_rnx_svp_prepare(rnx_svp_prepare, bytes_of_rnx_svp_ppol); } +TEST(vec_rnx, vec_rnx_svp_prepare_ref) { + test_vec_rnx_svp_prepare(fft64_rnx_svp_prepare_ref, fft64_bytes_of_rnx_svp_ppol); +} + +static void test_vec_rnx_svp_apply(RNX_SVP_APPLY_F* apply) { + for (uint64_t n : {2, 4, 8, 64, 128}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + + // poly 1 to multiply - create and prepare + fft64_rnx_svp_ppol_layout ppol(n); + ppol.fill_random(1.); + for (uint64_t sa : {3, 5, 8}) { + for (uint64_t sr : {3, 5, 8}) { + uint64_t a_sl = n + uniform_u64_bits(2); + uint64_t r_sl = n + uniform_u64_bits(2); + // poly 2 to multiply + rnx_vec_f64_layout a(n, sa, a_sl); + a.fill_random(19); + + // original operation result + rnx_vec_f64_layout res(n, sr, r_sl); + thash hash_a_before = a.content_hash(); + thash hash_ppol_before = ppol.content_hash(); + apply(mod, res.data(), sr, r_sl, ppol.data, a.data(), sa, a_sl); + ASSERT_EQ(a.content_hash(), hash_a_before); + ASSERT_EQ(ppol.content_hash(), hash_ppol_before); + // create expected value + reim_fft64vec ppo = ppol.get_copy(); + std::vector expect(sr); + for (uint64_t i = 0; i < sr; ++i) { + expect[i] = simple_ifft64(ppo * simple_fft64(a.get_copy_zext(i))); + } + // this is the largest precision we can safely expect + double prec_expect = n * pow(2., 19 - 50); + for (uint64_t i = 0; i < sr; ++i) { + rnx_f64 actual = res.get_copy_zext(i); + ASSERT_LE(infty_dist(actual, expect[i]), prec_expect); + } + } + } + delete_rnx_module_info(mod); + } +} +TEST(vec_rnx, vec_rnx_svp_apply) { test_vec_rnx_svp_apply(rnx_svp_apply); } +TEST(vec_rnx, vec_rnx_svp_apply_ref) { test_vec_rnx_svp_apply(fft64_rnx_svp_apply_ref); } diff --git a/spqlios/lib/test/spqlios_vec_rnx_test.cpp b/spqlios/lib/test/spqlios_vec_rnx_test.cpp new file mode 100644 index 0000000..2990299 --- /dev/null +++ b/spqlios/lib/test/spqlios_vec_rnx_test.cpp @@ -0,0 +1,417 @@ +#include + +#include "spqlios/arithmetic/vec_rnx_arithmetic_private.h" +#include "spqlios/reim/reim_fft.h" +#include "testlib/vec_rnx_layout.h" + +// disabling this test by default, since it depicts on purpose wrong accesses +#if 0 +TEST(rnx_layout, valgrind_antipattern_test) { + uint64_t n = 4; + rnx_vec_f64_layout v(n, 7, 13); + // this should be ok + v.set(0, rnx_f64::zero(n)); + // this should abort (wrong ring dimension) + ASSERT_DEATH(v.set(3, rnx_f64::zero(2 * n)), ""); + // this should abort (out of bounds) + ASSERT_DEATH(v.set(8, rnx_f64::zero(n)), ""); + // this should be ok + ASSERT_EQ(v.get_copy_zext(0), rnx_f64::zero(n)); + // should be an uninit read + ASSERT_TRUE(!(v.get_copy_zext(2) == rnx_f64::zero(n))); // should be uninit + // should be an invalid read (inter-slice) + ASSERT_NE(v.data()[4], 0); + ASSERT_EQ(v.data()[2], 0); // should be ok + // should be an uninit read + ASSERT_NE(v.data()[13], 0); // should be uninit +} +#endif + +// test of binary operations + +// test for out of place calls +template +void test_vec_rnx_elemw_binop_outplace(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 4, 8, 128}) { + RNX_MODULE_TYPE mtype = FFT64; + MOD_RNX* mod = new_rnx_module_info(n, mtype); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + for (uint64_t sc : {7, 13, 15}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + uint64_t c_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + rnx_vec_f64_layout lc(n, sc, c_sl); + std::vector expect(sc); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sc; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + lc.data(), sc, c_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sc; ++i) { + ASSERT_EQ(lc.get_copy_zext(i), expect[i]); + } + } + } + } + delete_rnx_module_info(mod); + } +} +// test for inplace1 calls +template +void test_vec_rnx_elemw_binop_inplace1(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 4, 64}) { + RNX_MODULE_TYPE mtype = FFT64; + MOD_RNX* mod = new_rnx_module_info(n, mtype); + for (uint64_t sa : {3, 9, 12}) { + for (uint64_t sb : {3, 9, 12}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]); + } + } + } + delete_rnx_module_info(mod); + } +} +// test for inplace2 calls +template +void test_vec_rnx_elemw_binop_inplace2(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {4, 32, 64}) { + RNX_MODULE_TYPE mtype = FFT64; + MOD_RNX* mod = new_rnx_module_info(n, mtype); + for (uint64_t sa : {3, 9, 12}) { + for (uint64_t sb : {3, 9, 12}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + lb.data(), sb, b_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]); + } + } + } + delete_rnx_module_info(mod); + } +} +// test for inplace3 calls +template +void test_vec_rnx_elemw_binop_inplace3(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 16, 1024}) { + RNX_MODULE_TYPE mtype = FFT64; + MOD_RNX* mod = new_rnx_module_info(n, mtype); + for (uint64_t sa : {2, 6, 11}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), la.get_copy_zext(i)); + } + binop(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl, // a + la.data(), sa, a_sl); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]); + } + } + delete_rnx_module_info(mod); + } +} +template +void test_vec_rnx_elemw_binop(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_rnx_elemw_binop_outplace(binop, ref_binop); + test_vec_rnx_elemw_binop_inplace1(binop, ref_binop); + test_vec_rnx_elemw_binop_inplace2(binop, ref_binop); + test_vec_rnx_elemw_binop_inplace3(binop, ref_binop); +} + +static rnx_f64 poly_add(const rnx_f64& a, const rnx_f64& b) { return a + b; } +static rnx_f64 poly_sub(const rnx_f64& a, const rnx_f64& b) { return a - b; } +TEST(vec_rnx, vec_rnx_add) { test_vec_rnx_elemw_binop(vec_rnx_add, poly_add); } +TEST(vec_rnx, vec_rnx_add_ref) { test_vec_rnx_elemw_binop(vec_rnx_add_ref, poly_add); } +#ifdef __x86_64__ +TEST(vec_rnx, vec_rnx_add_avx) { test_vec_rnx_elemw_binop(vec_rnx_add_avx, poly_add); } +#endif +TEST(vec_rnx, vec_rnx_sub) { test_vec_rnx_elemw_binop(vec_rnx_sub, poly_sub); } +TEST(vec_rnx, vec_rnx_sub_ref) { test_vec_rnx_elemw_binop(vec_rnx_sub_ref, poly_sub); } +#ifdef __x86_64__ +TEST(vec_rnx, vec_rnx_sub_avx) { test_vec_rnx_elemw_binop(vec_rnx_sub_avx, poly_sub); } +#endif + +// test for out of place calls +template +void test_vec_rnx_elemw_unop_param_outplace(ACTUAL_FCN test_mul_xp_minus_one, EXPECT_FCN ref_mul_xp_minus_one, + int64_t (*param_gen)()) { + for (uint64_t n : {2, 4, 8, 128}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + { + int64_t p = param_gen(); + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 4 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_mul_xp_minus_one(p, la.get_copy_zext(i)); + } + test_mul_xp_minus_one(mod, // + p, // + lb.data(), sb, b_sl, // + la.data(), sa, a_sl // + ); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]) << n << " " << sa << " " << sb << " " << i; + } + } + } + } + delete_rnx_module_info(mod); + } +} + +// test for inplace calls +template +void test_vec_rnx_elemw_unop_param_inplace(ACTUAL_FCN actual_function, EXPECT_FCN ref_function, + int64_t (*param_gen)()) { + for (uint64_t n : {2, 16, 1024}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {2, 6, 11}) { + { + int64_t p = param_gen(); + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_function(p, la.get_copy_zext(i)); + } + actual_function(mod, // N + p, //; + la.data(), sa, a_sl, // res + la.data(), sa, a_sl // a + ); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]) << n << " " << sa << " " << i; + } + } + } + delete_rnx_module_info(mod); + } +} + +static int64_t random_mul_xp_minus_one_param() { return uniform_i64(); } +static int64_t random_automorphism_param() { return 2 * uniform_i64() + 1; } +static int64_t random_rotation_param() { return uniform_i64(); } + +template +void test_vec_rnx_elemw_mul_xp_minus_one(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_rnx_elemw_unop_param_outplace(binop, ref_binop, random_mul_xp_minus_one_param); + test_vec_rnx_elemw_unop_param_inplace(binop, ref_binop, random_mul_xp_minus_one_param); +} +template +void test_vec_rnx_elemw_rotate(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_rnx_elemw_unop_param_outplace(binop, ref_binop, random_rotation_param); + test_vec_rnx_elemw_unop_param_inplace(binop, ref_binop, random_rotation_param); +} +template +void test_vec_rnx_elemw_automorphism(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_rnx_elemw_unop_param_outplace(binop, ref_binop, random_automorphism_param); + test_vec_rnx_elemw_unop_param_inplace(binop, ref_binop, random_automorphism_param); +} + +static rnx_f64 poly_mul_xp_minus_one(const int64_t p, const rnx_f64& a) { + uint64_t n = a.nn(); + rnx_f64 res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, a.get_coeff(i - p) - a.get_coeff(i)); + } + return res; +} +static rnx_f64 poly_rotate(const int64_t p, const rnx_f64& a) { + uint64_t n = a.nn(); + rnx_f64 res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, a.get_coeff(i - p)); + } + return res; +} +static rnx_f64 poly_automorphism(const int64_t p, const rnx_f64& a) { + uint64_t n = a.nn(); + rnx_f64 res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i * p, a.get_coeff(i)); + } + return res; +} + +TEST(vec_rnx, vec_rnx_mul_xp_minus_one) { + test_vec_rnx_elemw_mul_xp_minus_one(vec_rnx_mul_xp_minus_one, poly_mul_xp_minus_one); +} +TEST(vec_rnx, vec_rnx_mul_xp_minus_one_ref) { + test_vec_rnx_elemw_mul_xp_minus_one(vec_rnx_mul_xp_minus_one_ref, poly_mul_xp_minus_one); +} + +TEST(vec_rnx, vec_rnx_rotate) { test_vec_rnx_elemw_rotate(vec_rnx_rotate, poly_rotate); } +TEST(vec_rnx, vec_rnx_rotate_ref) { test_vec_rnx_elemw_rotate(vec_rnx_rotate_ref, poly_rotate); } +TEST(vec_rnx, vec_rnx_automorphism) { test_vec_rnx_elemw_automorphism(vec_rnx_automorphism, poly_automorphism); } +TEST(vec_rnx, vec_rnx_automorphism_ref) { + test_vec_rnx_elemw_automorphism(vec_rnx_automorphism_ref, poly_automorphism); +} + +// test for out of place calls +template +void test_vec_rnx_elemw_unop_outplace(ACTUAL_FCN actual_function, EXPECT_FCN ref_function) { + for (uint64_t n : {2, 4, 8, 128}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 4 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_function(la.get_copy_zext(i)); + } + actual_function(mod, // + lb.data(), sb, b_sl, // + la.data(), sa, a_sl // + ); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]) << n << " " << sa << " " << sb << " " << i; + } + } + } + } + delete_rnx_module_info(mod); + } +} + +// test for inplace calls +template +void test_vec_rnx_elemw_unop_inplace(ACTUAL_FCN actual_function, EXPECT_FCN ref_function) { + for (uint64_t n : {2, 16, 1024}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {2, 6, 11}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_function(la.get_copy_zext(i)); + } + actual_function(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl // a + ); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]) << n << " " << sa << " " << i; + } + } + } + delete_rnx_module_info(mod); + } +} +template +void test_vec_rnx_elemw_unop(ACTUAL_FCN unnop, EXPECT_FCN ref_unnop) { + test_vec_rnx_elemw_unop_outplace(unnop, ref_unnop); + test_vec_rnx_elemw_unop_inplace(unnop, ref_unnop); +} + +static rnx_f64 poly_copy(const rnx_f64& a) { return a; } +static rnx_f64 poly_negate(const rnx_f64& a) { return -a; } + +TEST(vec_rnx, vec_rnx_copy) { test_vec_rnx_elemw_unop(vec_rnx_copy, poly_copy); } +TEST(vec_rnx, vec_rnx_copy_ref) { test_vec_rnx_elemw_unop(vec_rnx_copy_ref, poly_copy); } +TEST(vec_rnx, vec_rnx_negate) { test_vec_rnx_elemw_unop(vec_rnx_negate, poly_negate); } +TEST(vec_rnx, vec_rnx_negate_ref) { test_vec_rnx_elemw_unop(vec_rnx_negate_ref, poly_negate); } +#ifdef __x86_64__ +TEST(vec_rnx, vec_rnx_negate_avx) { test_vec_rnx_elemw_unop(vec_rnx_negate_avx, poly_negate); } +#endif + +// test for inplace calls +void test_vec_rnx_zero(VEC_RNX_ZERO_F actual_function) { + for (uint64_t n : {2, 16, 1024}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {2, 6, 11}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + const rnx_f64 ZERO = rnx_f64::zero(n); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + actual_function(mod, // N + la.data(), sa, a_sl // res + ); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), ZERO) << n << " " << sa << " " << i; + } + } + } + delete_rnx_module_info(mod); + } +} + +TEST(vec_rnx, vec_rnx_zero) { test_vec_rnx_zero(vec_rnx_zero); } + +TEST(vec_rnx, vec_rnx_zero_ref) { test_vec_rnx_zero(vec_rnx_zero_ref); } diff --git a/spqlios/lib/test/spqlios_vec_rnx_vmp_test.cpp b/spqlios/lib/test/spqlios_vec_rnx_vmp_test.cpp new file mode 100644 index 0000000..9bbb9d7 --- /dev/null +++ b/spqlios/lib/test/spqlios_vec_rnx_vmp_test.cpp @@ -0,0 +1,291 @@ +#include "gtest/gtest.h" +#include "../spqlios/arithmetic/vec_rnx_arithmetic_private.h" +#include "../spqlios/reim/reim_fft.h" +#include "testlib/vec_rnx_layout.h" + +static void test_vmp_apply_dft_to_dft_outplace( // + RNX_VMP_APPLY_DFT_TO_DFT_F* apply, // + RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {2, 4, 8, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 4, 7}) { + for (uint64_t mat_ncols : {1, 2, 5}) { + for (uint64_t in_size : {1, 4, 7}) { + for (uint64_t out_size : {1, 2, 5}) { + const uint64_t in_sl = nn + uniform_u64_bits(2); + const uint64_t out_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in(nn, in_size, in_sl); + fft64_rnx_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + rnx_vec_f64_layout out(nn, out_size, out_sl); + in.fill_random(0); + pmat.fill_random(0); + // naive computation of the product + std::vector expect(out_size, reim_fft64vec(nn)); + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col) * in.get_dft_copy(row); + } + expect[col] = ex; + } + // apply the product + std::vector tmp(tmp_bytes(module, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, // + out.data(), out_size, out_sl, // + in.data(), in_size, in_sl, // + pmat.data, mat_nrows, mat_ncols, // + tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec actual = out.get_dft_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static void test_vmp_apply_dft_to_dft_inplace( // + RNX_VMP_APPLY_DFT_TO_DFT_F* apply, // + RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {2, 4, 8, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 2, 6}) { + for (uint64_t mat_ncols : {1, 2, 7, 8}) { + for (uint64_t in_size : {1, 3, 6}) { + for (uint64_t out_size : {1, 3, 6}) { + const uint64_t in_out_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in_out(nn, std::max(in_size, out_size), in_out_sl); + fft64_rnx_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + in_out.fill_random(0); + pmat.fill_random(0); + // naive computation of the product + std::vector expect(out_size, reim_fft64vec(nn)); + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col) * in_out.get_dft_copy(row); + } + expect[col] = ex; + } + // apply the product + std::vector tmp(tmp_bytes(module, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, // + in_out.data(), out_size, in_out_sl, // + in_out.data(), in_size, in_out_sl, // + pmat.data, mat_nrows, mat_ncols, // + tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec actual = in_out.get_dft_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static void test_vmp_apply_dft_to_dft( // + RNX_VMP_APPLY_DFT_TO_DFT_F* apply, // + RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* tmp_bytes) { + test_vmp_apply_dft_to_dft_outplace(apply, tmp_bytes); + test_vmp_apply_dft_to_dft_inplace(apply, tmp_bytes); +} + +TEST(vec_rnx, vmp_apply_to_dft) { + test_vmp_apply_dft_to_dft(rnx_vmp_apply_dft_to_dft, rnx_vmp_apply_dft_to_dft_tmp_bytes); +} +TEST(vec_rnx, fft64_vmp_apply_dft_to_dft_ref) { + test_vmp_apply_dft_to_dft(fft64_rnx_vmp_apply_dft_to_dft_ref, fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref); +} +#ifdef __x86_64__ +TEST(vec_rnx, fft64_vmp_apply_dft_to_dft_avx) { + test_vmp_apply_dft_to_dft(fft64_rnx_vmp_apply_dft_to_dft_avx, fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx); +} +#endif + +/// rnx_vmp_prepare + +static void test_vmp_prepare_contiguous(RNX_VMP_PREPARE_CONTIGUOUS_F* prepare_contiguous, + RNX_VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F* tmp_bytes) { + // tests when n < 8 + for (uint64_t nn : {2, 4}) { + const double one_over_m = 2. / nn; + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + rnx_vec_f64_layout mat(nn, nrows * ncols, nn); + fft64_rnx_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(0); + std::vector tmp_space(tmp_bytes(module)); + thash hash_before = mat.content_hash(); + prepare_contiguous(module, pmat.data, mat.data(), nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + const double* pmatv = (double*)pmat.data + (col * nrows + row) * nn; + reim_fft64vec tmp = one_over_m * simple_fft64(mat.get_copy(row * ncols + col)); + const double* tmpv = tmp.data(); + for (uint64_t i = 0; i < nn; ++i) { + ASSERT_LE(abs(pmatv[i] - tmpv[i]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } + // tests when n >= 8 + for (uint64_t nn : {8, 32}) { + const double one_over_m = 2. / nn; + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + uint64_t nblk = nn / 8; + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + rnx_vec_f64_layout mat(nn, nrows * ncols, nn); + fft64_rnx_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(0); + std::vector tmp_space(tmp_bytes(module)); + thash hash_before = mat.content_hash(); + prepare_contiguous(module, pmat.data, mat.data(), nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + reim_fft64vec tmp = one_over_m * simple_fft64(mat.get_copy(row * ncols + col)); + for (uint64_t blk = 0; blk < nblk; ++blk) { + reim4_elem expect = tmp.get_blk(blk); + reim4_elem actual = pmat.get(row, col, blk); + ASSERT_LE(infty_dist(actual, expect), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +TEST(vec_rnx, vmp_prepare_contiguous) { + test_vmp_prepare_contiguous(rnx_vmp_prepare_contiguous, rnx_vmp_prepare_contiguous_tmp_bytes); +} +TEST(vec_rnx, fft64_vmp_prepare_contiguous_ref) { + test_vmp_prepare_contiguous(fft64_rnx_vmp_prepare_contiguous_ref, fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref); +} +#ifdef __x86_64__ +TEST(vec_rnx, fft64_vmp_prepare_contiguous_avx) { + test_vmp_prepare_contiguous(fft64_rnx_vmp_prepare_contiguous_avx, fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx); +} +#endif + +/// rnx_vmp_apply_dft_to_dft + +static void test_vmp_apply_tmp_a_outplace( // + RNX_VMP_APPLY_TMP_A_F* apply, // + RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {2, 4, 8, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 4, 7}) { + for (uint64_t mat_ncols : {1, 2, 5}) { + for (uint64_t in_size : {1, 4, 7}) { + for (uint64_t out_size : {1, 2, 5}) { + const uint64_t in_sl = nn + uniform_u64_bits(2); + const uint64_t out_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in(nn, in_size, in_sl); + fft64_rnx_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + rnx_vec_f64_layout out(nn, out_size, out_sl); + in.fill_random(0); + pmat.fill_random(0); + // naive computation of the product + std::vector expect(out_size); + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col) * simple_fft64(in.get_copy(row)); + } + expect[col] = simple_ifft64(ex); + } + // apply the product + std::vector tmp(tmp_bytes(module, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, // + out.data(), out_size, out_sl, // + in.data(), in_size, in_sl, // + pmat.data, mat_nrows, mat_ncols, // + tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < out_size; ++col) { + rnx_f64 actual = out.get_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static void test_vmp_apply_tmp_a_inplace( // + RNX_VMP_APPLY_TMP_A_F* apply, // + RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {2, 4, 8, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 4, 7}) { + for (uint64_t mat_ncols : {1, 2, 5}) { + for (uint64_t in_size : {1, 4, 7}) { + for (uint64_t out_size : {1, 2, 5}) { + const uint64_t in_out_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in_out(nn, std::max(in_size, out_size), in_out_sl); + fft64_rnx_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + in_out.fill_random(0); + pmat.fill_random(0); + // naive computation of the product + std::vector expect(out_size); + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col) * simple_fft64(in_out.get_copy(row)); + } + expect[col] = simple_ifft64(ex); + } + // apply the product + std::vector tmp(tmp_bytes(module, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, // + in_out.data(), out_size, in_out_sl, // + in_out.data(), in_size, in_out_sl, // + pmat.data, mat_nrows, mat_ncols, // + tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < out_size; ++col) { + rnx_f64 actual = in_out.get_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static void test_vmp_apply_tmp_a( // + RNX_VMP_APPLY_TMP_A_F* apply, // + RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* tmp_bytes) { + test_vmp_apply_tmp_a_outplace(apply, tmp_bytes); + test_vmp_apply_tmp_a_inplace(apply, tmp_bytes); +} + +TEST(vec_znx, fft64_vmp_apply_tmp_a) { test_vmp_apply_tmp_a(rnx_vmp_apply_tmp_a, rnx_vmp_apply_tmp_a_tmp_bytes); } +TEST(vec_znx, fft64_vmp_apply_tmp_a_ref) { + test_vmp_apply_tmp_a(fft64_rnx_vmp_apply_tmp_a_ref, fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref); +} +#ifdef __x86_64__ +TEST(vec_znx, fft64_vmp_apply_tmp_a_avx) { + test_vmp_apply_tmp_a(fft64_rnx_vmp_apply_tmp_a_avx, fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx); +} +#endif diff --git a/spqlios/lib/test/spqlios_vec_znx_big_test.cpp b/spqlios/lib/test/spqlios_vec_znx_big_test.cpp new file mode 100644 index 0000000..4dd1170 --- /dev/null +++ b/spqlios/lib/test/spqlios_vec_znx_big_test.cpp @@ -0,0 +1,265 @@ +#include + +#include "spqlios/arithmetic/vec_znx_arithmetic_private.h" +#include "test/testlib/polynomial_vector.h" +#include "testlib/fft64_layouts.h" +#include "testlib/test_commons.h" + +#define def_rand_big(varname, ringdim, varsize) \ + fft64_vec_znx_big_layout varname(ringdim, varsize); \ + varname.fill_random() + +#define def_rand_small(varname, ringdim, varsize) \ + znx_vec_i64_layout varname(ringdim, varsize, 2 * ringdim); \ + varname.fill_random() + +#define test_prelude(ringdim, moduletype, dim1, dim2, dim3) \ + uint64_t n = ringdim; \ + MODULE* module = new_module_info(ringdim, moduletype); \ + for (uint64_t sa : {dim1, dim2, dim3}) { \ + for (uint64_t sb : {dim1, dim2, dim3}) { \ + for (uint64_t sr : {dim1, dim2, dim3}) + +#define test_end() \ + } \ + } \ + free(module) + +void test_fft64_vec_znx_big_add(VEC_ZNX_BIG_ADD_F vec_znx_big_add_fcn) { + test_prelude(8, FFT64, 3, 5, 7) { + fft64_vec_znx_big_layout r(n, sr); + def_rand_big(a, n, sa); + def_rand_big(b, n, sb); + vec_znx_big_add_fcn(module, r.data, sr, a.data, sa, b.data, sb); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), a.get_copy_zext(i) + b.get_copy_zext(i)); + } + } + test_end(); +} +void test_fft64_vec_znx_big_add_small(VEC_ZNX_BIG_ADD_SMALL_F vec_znx_big_add_fcn) { + test_prelude(16, FFT64, 2, 4, 5) { + fft64_vec_znx_big_layout r(n, sr); + def_rand_big(a, n, sa); + def_rand_small(b, n, sb); + vec_znx_big_add_fcn(module, r.data, sr, a.data, sa, b.data(), sb, 2 * n); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), a.get_copy_zext(i) + b.get_copy_zext(i)); + } + } + test_end(); +} +void test_fft64_vec_znx_big_add_small2(VEC_ZNX_BIG_ADD_SMALL2_F vec_znx_big_add_fcn) { + test_prelude(64, FFT64, 3, 6, 7) { + fft64_vec_znx_big_layout r(n, sr); + def_rand_small(a, n, sa); + def_rand_small(b, n, sb); + vec_znx_big_add_fcn(module, r.data, sr, a.data(), sa, 2 * n, b.data(), sb, 2 * n); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), a.get_copy_zext(i) + b.get_copy_zext(i)); + } + } + test_end(); +} + +TEST(fft64_vec_znx_big, fft64_vec_znx_big_add) { test_fft64_vec_znx_big_add(fft64_vec_znx_big_add); } +TEST(vec_znx_big, vec_znx_big_add) { test_fft64_vec_znx_big_add(vec_znx_big_add); } + +TEST(fft64_vec_znx_big, fft64_vec_znx_big_add_small) { test_fft64_vec_znx_big_add_small(fft64_vec_znx_big_add_small); } +TEST(vec_znx_big, vec_znx_big_add_small) { test_fft64_vec_znx_big_add_small(vec_znx_big_add_small); } + +TEST(fft64_vec_znx_big, fft64_vec_znx_big_add_small2) { + test_fft64_vec_znx_big_add_small2(fft64_vec_znx_big_add_small2); +} +TEST(vec_znx_big, vec_znx_big_add_small2) { test_fft64_vec_znx_big_add_small2(vec_znx_big_add_small2); } + +void test_fft64_vec_znx_big_sub(VEC_ZNX_BIG_SUB_F vec_znx_big_sub_fcn) { + test_prelude(16, FFT64, 3, 5, 7) { + fft64_vec_znx_big_layout r(n, sr); + def_rand_big(a, n, sa); + def_rand_big(b, n, sb); + vec_znx_big_sub_fcn(module, r.data, sr, a.data, sa, b.data, sb); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), a.get_copy_zext(i) - b.get_copy_zext(i)); + } + } + test_end(); +} +void test_fft64_vec_znx_big_sub_small_a(VEC_ZNX_BIG_SUB_SMALL_A_F vec_znx_big_sub_fcn) { + test_prelude(32, FFT64, 2, 4, 5) { + fft64_vec_znx_big_layout r(n, sr); + def_rand_small(a, n, sa); + def_rand_big(b, n, sb); + vec_znx_big_sub_fcn(module, r.data, sr, a.data(), sa, 2 * n, b.data, sb); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), a.get_copy_zext(i) - b.get_copy_zext(i)); + } + } + test_end(); +} +void test_fft64_vec_znx_big_sub_small_b(VEC_ZNX_BIG_SUB_SMALL_B_F vec_znx_big_sub_fcn) { + test_prelude(16, FFT64, 2, 4, 5) { + fft64_vec_znx_big_layout r(n, sr); + def_rand_big(a, n, sa); + def_rand_small(b, n, sb); + vec_znx_big_sub_fcn(module, r.data, sr, a.data, sa, b.data(), sb, 2 * n); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), a.get_copy_zext(i) - b.get_copy_zext(i)); + } + } + test_end(); +} +void test_fft64_vec_znx_big_sub_small2(VEC_ZNX_BIG_SUB_SMALL2_F vec_znx_big_sub_fcn) { + test_prelude(8, FFT64, 3, 6, 7) { + fft64_vec_znx_big_layout r(n, sr); + def_rand_small(a, n, sa); + def_rand_small(b, n, sb); + vec_znx_big_sub_fcn(module, r.data, sr, a.data(), sa, 2 * n, b.data(), sb, 2 * n); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), a.get_copy_zext(i) - b.get_copy_zext(i)); + } + } + test_end(); +} + +TEST(fft64_vec_znx_big, fft64_vec_znx_big_sub) { test_fft64_vec_znx_big_sub(fft64_vec_znx_big_sub); } +TEST(vec_znx_big, vec_znx_big_sub) { test_fft64_vec_znx_big_sub(vec_znx_big_sub); } + +TEST(fft64_vec_znx_big, fft64_vec_znx_big_sub_small_a) { + test_fft64_vec_znx_big_sub_small_a(fft64_vec_znx_big_sub_small_a); +} +TEST(vec_znx_big, vec_znx_big_sub_small_a) { test_fft64_vec_znx_big_sub_small_a(vec_znx_big_sub_small_a); } + +TEST(fft64_vec_znx_big, fft64_vec_znx_big_sub_small_b) { + test_fft64_vec_znx_big_sub_small_b(fft64_vec_znx_big_sub_small_b); +} +TEST(vec_znx_big, vec_znx_big_sub_small_b) { test_fft64_vec_znx_big_sub_small_b(vec_znx_big_sub_small_b); } + +TEST(fft64_vec_znx_big, fft64_vec_znx_big_sub_small2) { + test_fft64_vec_znx_big_sub_small2(fft64_vec_znx_big_sub_small2); +} +TEST(vec_znx_big, vec_znx_big_sub_small2) { test_fft64_vec_znx_big_sub_small2(vec_znx_big_sub_small2); } + +static void test_vec_znx_big_normalize(VEC_ZNX_BIG_NORMALIZE_BASE2K_F normalize, + VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F normalize_tmp_bytes) { + // in the FFT64 case, big_normalize is just a forward. + // we will just test that the functions are callable + uint64_t n = 16; + uint64_t k = 12; + MODULE* module = new_module_info(n, FFT64); + for (uint64_t sa : {3, 5, 7}) { + for (uint64_t sr : {3, 5, 7}) { + uint64_t r_sl = n + 3; + def_rand_big(a, n, sa); + znx_vec_i64_layout r(n, sr, r_sl); + std::vector tmp_space(normalize_tmp_bytes(module)); + normalize(module, k, r.data(), sr, r_sl, a.data, sa, tmp_space.data()); + } + } + delete_module_info(module); +} + +TEST(vec_znx_big, fft64_vec_znx_big_normalize_base2k) { + test_vec_znx_big_normalize(fft64_vec_znx_big_normalize_base2k, fft64_vec_znx_big_normalize_base2k_tmp_bytes); +} +TEST(vec_znx_big, vec_znx_big_normalize_base2k) { + test_vec_znx_big_normalize(vec_znx_big_normalize_base2k, vec_znx_big_normalize_base2k_tmp_bytes); +} + +static void test_vec_znx_big_range_normalize( // + VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F normalize, + VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F normalize_tmp_bytes) { + // in the FFT64 case, big_normalize is just a forward. + // we will just test that the functions are callable + uint64_t n = 16; + uint64_t k = 11; + MODULE* module = new_module_info(n, FFT64); + for (uint64_t sa : {6, 15, 21}) { + for (uint64_t sr : {3, 5, 7}) { + uint64_t r_sl = n + 3; + def_rand_big(a, n, sa); + uint64_t a_start = uniform_u64_bits(30) % (sa / 2); + uint64_t a_end = sa - (uniform_u64_bits(30) % (sa / 2)); + uint64_t a_step = (uniform_u64_bits(30) % 3) + 1; + uint64_t range_size = (a_end + a_step - 1 - a_start) / a_step; + fft64_vec_znx_big_layout aextr(n, range_size); + for (uint64_t i = 0, idx = a_start; idx < a_end; ++i, idx += a_step) { + aextr.set(i, a.get_copy(idx)); + } + znx_vec_i64_layout r(n, sr, r_sl); + znx_vec_i64_layout r2(n, sr, r_sl); + // tmp_space is large-enough for both + std::vector tmp_space(normalize_tmp_bytes(module)); + normalize(module, k, r.data(), sr, r_sl, a.data, a_start, a_end, a_step, tmp_space.data()); + fft64_vec_znx_big_normalize_base2k(module, k, r2.data(), sr, r_sl, aextr.data, range_size, tmp_space.data()); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), r2.get_copy(i)); + } + } + } + delete_module_info(module); +} + +TEST(vec_znx_big, fft64_vec_znx_big_range_normalize_base2k) { + test_vec_znx_big_range_normalize(fft64_vec_znx_big_range_normalize_base2k, + fft64_vec_znx_big_range_normalize_base2k_tmp_bytes); +} +TEST(vec_znx_big, vec_znx_big_range_normalize_base2k) { + test_vec_znx_big_range_normalize(vec_znx_big_range_normalize_base2k, vec_znx_big_range_normalize_base2k_tmp_bytes); +} + +static void test_vec_znx_big_rotate(VEC_ZNX_BIG_ROTATE_F rotate) { + // in the FFT64 case, big_normalize is just a forward. + // we will just test that the functions are callable + uint64_t n = 16; + int64_t p = 12; + MODULE* module = new_module_info(n, FFT64); + for (uint64_t sa : {3, 5, 7}) { + for (uint64_t sr : {3, 5, 7}) { + def_rand_big(a, n, sa); + fft64_vec_znx_big_layout r(n, sr); + rotate(module, p, r.data, sr, a.data, sa); + for (uint64_t i = 0; i < sr; ++i) { + znx_i64 aa = a.get_copy_zext(i); + znx_i64 expect(n); + for (uint64_t j = 0; j < n; ++j) { + expect.set_coeff(j, aa.get_coeff(int64_t(j) - p)); + } + znx_i64 actual = r.get_copy(i); + ASSERT_EQ(expect, actual); + } + } + } + delete_module_info(module); +} + +TEST(vec_znx_big, fft64_vec_znx_big_rotate) { test_vec_znx_big_rotate(fft64_vec_znx_big_rotate); } +TEST(vec_znx_big, vec_znx_big_rotate) { test_vec_znx_big_rotate(vec_znx_big_rotate); } + +static void test_vec_znx_big_automorphism(VEC_ZNX_BIG_AUTOMORPHISM_F automorphism) { + // in the FFT64 case, big_normalize is just a forward. + // we will just test that the functions are callable + uint64_t n = 16; + int64_t p = 11; + MODULE* module = new_module_info(n, FFT64); + for (uint64_t sa : {3, 5, 7}) { + for (uint64_t sr : {3, 5, 7}) { + def_rand_big(a, n, sa); + fft64_vec_znx_big_layout r(n, sr); + automorphism(module, p, r.data, sr, a.data, sa); + for (uint64_t i = 0; i < sr; ++i) { + znx_i64 aa = a.get_copy_zext(i); + znx_i64 expect(n); + for (uint64_t j = 0; j < n; ++j) { + expect.set_coeff(p * j, aa.get_coeff(j)); + } + znx_i64 actual = r.get_copy(i); + ASSERT_EQ(expect, actual); + } + } + } + delete_module_info(module); +} + +TEST(vec_znx_big, fft64_vec_znx_big_automorphism) { test_vec_znx_big_automorphism(fft64_vec_znx_big_automorphism); } +TEST(vec_znx_big, vec_znx_big_automorphism) { test_vec_znx_big_automorphism(vec_znx_big_automorphism); } diff --git a/spqlios/lib/test/spqlios_vec_znx_dft_test.cpp b/spqlios/lib/test/spqlios_vec_znx_dft_test.cpp new file mode 100644 index 0000000..d72cb0b --- /dev/null +++ b/spqlios/lib/test/spqlios_vec_znx_dft_test.cpp @@ -0,0 +1,193 @@ +#include + +#include + +#include "../spqlios/arithmetic/vec_znx_arithmetic_private.h" +#include "spqlios/arithmetic/vec_znx_arithmetic.h" +#include "test/testlib/ntt120_dft.h" +#include "test/testlib/ntt120_layouts.h" +#include "testlib/fft64_dft.h" +#include "testlib/fft64_layouts.h" +#include "testlib/polynomial_vector.h" + +static void test_fft64_vec_znx_dft(VEC_ZNX_DFT_F dft) { + for (uint64_t n : {2, 4, 128}) { + MODULE* module = new_module_info(n, FFT64); + for (uint64_t sa : {3, 5, 8}) { + for (uint64_t sr : {3, 5, 8}) { + uint64_t a_sl = n + uniform_u64_bits(2); + znx_vec_i64_layout a(n, sa, a_sl); + fft64_vec_znx_dft_layout res(n, sr); + a.fill_random(42); + std::vector expect(sr); + for (uint64_t i = 0; i < sr; ++i) { + expect[i] = simple_fft64(a.get_copy_zext(i)); + } + // test the function + thash hash_before = a.content_hash(); + dft(module, res.data, sr, a.data(), sa, a_sl); + ASSERT_EQ(a.content_hash(), hash_before); + for (uint64_t i = 0; i < sr; ++i) { + reim_fft64vec actual = res.get_copy_zext(i); + ASSERT_LE(infty_dist(actual, expect[i]), 1e-10); + } + } + } + delete_module_info(module); + } +} + +#ifdef __x86_64__ +// FIXME: currently, it only works on avx +static void test_ntt120_vec_znx_dft(VEC_ZNX_DFT_F dft) { + for (uint64_t n : {2, 4, 128}) { + MODULE* module = new_module_info(n, NTT120); + for (uint64_t sa : {3, 5, 8}) { + for (uint64_t sr : {3, 5, 8}) { + uint64_t a_sl = n + uniform_u64_bits(2); + znx_vec_i64_layout a(n, sa, a_sl); + ntt120_vec_znx_dft_layout res(n, sr); + a.fill_random(42); + std::vector expect(sr); + for (uint64_t i = 0; i < sr; ++i) { + expect[i] = simple_ntt120(a.get_copy_zext(i)); + } + // test the function + thash hash_before = a.content_hash(); + dft(module, res.data, sr, a.data(), sa, a_sl); + ASSERT_EQ(a.content_hash(), hash_before); + for (uint64_t i = 0; i < sr; ++i) { + q120_nttvec actual = res.get_copy_zext(i); + if (!(actual == expect[i])) { + for (uint64_t j = 0; j < n; ++j) { + std::cerr << actual.v[j] << " vs " << expect[i].v[j] << std::endl; + } + } + ASSERT_EQ(actual, expect[i]); + } + } + } + delete_module_info(module); + } +} +#endif + +TEST(vec_znx_dft, fft64_vec_znx_dft) { test_fft64_vec_znx_dft(fft64_vec_znx_dft); } +#ifdef __x86_64__ +// FIXME: currently, it only works on avx +TEST(vec_znx_dft, ntt120_vec_znx_dft) { test_ntt120_vec_znx_dft(ntt120_vec_znx_dft_avx); } +#endif +TEST(vec_znx_dft, vec_znx_dft) { + test_fft64_vec_znx_dft(vec_znx_dft); +#ifdef __x86_64__ + // FIXME: currently, it only works on avx + test_ntt120_vec_znx_dft(ntt120_vec_znx_dft_avx); +#endif +} + +static void test_fft64_vec_znx_idft(VEC_ZNX_IDFT_F idft, VEC_ZNX_IDFT_TMP_A_F idft_tmp_a, + VEC_ZNX_IDFT_TMP_BYTES_F idft_tmp_bytes) { + for (uint64_t n : {2, 4, 64, 128}) { + MODULE* module = new_module_info(n, FFT64); + uint64_t tmp_size = idft_tmp_bytes ? idft_tmp_bytes(module) : 0; + std::vector tmp(tmp_size); + for (uint64_t sa : {3, 5, 8}) { + for (uint64_t sr : {3, 5, 8}) { + fft64_vec_znx_dft_layout a(n, sa); + fft64_vec_znx_big_layout res(n, sr); + a.fill_dft_random_log2bound(22); + std::vector expect(sr); + for (uint64_t i = 0; i < sr; ++i) { + expect[i] = simple_rint_ifft64(a.get_copy_zext(i)); + } + // test the function + if (idft_tmp_bytes) { + thash hash_before = a.content_hash(); + idft(module, res.data, sr, a.data, sa, tmp.data()); + ASSERT_EQ(a.content_hash(), hash_before); + } else { + idft_tmp_a(module, res.data, sr, a.data, sa); + } + for (uint64_t i = 0; i < sr; ++i) { + znx_i64 actual = res.get_copy_zext(i); + // ASSERT_EQ(res.get_copy_zext(i), expect[i]); + if (!(actual == expect[i])) { + for (uint64_t j = 0; j < n; ++j) { + std::cerr << actual.get_coeff(j) << " dft vs. " << expect[i].get_coeff(j) << std::endl; + } + FAIL(); + } + } + } + } + delete_module_info(module); + } +} + +TEST(vec_znx_dft, fft64_vec_znx_idft) { + test_fft64_vec_znx_idft(fft64_vec_znx_idft, nullptr, fft64_vec_znx_idft_tmp_bytes); +} +TEST(vec_znx_dft, fft64_vec_znx_idft_tmp_a) { test_fft64_vec_znx_idft(nullptr, fft64_vec_znx_idft_tmp_a, nullptr); } + +#ifdef __x86_64__ +// FIXME: currently, it only works on avx +static void test_ntt120_vec_znx_idft(VEC_ZNX_IDFT_F idft, VEC_ZNX_IDFT_TMP_A_F idft_tmp_a, + VEC_ZNX_IDFT_TMP_BYTES_F idft_tmp_bytes) { + for (uint64_t n : {2, 4, 64, 128}) { + MODULE* module = new_module_info(n, NTT120); + uint64_t tmp_size = idft_tmp_bytes ? idft_tmp_bytes(module) : 0; + std::vector tmp(tmp_size); + for (uint64_t sa : {3, 5, 8}) { + for (uint64_t sr : {3, 5, 8}) { + ntt120_vec_znx_dft_layout a(n, sa); + ntt120_vec_znx_big_layout res(n, sr); + a.fill_random(); + std::vector expect(sr); + for (uint64_t i = 0; i < sr; ++i) { + expect[i] = simple_intt120(a.get_copy_zext(i)); + } + // test the function + if (idft_tmp_bytes) { + thash hash_before = a.content_hash(); + idft(module, res.data, sr, a.data, sa, tmp.data()); + ASSERT_EQ(a.content_hash(), hash_before); + } else { + idft_tmp_a(module, res.data, sr, a.data, sa); + } + for (uint64_t i = 0; i < sr; ++i) { + znx_i128 actual = res.get_copy_zext(i); + ASSERT_EQ(res.get_copy_zext(i), expect[i]); + // if (!(actual == expect[i])) { + // for (uint64_t j = 0; j < n; ++j) { + // std::cerr << actual.get_coeff(j) << " dft vs. " << expect[i].get_coeff(j) << std::endl; + // } + // FAIL(); + // } + } + } + } + delete_module_info(module); + } +} + +TEST(vec_znx_dft, ntt120_vec_znx_idft) { + test_ntt120_vec_znx_idft(ntt120_vec_znx_idft_avx, nullptr, ntt120_vec_znx_idft_tmp_bytes_avx); +} +TEST(vec_znx_dft, ntt120_vec_znx_idft_tmp_a) { + test_ntt120_vec_znx_idft(nullptr, ntt120_vec_znx_idft_tmp_a_avx, nullptr); +} +#endif +TEST(vec_znx_dft, vec_znx_idft) { + test_fft64_vec_znx_idft(vec_znx_idft, nullptr, vec_znx_idft_tmp_bytes); +#ifdef __x86_64__ + // FIXME: currently, only supported on avx + test_ntt120_vec_znx_idft(vec_znx_idft, nullptr, vec_znx_idft_tmp_bytes); +#endif +} +TEST(vec_znx_dft, vec_znx_idft_tmp_a) { + test_fft64_vec_znx_idft(nullptr, vec_znx_idft_tmp_a, nullptr); +#ifdef __x86_64__ + // FIXME: currently, only supported on avx + test_ntt120_vec_znx_idft(nullptr, vec_znx_idft_tmp_a, nullptr); +#endif +} diff --git a/spqlios/lib/test/spqlios_vec_znx_test.cpp b/spqlios/lib/test/spqlios_vec_znx_test.cpp new file mode 100644 index 0000000..52540c1 --- /dev/null +++ b/spqlios/lib/test/spqlios_vec_znx_test.cpp @@ -0,0 +1,546 @@ +#include +#include + +#include "../spqlios/arithmetic/vec_znx_arithmetic.h" +#include "gtest/gtest.h" +#include "spqlios/arithmetic/vec_znx_arithmetic_private.h" +#include "spqlios/coeffs/coeffs_arithmetic.h" +#include "test/testlib/mod_q120.h" +#include "test/testlib/negacyclic_polynomial.h" +#include "testlib/fft64_dft.h" +#include "testlib/polynomial_vector.h" + +TEST(fft64_layouts, dft_idft_fft64) { + uint64_t n = 128; + // create a random polynomial + znx_i64 p(n); + for (uint64_t i = 0; i < n; ++i) { + p.set_coeff(i, uniform_i64_bits(36)); + } + // call fft + reim_fft64vec q = simple_fft64(p); + // call ifft and round + znx_i64 r = simple_rint_ifft64(q); + ASSERT_EQ(p, r); +} + +TEST(znx_layout, valid_test) { + uint64_t n = 4; + znx_vec_i64_layout v(n, 7, 13); + // this should be ok + v.set(0, znx_i64::zero(n)); + // this should be ok + ASSERT_EQ(v.get_copy_zext(0), znx_i64::zero(n)); + ASSERT_EQ(v.data()[2], 0); // should be ok + // this is also ok (zero extended vector) + ASSERT_EQ(v.get_copy_zext(1000), znx_i64::zero(n)); +} + +// disabling this test by default, since it depicts on purpose wrong accesses +#if 0 +TEST(znx_layout, valgrind_antipattern_test) { + uint64_t n = 4; + znx_vec_i64_layout v(n, 7, 13); + // this should be ok + v.set(0, znx_i64::zero(n)); + // this should abort (wrong ring dimension) + ASSERT_DEATH(v.set(3, znx_i64::zero(2 * n)), ""); + // this should abort (out of bounds) + ASSERT_DEATH(v.set(8, znx_i64::zero(n)), ""); + // this should be ok + ASSERT_EQ(v.get_copy_zext(0), znx_i64::zero(n)); + // should be an uninit read + ASSERT_TRUE(!(v.get_copy_zext(2) == znx_i64::zero(n))); // should be uninit + // should be an invalid read (inter-slice) + ASSERT_NE(v.data()[4], 0); + ASSERT_EQ(v.data()[2], 0); // should be ok + // should be an uninit read + ASSERT_NE(v.data()[13], 0); // should be uninit +} +#endif + +// test of binary operations + +// test for out of place calls +template +void test_vec_znx_elemw_binop_outplace(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 4, 8, 128}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + for (uint64_t sc : {7, 13, 15}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + uint64_t c_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + znx_vec_i64_layout lb(n, sb, b_sl); + znx_vec_i64_layout lc(n, sc, c_sl); + std::vector expect(sc); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sc; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + lc.data(), sc, c_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sc; ++i) { + ASSERT_EQ(lc.get_copy_zext(i), expect[i]); + } + } + } + } + delete_module_info(mod); + } +} +// test for inplace1 calls +template +void test_vec_znx_elemw_binop_inplace1(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 4, 64}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {3, 9, 12}) { + for (uint64_t sb : {3, 9, 12}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + znx_vec_i64_layout lb(n, sb, b_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]); + } + } + } + delete_module_info(mod); + } +} +// test for inplace2 calls +template +void test_vec_znx_elemw_binop_inplace2(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {4, 32, 64}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {3, 9, 12}) { + for (uint64_t sb : {3, 9, 12}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + znx_vec_i64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + lb.data(), sb, b_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]); + } + } + } + delete_module_info(mod); + } +} +// test for inplace3 calls +template +void test_vec_znx_elemw_binop_inplace3(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 16, 1024}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {2, 6, 11}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), la.get_copy_zext(i)); + } + binop(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl, // a + la.data(), sa, a_sl); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]); + } + } + delete_module_info(mod); + } +} +template +void test_vec_znx_elemw_binop(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_znx_elemw_binop_outplace(binop, ref_binop); + test_vec_znx_elemw_binop_inplace1(binop, ref_binop); + test_vec_znx_elemw_binop_inplace2(binop, ref_binop); + test_vec_znx_elemw_binop_inplace3(binop, ref_binop); +} + +static znx_i64 poly_add(const znx_i64& a, const znx_i64& b) { return a + b; } +TEST(vec_znx, vec_znx_add) { test_vec_znx_elemw_binop(vec_znx_add, poly_add); } +TEST(vec_znx, vec_znx_add_ref) { test_vec_znx_elemw_binop(vec_znx_add_ref, poly_add); } +#ifdef __x86_64__ +TEST(vec_znx, vec_znx_add_avx) { test_vec_znx_elemw_binop(vec_znx_add_avx, poly_add); } +#endif + +static znx_i64 poly_sub(const znx_i64& a, const znx_i64& b) { return a - b; } +TEST(vec_znx, vec_znx_sub) { test_vec_znx_elemw_binop(vec_znx_sub, poly_sub); } +TEST(vec_znx, vec_znx_sub_ref) { test_vec_znx_elemw_binop(vec_znx_sub_ref, poly_sub); } +#ifdef __x86_64__ +TEST(vec_znx, vec_znx_sub_avx) { test_vec_znx_elemw_binop(vec_znx_sub_avx, poly_sub); } +#endif + +// test of rotation operations + +// test for out of place calls +template +void test_vec_znx_elemw_unop_param_outplace(ACTUAL_FCN test_rotate, EXPECT_FCN ref_rotate, int64_t (*param_gen)()) { + for (uint64_t n : {2, 4, 8, 128}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + { + int64_t p = param_gen(); + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 4 + n; + znx_vec_i64_layout la(n, sa, a_sl); + znx_vec_i64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_rotate(p, la.get_copy_zext(i)); + } + test_rotate(mod, // + p, // + lb.data(), sb, b_sl, // + la.data(), sa, a_sl // + ); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]) << n << " " << sa << " " << sb << " " << i; + } + } + } + } + delete_module_info(mod); + } +} + +// test for inplace calls +template +void test_vec_znx_elemw_unop_param_inplace(ACTUAL_FCN test_rotate, EXPECT_FCN ref_rotate, int64_t (*param_gen)()) { + for (uint64_t n : {2, 16, 1024}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {2, 6, 11}) { + { + int64_t p = param_gen(); + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_rotate(p, la.get_copy_zext(i)); + } + test_rotate(mod, // N + p, //; + la.data(), sa, a_sl, // res + la.data(), sa, a_sl // a + ); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]) << n << " " << sa << " " << i; + } + } + } + delete_module_info(mod); + } +} + +static int64_t random_rotate_param() { return uniform_i64(); } + +template +void test_vec_znx_elemw_rotate(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_znx_elemw_unop_param_outplace(binop, ref_binop, random_rotate_param); + test_vec_znx_elemw_unop_param_inplace(binop, ref_binop, random_rotate_param); +} + +static znx_i64 poly_rotate(const int64_t p, const znx_i64& a) { + uint64_t n = a.nn(); + znx_i64 res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, a.get_coeff(i - p)); + } + return res; +} +TEST(vec_znx, vec_znx_rotate) { test_vec_znx_elemw_rotate(vec_znx_rotate, poly_rotate); } +TEST(vec_znx, vec_znx_rotate_ref) { test_vec_znx_elemw_rotate(vec_znx_rotate_ref, poly_rotate); } + +static int64_t random_automorphism_param() { return uniform_i64() | 1; } + +template +void test_vec_znx_elemw_automorphism(ACTUAL_FCN unop, EXPECT_FCN ref_unop) { + test_vec_znx_elemw_unop_param_outplace(unop, ref_unop, random_automorphism_param); + test_vec_znx_elemw_unop_param_inplace(unop, ref_unop, random_automorphism_param); +} + +static znx_i64 poly_automorphism(const int64_t p, const znx_i64& a) { + uint64_t n = a.nn(); + znx_i64 res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i * p, a.get_coeff(i)); + } + return res; +} + +TEST(vec_znx, vec_znx_automorphism) { test_vec_znx_elemw_automorphism(vec_znx_automorphism, poly_automorphism); } +TEST(vec_znx, vec_znx_automorphism_ref) { + test_vec_znx_elemw_automorphism(vec_znx_automorphism_ref, poly_automorphism); +} + +// test for out of place calls +template +void test_vec_znx_elemw_unop_outplace(ACTUAL_FCN test_unop, EXPECT_FCN ref_unop) { + for (uint64_t n : {2, 4, 8, 128}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 4 + n; + znx_vec_i64_layout la(n, sa, a_sl); + znx_vec_i64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_unop(la.get_copy_zext(i)); + } + test_unop(mod, // + lb.data(), sb, b_sl, // + la.data(), sa, a_sl // + ); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]) << n << " " << sa << " " << sb << " " << i; + } + } + } + } + delete_module_info(mod); + } +} + +// test for inplace calls +template +void test_vec_znx_elemw_unop_inplace(ACTUAL_FCN test_unop, EXPECT_FCN ref_unop) { + for (uint64_t n : {2, 16, 1024}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {2, 6, 11}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_unop(la.get_copy_zext(i)); + } + test_unop(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl // a + ); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]) << n << " " << sa << " " << i; + } + } + } + delete_module_info(mod); + } +} + +template +void test_vec_znx_elemw_unop(ACTUAL_FCN unop, EXPECT_FCN ref_unop) { + test_vec_znx_elemw_unop_outplace(unop, ref_unop); + test_vec_znx_elemw_unop_inplace(unop, ref_unop); +} + +static znx_i64 poly_copy(const znx_i64& a) { return a; } + +TEST(vec_znx, vec_znx_copy) { test_vec_znx_elemw_unop(vec_znx_copy, poly_copy); } +TEST(vec_znx, vec_znx_copy_ref) { test_vec_znx_elemw_unop(vec_znx_copy_ref, poly_copy); } + +static znx_i64 poly_negate(const znx_i64& a) { return -a; } + +TEST(vec_znx, vec_znx_negate) { test_vec_znx_elemw_unop(vec_znx_negate, poly_negate); } +TEST(vec_znx, vec_znx_negate_ref) { test_vec_znx_elemw_unop(vec_znx_negate_ref, poly_negate); } +#ifdef __x86_64__ +TEST(vec_znx, vec_znx_negate_avx) { test_vec_znx_elemw_unop(vec_znx_negate_avx, poly_negate); } +#endif + +static void test_vec_znx_zero(VEC_ZNX_ZERO_F zero) { + for (uint64_t n : {2, 16, 1024}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {2, 6, 11}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + zero(mod, // N + la.data(), sa, a_sl // res + ); + znx_i64 ZERO = znx_i64::zero(n); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), ZERO) << n << " " << sa << " " << i; + } + } + } + delete_module_info(mod); + } +} + +TEST(vec_znx, vec_znx_zero) { test_vec_znx_zero(vec_znx_zero); } +TEST(vec_znx, vec_znx_zero_ref) { test_vec_znx_zero(vec_znx_zero_ref); } + +static void vec_poly_normalize(const uint64_t base_k, std::vector& in) { + if (in.size() > 0) { + uint64_t n = in.front().nn(); + + znx_i64 out = znx_i64::random_log2bound(n, 62); + znx_i64 cinout(n); + for (int64_t i = in.size() - 1; i >= 0; --i) { + znx_normalize(n, base_k, in[i].data(), cinout.data(), in[i].data(), cinout.data()); + } + } +} + +template +void test_vec_znx_normalize_outplace(ACTUAL_FCN test_normalize, TMP_BYTES_FNC tmp_bytes) { + for (uint64_t n : {2, 16, 1024}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {1, 2, 6, 11}) { + for (uint64_t sb : {1, 2, 6, 11}) { + for (uint64_t base_k : {19}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + + std::vector la_norm; + for (uint64_t i = 0; i < sa; ++i) { + la_norm.push_back(la.get_copy_zext(i)); + } + vec_poly_normalize(base_k, la_norm); + + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout lb(n, sb, b_sl); + + const uint64_t tmp_size = tmp_bytes(mod); + uint8_t* tmp = new uint8_t[tmp_size]; + test_normalize(mod, // N + base_k, // base_k + lb.data(), sb, b_sl, // res + la.data(), sa, a_sl, // a + tmp); + delete[] tmp; + + for (uint64_t i = 0; i < std::min(sa, sb); ++i) { + ASSERT_EQ(lb.get_copy_zext(i), la_norm[i]) << n << " " << sa << " " << sb << " " << i; + } + znx_i64 zero(n); + for (uint64_t i = std::min(sa, sb); i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), zero) << n << " " << sa << " " << sb << " " << i; + } + } + } + } + delete_module_info(mod); + } +} + +TEST(vec_znx, vec_znx_normalize_outplace) { + test_vec_znx_normalize_outplace(vec_znx_normalize_base2k, vec_znx_normalize_base2k_tmp_bytes); +} +TEST(vec_znx, vec_znx_normalize_outplace_ref) { + test_vec_znx_normalize_outplace(vec_znx_normalize_base2k_ref, vec_znx_normalize_base2k_tmp_bytes_ref); +} + +template +void test_vec_znx_normalize_inplace(ACTUAL_FCN test_normalize, TMP_BYTES_FNC tmp_bytes) { + for (uint64_t n : {2, 16, 1024}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {2, 6, 11}) { + for (uint64_t base_k : {19}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + + std::vector la_norm; + for (uint64_t i = 0; i < sa; ++i) { + la_norm.push_back(la.get_copy_zext(i)); + } + vec_poly_normalize(base_k, la_norm); + + const uint64_t tmp_size = tmp_bytes(mod); + uint8_t* tmp = new uint8_t[tmp_size]; + test_normalize(mod, // N + base_k, // base_k + la.data(), sa, a_sl, // res + la.data(), sa, a_sl, // a + tmp); + delete[] tmp; + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), la_norm[i]) << n << " " << sa << " " << i; + } + } + } + delete_module_info(mod); + } +} + +TEST(vec_znx, vec_znx_normalize_inplace) { + test_vec_znx_normalize_inplace(vec_znx_normalize_base2k, vec_znx_normalize_base2k_tmp_bytes); +} +TEST(vec_znx, vec_znx_normalize_inplace_ref) { + test_vec_znx_normalize_inplace(vec_znx_normalize_base2k_ref, vec_znx_normalize_base2k_tmp_bytes_ref); +} diff --git a/spqlios/lib/test/spqlios_vmp_product_test.cpp b/spqlios/lib/test/spqlios_vmp_product_test.cpp new file mode 100644 index 0000000..cb55818 --- /dev/null +++ b/spqlios/lib/test/spqlios_vmp_product_test.cpp @@ -0,0 +1,121 @@ +#include + +#include "../spqlios/arithmetic/vec_znx_arithmetic_private.h" +#include "testlib/fft64_layouts.h" +#include "testlib/polynomial_vector.h" + +static void test_vmp_prepare_contiguous(VMP_PREPARE_CONTIGUOUS_F* prepare_contiguous, + VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F* tmp_bytes) { + // tests when n < 8 + for (uint64_t nn : {2, 4}) { + MODULE* module = new_module_info(nn, FFT64); + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + znx_vec_i64_layout mat(nn, nrows * ncols, nn); + fft64_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(30); + std::vector tmp_space(fft64_vmp_prepare_contiguous_tmp_bytes(module, nrows, ncols)); + thash hash_before = mat.content_hash(); + prepare_contiguous(module, pmat.data, mat.data(), nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + const double* pmatv = (double*)pmat.data + (col * nrows + row) * nn; + reim_fft64vec tmp = simple_fft64(mat.get_copy(row * ncols + col)); + const double* tmpv = tmp.data(); + for (uint64_t i = 0; i < nn; ++i) { + ASSERT_LE(abs(pmatv[i] - tmpv[i]), 1e-10); + } + } + } + } + } + delete_module_info(module); + } + // tests when n >= 8 + for (uint64_t nn : {8, 32}) { + MODULE* module = new_module_info(nn, FFT64); + uint64_t nblk = nn / 8; + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + znx_vec_i64_layout mat(nn, nrows * ncols, nn); + fft64_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(30); + std::vector tmp_space(tmp_bytes(module, nrows, ncols)); + thash hash_before = mat.content_hash(); + prepare_contiguous(module, pmat.data, mat.data(), nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + reim_fft64vec tmp = simple_fft64(mat.get_copy(row * ncols + col)); + for (uint64_t blk = 0; blk < nblk; ++blk) { + reim4_elem expect = tmp.get_blk(blk); + reim4_elem actual = pmat.get(row, col, blk); + ASSERT_LE(infty_dist(actual, expect), 1e-10); + } + } + } + } + } + delete_module_info(module); + } +} + +TEST(vec_znx, vmp_prepare_contiguous) { + test_vmp_prepare_contiguous(vmp_prepare_contiguous, vmp_prepare_contiguous_tmp_bytes); +} +TEST(vec_znx, fft64_vmp_prepare_contiguous_ref) { + test_vmp_prepare_contiguous(fft64_vmp_prepare_contiguous_ref, fft64_vmp_prepare_contiguous_tmp_bytes); +} +#ifdef __x86_64__ +TEST(vec_znx, fft64_vmp_prepare_contiguous_avx) { + test_vmp_prepare_contiguous(fft64_vmp_prepare_contiguous_avx, fft64_vmp_prepare_contiguous_tmp_bytes); +} +#endif + +static void test_vmp_apply(VMP_APPLY_DFT_TO_DFT_F* apply, VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {2, 4, 8, 64}) { + MODULE* module = new_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 4, 7}) { + for (uint64_t mat_ncols : {1, 2, 5}) { + for (uint64_t in_size : {1, 4, 7}) { + for (uint64_t out_size : {1, 2, 5}) { + fft64_vec_znx_dft_layout in(nn, in_size); + fft64_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + fft64_vec_znx_dft_layout out(nn, out_size); + in.fill_random(0); + pmat.fill_random(0); + // naive computation of the product + std::vector expect(out_size, reim_fft64vec(nn)); + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col) * in.get_copy_zext(row); + } + expect[col] = ex; + } + // apply the product + std::vector tmp(tmp_bytes(module, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, out.data, out_size, in.data, in_size, pmat.data, mat_nrows, mat_ncols, tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec actual = out.get_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + delete_module_info(module); + } +} + +TEST(vec_znx, vmp_apply_to_dft) { test_vmp_apply(vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_tmp_bytes); } +TEST(vec_znx, fft64_vmp_apply_dft_to_dft_ref) { + test_vmp_apply(fft64_vmp_apply_dft_to_dft_ref, fft64_vmp_apply_dft_to_dft_tmp_bytes); +} +#ifdef __x86_64__ +TEST(vec_znx, fft64_vmp_apply_dft_to_dft_avx) { + test_vmp_apply(fft64_vmp_apply_dft_to_dft_avx, fft64_vmp_apply_dft_to_dft_tmp_bytes); +} +#endif diff --git a/spqlios/lib/test/spqlios_zn_approxdecomp_test.cpp b/spqlios/lib/test/spqlios_zn_approxdecomp_test.cpp new file mode 100644 index 0000000..d21f420 --- /dev/null +++ b/spqlios/lib/test/spqlios_zn_approxdecomp_test.cpp @@ -0,0 +1,46 @@ +#include "gtest/gtest.h" +#include "spqlios/arithmetic/zn_arithmetic_private.h" +#include "testlib/test_commons.h" + +template +static void test_tndbl_approxdecomp( // + void (*approxdec)(const MOD_Z*, const TNDBL_APPROXDECOMP_GADGET*, INTTYPE*, uint64_t, const double*, uint64_t) // +) { + for (const uint64_t nn : {1, 3, 8, 51}) { + MOD_Z* module = new_z_module_info(DEFAULT); + for (const uint64_t ell : {1, 2, 7}) { + for (const uint64_t k : {2, 5}) { + TNDBL_APPROXDECOMP_GADGET* gadget = new_tndbl_approxdecomp_gadget(module, k, ell); + for (const uint64_t res_size : {ell * nn}) { + std::vector in(nn); + std::vector out(res_size); + for (double& x : in) x = uniform_f64_bounds(-10, 10); + approxdec(module, gadget, out.data(), res_size, in.data(), nn); + // reconstruct the output + double err_bnd = pow(2., -double(ell * k) - 1); + for (uint64_t j = 0; j < nn; ++j) { + double in_j = in[j]; + double out_j = 0; + for (uint64_t i = 0; i < ell; ++i) { + out_j += out[ell * j + i] * pow(2., -double((i + 1) * k)); + } + double err = out_j - in_j; + double err_abs = fabs(err - rint(err)); + ASSERT_LE(err_abs, err_bnd); + } + } + delete_tndbl_approxdecomp_gadget(gadget); + } + } + delete_z_module_info(module); + } +} + +TEST(vec_rnx, i8_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i8_approxdecomp_from_tndbl); } +TEST(vec_rnx, default_i8_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i8_approxdecomp_from_tndbl_ref); } + +TEST(vec_rnx, i16_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i16_approxdecomp_from_tndbl); } +TEST(vec_rnx, default_i16_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i16_approxdecomp_from_tndbl_ref); } + +TEST(vec_rnx, i32_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i32_approxdecomp_from_tndbl); } +TEST(vec_rnx, default_i32_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i32_approxdecomp_from_tndbl_ref); } diff --git a/spqlios/lib/test/spqlios_zn_conversions_test.cpp b/spqlios/lib/test/spqlios_zn_conversions_test.cpp new file mode 100644 index 0000000..da2b94b --- /dev/null +++ b/spqlios/lib/test/spqlios_zn_conversions_test.cpp @@ -0,0 +1,104 @@ +#include +#include + +#include "testlib/test_commons.h" + +template +static void test_conv(void (*conv_f)(const MOD_Z*, DST_T* res, uint64_t res_size, const SRC_T* a, uint64_t a_size), + DST_T (*ideal_conv_f)(SRC_T x), SRC_T (*random_f)()) { + MOD_Z* module = new_z_module_info(DEFAULT); + for (uint64_t a_size : {0, 1, 2, 42}) { + for (uint64_t res_size : {0, 1, 2, 42}) { + for (uint64_t trials = 0; trials < 100; ++trials) { + std::vector a(a_size); + std::vector res(res_size); + uint64_t msize = std::min(a_size, res_size); + for (SRC_T& x : a) x = random_f(); + conv_f(module, res.data(), res_size, a.data(), a_size); + for (uint64_t i = 0; i < msize; ++i) { + DST_T expect = ideal_conv_f(a[i]); + DST_T actual = res[i]; + ASSERT_EQ(expect, actual); + } + for (uint64_t i = msize; i < res_size; ++i) { + DST_T expect = 0; + SRC_T actual = res[i]; + ASSERT_EQ(expect, actual); + } + } + } + } + delete_z_module_info(module); +} + +static int32_t ideal_dbl_to_tn32(double a) { + double _2p32 = INT64_C(1) << 32; + double a_mod_1 = a - rint(a); + int64_t t = rint(a_mod_1 * _2p32); + return int32_t(t); +} + +static double random_f64_10() { return uniform_f64_bounds(-10, 10); } + +static void test_dbl_to_tn32(DBL_TO_TN32_F dbl_to_tn32_f) { + test_conv(dbl_to_tn32_f, ideal_dbl_to_tn32, random_f64_10); +} + +TEST(zn_arithmetic, dbl_to_tn32) { test_dbl_to_tn32(dbl_to_tn32); } +TEST(zn_arithmetic, dbl_to_tn32_ref) { test_dbl_to_tn32(dbl_to_tn32_ref); } + +static double ideal_tn32_to_dbl(int32_t a) { + const double _2p32 = INT64_C(1) << 32; + return double(a) / _2p32; +} + +static int32_t random_t32() { return uniform_i64_bits(32); } + +static void test_tn32_to_dbl(TN32_TO_DBL_F tn32_to_dbl_f) { test_conv(tn32_to_dbl_f, ideal_tn32_to_dbl, random_t32); } + +TEST(zn_arithmetic, tn32_to_dbl) { test_tn32_to_dbl(tn32_to_dbl); } +TEST(zn_arithmetic, tn32_to_dbl_ref) { test_tn32_to_dbl(tn32_to_dbl_ref); } + +static int32_t ideal_dbl_round_to_i32(double a) { return int32_t(rint(a)); } + +static double random_dbl_explaw_18() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(6) % 19); } + +static void test_dbl_round_to_i32(DBL_ROUND_TO_I32_F dbl_round_to_i32_f) { + test_conv(dbl_round_to_i32_f, ideal_dbl_round_to_i32, random_dbl_explaw_18); +} + +TEST(zn_arithmetic, dbl_round_to_i32) { test_dbl_round_to_i32(dbl_round_to_i32); } +TEST(zn_arithmetic, dbl_round_to_i32_ref) { test_dbl_round_to_i32(dbl_round_to_i32_ref); } + +static double ideal_i32_to_dbl(int32_t a) { return double(a); } + +static int32_t random_i32_explaw_18() { return uniform_i64_bits(uniform_u64_bits(6) % 19); } + +static void test_i32_to_dbl(I32_TO_DBL_F i32_to_dbl_f) { + test_conv(i32_to_dbl_f, ideal_i32_to_dbl, random_i32_explaw_18); +} + +TEST(zn_arithmetic, i32_to_dbl) { test_i32_to_dbl(i32_to_dbl); } +TEST(zn_arithmetic, i32_to_dbl_ref) { test_i32_to_dbl(i32_to_dbl_ref); } + +static int64_t ideal_dbl_round_to_i64(double a) { return rint(a); } + +static double random_dbl_explaw_50() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(7) % 51); } + +static void test_dbl_round_to_i64(DBL_ROUND_TO_I64_F dbl_round_to_i64_f) { + test_conv(dbl_round_to_i64_f, ideal_dbl_round_to_i64, random_dbl_explaw_50); +} + +TEST(zn_arithmetic, dbl_round_to_i64) { test_dbl_round_to_i64(dbl_round_to_i64); } +TEST(zn_arithmetic, dbl_round_to_i64_ref) { test_dbl_round_to_i64(dbl_round_to_i64_ref); } + +static double ideal_i64_to_dbl(int64_t a) { return double(a); } + +static int64_t random_i64_explaw_50() { return uniform_i64_bits(uniform_u64_bits(7) % 51); } + +static void test_i64_to_dbl(I64_TO_DBL_F i64_to_dbl_f) { + test_conv(i64_to_dbl_f, ideal_i64_to_dbl, random_i64_explaw_50); +} + +TEST(zn_arithmetic, i64_to_dbl) { test_i64_to_dbl(i64_to_dbl); } +TEST(zn_arithmetic, i64_to_dbl_ref) { test_i64_to_dbl(i64_to_dbl_ref); } diff --git a/spqlios/lib/test/spqlios_zn_vmp_test.cpp b/spqlios/lib/test/spqlios_zn_vmp_test.cpp new file mode 100644 index 0000000..8f6fa25 --- /dev/null +++ b/spqlios/lib/test/spqlios_zn_vmp_test.cpp @@ -0,0 +1,67 @@ +#include "gtest/gtest.h" +#include "spqlios/arithmetic/zn_arithmetic_private.h" +#include "testlib/zn_layouts.h" + +static void test_zn_vmp_prepare(ZN32_VMP_PREPARE_CONTIGUOUS_F prep) { + MOD_Z* module = new_z_module_info(DEFAULT); + for (uint64_t nrows : {1, 2, 5, 15}) { + for (uint64_t ncols : {1, 2, 32, 42, 67}) { + std::vector src(nrows * ncols); + zn32_pmat_layout out(nrows, ncols); + for (int32_t& x : src) x = uniform_i64_bits(32); + prep(module, out.data, src.data(), nrows, ncols); + for (uint64_t i = 0; i < nrows; ++i) { + for (uint64_t j = 0; j < ncols; ++j) { + int32_t in = src[i * ncols + j]; + int32_t actual = out.get(i, j); + ASSERT_EQ(actual, in); + } + } + } + } + delete_z_module_info(module); +} + +TEST(zn, zn32_vmp_prepare_contiguous) { test_zn_vmp_prepare(zn32_vmp_prepare_contiguous); } +TEST(zn, default_zn32_vmp_prepare_contiguous_ref) { test_zn_vmp_prepare(default_zn32_vmp_prepare_contiguous_ref); } + +template +static void test_zn_vmp_apply(void (*apply)(const MOD_Z*, int32_t*, uint64_t, const INTTYPE*, uint64_t, + const ZN32_VMP_PMAT*, uint64_t, uint64_t)) { + MOD_Z* module = new_z_module_info(DEFAULT); + for (uint64_t nrows : {1, 2, 5, 15}) { + for (uint64_t ncols : {1, 2, 32, 42, 67}) { + for (uint64_t a_size : {1, 2, 5, 15}) { + for (uint64_t res_size : {1, 2, 32, 42, 67}) { + std::vector a(a_size); + zn32_pmat_layout out(nrows, ncols); + std::vector res(res_size); + for (INTTYPE& x : a) x = uniform_i64_bits(32); + out.fill_random(); + std::vector expect = vmp_product(a.data(), a_size, res_size, out); + apply(module, res.data(), res_size, a.data(), a_size, out.data, nrows, ncols); + for (uint64_t i = 0; i < res_size; ++i) { + int32_t exp = expect[i]; + int32_t actual = res[i]; + ASSERT_EQ(actual, exp); + } + } + } + } + } + delete_z_module_info(module); +} + +TEST(zn, zn32_vmp_apply_i32) { test_zn_vmp_apply(zn32_vmp_apply_i32); } +TEST(zn, zn32_vmp_apply_i16) { test_zn_vmp_apply(zn32_vmp_apply_i16); } +TEST(zn, zn32_vmp_apply_i8) { test_zn_vmp_apply(zn32_vmp_apply_i8); } + +TEST(zn, default_zn32_vmp_apply_i32_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i32_ref); } +TEST(zn, default_zn32_vmp_apply_i16_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i16_ref); } +TEST(zn, default_zn32_vmp_apply_i8_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i8_ref); } + +#ifdef __x86_64__ +TEST(zn, default_zn32_vmp_apply_i32_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i32_avx); } +TEST(zn, default_zn32_vmp_apply_i16_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i16_avx); } +TEST(zn, default_zn32_vmp_apply_i8_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i8_avx); } +#endif diff --git a/spqlios/lib/test/spqlios_znx_small_test.cpp b/spqlios/lib/test/spqlios_znx_small_test.cpp new file mode 100644 index 0000000..477b4c5 --- /dev/null +++ b/spqlios/lib/test/spqlios_znx_small_test.cpp @@ -0,0 +1,26 @@ +#include + +#include "../spqlios/arithmetic/vec_znx_arithmetic_private.h" +#include "testlib/negacyclic_polynomial.h" + +static void test_znx_small_single_product(ZNX_SMALL_SINGLE_PRODUCT_F product, + ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F product_tmp_bytes) { + for (const uint64_t nn : {2, 4, 8, 64}) { + MODULE* module = new_module_info(nn, FFT64); + znx_i64 a = znx_i64::random_log2bound(nn, 20); + znx_i64 b = znx_i64::random_log2bound(nn, 20); + znx_i64 expect = naive_product(a, b); + znx_i64 actual(nn); + std::vector tmp(znx_small_single_product_tmp_bytes(module)); + fft64_znx_small_single_product(module, actual.data(), a.data(), b.data(), tmp.data()); + ASSERT_EQ(actual, expect) << actual.get_coeff(0) << " vs. " << expect.get_coeff(0); + delete_module_info(module); + } +} + +TEST(znx_small, fft64_znx_small_single_product) { + test_znx_small_single_product(fft64_znx_small_single_product, fft64_znx_small_single_product_tmp_bytes); +} +TEST(znx_small, znx_small_single_product) { + test_znx_small_single_product(znx_small_single_product, znx_small_single_product_tmp_bytes); +} diff --git a/spqlios/lib/test/testlib/fft64_dft.cpp b/spqlios/lib/test/testlib/fft64_dft.cpp new file mode 100644 index 0000000..e66b680 --- /dev/null +++ b/spqlios/lib/test/testlib/fft64_dft.cpp @@ -0,0 +1,168 @@ +#include "fft64_dft.h" + +#include + +#include "../../spqlios/reim/reim_fft.h" +#include "../../spqlios/reim/reim_fft_internal.h" + +reim_fft64vec::reim_fft64vec(uint64_t n) : v(n, 0) {} +reim4_elem reim_fft64vec::get_blk(uint64_t blk) const { + return reim_view(v.size() / 2, (double*)v.data()).get_blk(blk); +} +double* reim_fft64vec::data() { return v.data(); } +const double* reim_fft64vec::data() const { return v.data(); } +uint64_t reim_fft64vec::nn() const { return v.size(); } +reim_fft64vec::reim_fft64vec(uint64_t n, const double* data) : v(data, data + n) {} +void reim_fft64vec::save_as(double* dest) const { memcpy(dest, v.data(), nn() * sizeof(double)); } +reim_fft64vec reim_fft64vec::zero(uint64_t n) { return reim_fft64vec(n); } +void reim_fft64vec::set_blk(uint64_t blk, const reim4_elem& value) { + reim_view(v.size() / 2, (double*)v.data()).set_blk(blk, value); +} +reim_fft64vec reim_fft64vec::dft_random(uint64_t n, uint64_t log2bound) { + return simple_fft64(znx_i64::random_log2bound(n, log2bound)); +} +reim_fft64vec reim_fft64vec::random(uint64_t n, double log2bound) { + double bound = pow(2., log2bound); + reim_fft64vec res(n); + for (uint64_t i = 0; i < n; ++i) { + res.v[i] = uniform_f64_bounds(-bound, bound); + } + return res; +} + +reim_fft64vec operator+(const reim_fft64vec& a, const reim_fft64vec& b) { + uint64_t nn = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == a.nn(), "ring dimension mismatch"); + reim_fft64vec res(nn); + double* rv = res.data(); + const double* av = a.data(); + const double* bv = b.data(); + for (uint64_t i = 0; i < nn; ++i) { + rv[i] = av[i] + bv[i]; + } + return res; +} +reim_fft64vec operator-(const reim_fft64vec& a, const reim_fft64vec& b) { + uint64_t nn = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == a.nn(), "ring dimension mismatch"); + reim_fft64vec res(nn); + double* rv = res.data(); + const double* av = a.data(); + const double* bv = b.data(); + for (uint64_t i = 0; i < nn; ++i) { + rv[i] = av[i] - bv[i]; + } + return res; +} +reim_fft64vec operator*(const reim_fft64vec& a, const reim_fft64vec& b) { + uint64_t nn = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == a.nn(), "ring dimension mismatch"); + REQUIRE_DRAMATICALLY(nn >= 2, "test not defined for nn=1"); + uint64_t m = nn / 2; + reim_fft64vec res(nn); + double* rv = res.data(); + const double* av = a.data(); + const double* bv = b.data(); + for (uint64_t i = 0; i < m; ++i) { + rv[i] = av[i] * bv[i] - av[m + i] * bv[m + i]; + rv[m + i] = av[i] * bv[m + i] + av[m + i] * bv[i]; + } + return res; +} +reim_fft64vec& operator+=(reim_fft64vec& a, const reim_fft64vec& b) { + uint64_t nn = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == a.nn(), "ring dimension mismatch"); + double* av = a.data(); + const double* bv = b.data(); + for (uint64_t i = 0; i < nn; ++i) { + av[i] = av[i] + bv[i]; + } + return a; +} +reim_fft64vec& operator-=(reim_fft64vec& a, const reim_fft64vec& b) { + uint64_t nn = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == a.nn(), "ring dimension mismatch"); + double* av = a.data(); + const double* bv = b.data(); + for (uint64_t i = 0; i < nn; ++i) { + av[i] = av[i] - bv[i]; + } + return a; +} + +reim_fft64vec simple_fft64(const znx_i64& polynomial) { + const uint64_t nn = polynomial.nn(); + const uint64_t m = nn / 2; + reim_fft64vec res(nn); + double* dat = res.data(); + for (uint64_t i = 0; i < nn; ++i) dat[i] = polynomial.get_coeff(i); + reim_fft_simple(m, dat); + return res; +} + +znx_i64 simple_rint_ifft64(const reim_fft64vec& fftvec) { + const uint64_t nn = fftvec.nn(); + const uint64_t m = nn / 2; + std::vector vv(fftvec.data(), fftvec.data() + nn); + double* v = vv.data(); + reim_ifft_simple(m, v); + znx_i64 res(nn); + for (uint64_t i = 0; i < nn; ++i) { + res.set_coeff(i, rint(v[i] / m)); + } + return res; +} + +rnx_f64 naive_ifft64(const reim_fft64vec& fftvec) { + const uint64_t nn = fftvec.nn(); + const uint64_t m = nn / 2; + std::vector vv(fftvec.data(), fftvec.data() + nn); + double* v = vv.data(); + reim_ifft_simple(m, v); + rnx_f64 res(nn); + for (uint64_t i = 0; i < nn; ++i) { + res.set_coeff(i, v[i] / m); + } + return res; +} +double infty_dist(const reim_fft64vec& a, const reim_fft64vec& b) { + const uint64_t n = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == a.nn(), "dimensions mismatch"); + const double* da = a.data(); + const double* db = b.data(); + double d = 0; + for (uint64_t i = 0; i < n; ++i) { + double di = abs(da[i] - db[i]); + if (di > d) d = di; + } + return d; +} + +reim_fft64vec simple_fft64(const rnx_f64& polynomial) { + const uint64_t nn = polynomial.nn(); + const uint64_t m = nn / 2; + reim_fft64vec res(nn); + double* dat = res.data(); + for (uint64_t i = 0; i < nn; ++i) dat[i] = polynomial.get_coeff(i); + reim_fft_simple(m, dat); + return res; +} + +reim_fft64vec operator*(double coeff, const reim_fft64vec& v) { + const uint64_t nn = v.nn(); + reim_fft64vec res(nn); + double* rr = res.data(); + const double* vv = v.data(); + for (uint64_t i = 0; i < nn; ++i) rr[i] = coeff * vv[i]; + return res; +} + +rnx_f64 simple_ifft64(const reim_fft64vec& v) { + const uint64_t nn = v.nn(); + const uint64_t m = nn / 2; + rnx_f64 res(nn); + double* dat = res.data(); + memcpy(dat, v.data(), nn * sizeof(double)); + reim_ifft_simple(m, dat); + return res; +} diff --git a/spqlios/lib/test/testlib/fft64_dft.h b/spqlios/lib/test/testlib/fft64_dft.h new file mode 100644 index 0000000..32ee437 --- /dev/null +++ b/spqlios/lib/test/testlib/fft64_dft.h @@ -0,0 +1,43 @@ +#ifndef SPQLIOS_FFT64_DFT_H +#define SPQLIOS_FFT64_DFT_H + +#include "negacyclic_polynomial.h" +#include "reim4_elem.h" + +class reim_fft64vec { + std::vector v; + + public: + reim_fft64vec() = default; + explicit reim_fft64vec(uint64_t n); + reim_fft64vec(uint64_t n, const double* data); + uint64_t nn() const; + static reim_fft64vec zero(uint64_t n); + /** random complex coefficients (unstructured) */ + static reim_fft64vec random(uint64_t n, double log2bound); + /** random fft of a small int polynomial */ + static reim_fft64vec dft_random(uint64_t n, uint64_t log2bound); + double* data(); + const double* data() const; + void save_as(double* dest) const; + reim4_elem get_blk(uint64_t blk) const; + void set_blk(uint64_t blk, const reim4_elem& value); +}; + +reim_fft64vec operator+(const reim_fft64vec& a, const reim_fft64vec& b); +reim_fft64vec operator-(const reim_fft64vec& a, const reim_fft64vec& b); +reim_fft64vec operator*(const reim_fft64vec& a, const reim_fft64vec& b); +reim_fft64vec operator*(double coeff, const reim_fft64vec& v); +reim_fft64vec& operator+=(reim_fft64vec& a, const reim_fft64vec& b); +reim_fft64vec& operator-=(reim_fft64vec& a, const reim_fft64vec& b); + +/** infty distance */ +double infty_dist(const reim_fft64vec& a, const reim_fft64vec& b); + +reim_fft64vec simple_fft64(const znx_i64& polynomial); +znx_i64 simple_rint_ifft64(const reim_fft64vec& fftvec); +rnx_f64 naive_ifft64(const reim_fft64vec& fftvec); +reim_fft64vec simple_fft64(const rnx_f64& polynomial); +rnx_f64 simple_ifft64(const reim_fft64vec& v); + +#endif // SPQLIOS_FFT64_DFT_H diff --git a/spqlios/lib/test/testlib/fft64_layouts.cpp b/spqlios/lib/test/testlib/fft64_layouts.cpp new file mode 100644 index 0000000..a8976b6 --- /dev/null +++ b/spqlios/lib/test/testlib/fft64_layouts.cpp @@ -0,0 +1,238 @@ +#include "fft64_layouts.h" +#ifdef VALGRIND_MEM_TESTS +#include "valgrind/memcheck.h" +#endif + +void* alloc64(uint64_t size) { + static uint64_t _msk64 = -64; + if (size == 0) return nullptr; + uint64_t rsize = (size + 63) & _msk64; + uint8_t* reps = (uint8_t*)spqlios_alloc(rsize); + REQUIRE_DRAMATICALLY(reps != 0, "Out of memory"); +#ifdef VALGRIND_MEM_TESTS + VALGRIND_MAKE_MEM_NOACCESS(reps + size, rsize - size); +#endif + return reps; +} + +fft64_vec_znx_dft_layout::fft64_vec_znx_dft_layout(uint64_t n, uint64_t size) + : nn(n), // + size(size), // + data((VEC_ZNX_DFT*)alloc64(n * size * 8)), // + view(n / 2, size, (double*)data) {} + +fft64_vec_znx_dft_layout::~fft64_vec_znx_dft_layout() { spqlios_free(data); } + +double* fft64_vec_znx_dft_layout::get_addr(uint64_t idx) { + REQUIRE_DRAMATICALLY(idx < size, "index overflow " << idx << " / " << size); + return ((double*)data) + idx * nn; +} +const double* fft64_vec_znx_dft_layout::get_addr(uint64_t idx) const { + REQUIRE_DRAMATICALLY(idx < size, "index overflow " << idx << " / " << size); + return ((double*)data) + idx * nn; +} +reim_fft64vec fft64_vec_znx_dft_layout::get_copy_zext(uint64_t idx) const { + if (idx < size) { + return reim_fft64vec(nn, get_addr(idx)); + } else { + return reim_fft64vec::zero(nn); + } +} +void fft64_vec_znx_dft_layout::fill_dft_random_log2bound(uint64_t bits) { + for (uint64_t i = 0; i < size; ++i) { + set(i, simple_fft64(znx_i64::random_log2bound(nn, bits))); + } +} +void fft64_vec_znx_dft_layout::set(uint64_t idx, const reim_fft64vec& value) { + REQUIRE_DRAMATICALLY(value.nn() == nn, "ring dimension mismatch"); + value.save_as(get_addr(idx)); +} +thash fft64_vec_znx_dft_layout::content_hash() const { return test_hash(data, size * nn * sizeof(double)); } + +reim4_elem fft64_vec_znx_dft_layout::get(uint64_t idx, uint64_t blk) const { + REQUIRE_DRAMATICALLY(idx < size, "index overflow: " << idx << " / " << size); + REQUIRE_DRAMATICALLY(blk < nn / 8, "blk overflow: " << blk << " / " << nn / 8); + double* reim = ((double*)data) + idx * nn; + return reim4_elem(reim + blk * 4, reim + nn / 2 + blk * 4); +} +reim4_elem fft64_vec_znx_dft_layout::get_zext(uint64_t idx, uint64_t blk) const { + REQUIRE_DRAMATICALLY(blk < nn / 8, "blk overflow: " << blk << " / " << nn / 8); + if (idx < size) { + return get(idx, blk); + } else { + return reim4_elem::zero(); + } +} +void fft64_vec_znx_dft_layout::set(uint64_t idx, uint64_t blk, const reim4_elem& value) { + REQUIRE_DRAMATICALLY(idx < size, "index overflow: " << idx << " / " << size); + REQUIRE_DRAMATICALLY(blk < nn / 8, "blk overflow: " << blk << " / " << nn / 8); + double* reim = ((double*)data) + idx * nn; + value.save_re_im(reim + blk * 4, reim + nn / 2 + blk * 4); +} +void fft64_vec_znx_dft_layout::fill_random(double log2bound) { + for (uint64_t i = 0; i < size; ++i) { + set(i, reim_fft64vec::random(nn, log2bound)); + } +} +void fft64_vec_znx_dft_layout::fill_dft_random(uint64_t log2bound) { + for (uint64_t i = 0; i < size; ++i) { + set(i, reim_fft64vec::dft_random(nn, log2bound)); + } +} + +fft64_vec_znx_big_layout::fft64_vec_znx_big_layout(uint64_t n, uint64_t size) + : nn(n), // + size(size), // + data((VEC_ZNX_BIG*)alloc64(n * size * 8)) {} + +znx_i64 fft64_vec_znx_big_layout::get_copy(uint64_t index) const { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + return znx_i64(nn, ((int64_t*)data) + index * nn); +} +znx_i64 fft64_vec_znx_big_layout::get_copy_zext(uint64_t index) const { + if (index < size) { + return znx_i64(nn, ((int64_t*)data) + index * nn); + } else { + return znx_i64::zero(nn); + } +} +void fft64_vec_znx_big_layout::set(uint64_t index, const znx_i64& value) { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + value.save_as(((int64_t*)data) + index * nn); +} +void fft64_vec_znx_big_layout::fill_random() { + for (uint64_t i = 0; i < size; ++i) { + set(i, znx_i64::random_log2bound(nn, 1)); + } +} +fft64_vec_znx_big_layout::~fft64_vec_znx_big_layout() { spqlios_free(data); } + +fft64_vmp_pmat_layout::fft64_vmp_pmat_layout(uint64_t n, uint64_t nrows, uint64_t ncols) + : nn(n), + nrows(nrows), + ncols(ncols), // + data((VMP_PMAT*)alloc64(nrows * ncols * nn * 8)) {} + +double* fft64_vmp_pmat_layout::get_addr(uint64_t row, uint64_t col, uint64_t blk) const { + REQUIRE_DRAMATICALLY(row < nrows, "row overflow: " << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "col overflow: " << col << " / " << ncols); + REQUIRE_DRAMATICALLY(blk < nn / 8, "block overflow: " << blk << " / " << (nn / 8)); + double* d = (double*)data; + if (col == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + return d + blk * nrows * ncols * 8 // major: blk + + col * nrows * 8 // col == ncols-1 + + row * 8; + } else { + // general case: columns go by pair + return d + blk * nrows * ncols * 8 // major: blk + + (col / 2) * (2 * nrows) * 8 // second: col pair index + + row * 2 * 8 // third: row index + + (col % 2) * 8; // minor: col in colpair + } +} + +reim4_elem fft64_vmp_pmat_layout::get(uint64_t row, uint64_t col, uint64_t blk) const { + return reim4_elem(get_addr(row, col, blk)); +} +reim4_elem fft64_vmp_pmat_layout::get_zext(uint64_t row, uint64_t col, uint64_t blk) const { + REQUIRE_DRAMATICALLY(blk < nn / 8, "block overflow: " << blk << " / " << (nn / 8)); + if (row < nrows && col < ncols) { + return reim4_elem(get_addr(row, col, blk)); + } else { + return reim4_elem::zero(); + } +} +void fft64_vmp_pmat_layout::set(uint64_t row, uint64_t col, uint64_t blk, const reim4_elem& value) const { + value.save_as(get_addr(row, col, blk)); +} + +fft64_vmp_pmat_layout::~fft64_vmp_pmat_layout() { spqlios_free(data); } + +reim_fft64vec fft64_vmp_pmat_layout::get_zext(uint64_t row, uint64_t col) const { + if (row >= nrows || col >= ncols) { + return reim_fft64vec::zero(nn); + } + if (nn < 8) { + // the pmat is just col major + double* addr = (double*)data + (row + col * nrows) * nn; + return reim_fft64vec(nn, addr); + } + // otherwise, reconstruct it block by block + reim_fft64vec res(nn); + for (uint64_t blk = 0; blk < nn / 8; ++blk) { + reim4_elem v = get(row, col, blk); + res.set_blk(blk, v); + } + return res; +} +void fft64_vmp_pmat_layout::set(uint64_t row, uint64_t col, const reim_fft64vec& value) { + REQUIRE_DRAMATICALLY(row < nrows, "row overflow: " << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "row overflow: " << col << " / " << ncols); + if (nn < 8) { + // the pmat is just col major + double* addr = (double*)data + (row + col * nrows) * nn; + value.save_as(addr); + return; + } + // otherwise, reconstruct it block by block + for (uint64_t blk = 0; blk < nn / 8; ++blk) { + reim4_elem v = value.get_blk(blk); + set(row, col, blk, v); + } +} +void fft64_vmp_pmat_layout::fill_random(double log2bound) { + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + set(row, col, reim_fft64vec::random(nn, log2bound)); + } + } +} +void fft64_vmp_pmat_layout::fill_dft_random(uint64_t log2bound) { + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + set(row, col, reim_fft64vec::dft_random(nn, log2bound)); + } + } +} + +fft64_svp_ppol_layout::fft64_svp_ppol_layout(uint64_t n) + : nn(n), // + data((SVP_PPOL*)alloc64(nn * 8)) {} + +reim_fft64vec fft64_svp_ppol_layout::get_copy() const { return reim_fft64vec(nn, (double*)data); } + +void fft64_svp_ppol_layout::set(const reim_fft64vec& value) { value.save_as((double*)data); } + +void fft64_svp_ppol_layout::fill_dft_random(uint64_t log2bound) { set(reim_fft64vec::dft_random(nn, log2bound)); } + +void fft64_svp_ppol_layout::fill_random(double log2bound) { set(reim_fft64vec::random(nn, log2bound)); } + +fft64_svp_ppol_layout::~fft64_svp_ppol_layout() { spqlios_free(data); } +thash fft64_svp_ppol_layout::content_hash() const { return test_hash(data, nn * sizeof(double)); } + +fft64_cnv_left_layout::fft64_cnv_left_layout(uint64_t n, uint64_t size) + : nn(n), // + size(size), + data((CNV_PVEC_L*)alloc64(size * nn * 8)) {} + +reim4_elem fft64_cnv_left_layout::get(uint64_t idx, uint64_t blk) { + REQUIRE_DRAMATICALLY(idx < size, "idx overflow: " << idx << " / " << size); + REQUIRE_DRAMATICALLY(blk < nn / 8, "block overflow: " << blk << " / " << (nn / 8)); + return reim4_elem(((double*)data) + blk * size + idx); +} + +fft64_cnv_left_layout::~fft64_cnv_left_layout() { spqlios_free(data); } + +fft64_cnv_right_layout::fft64_cnv_right_layout(uint64_t n, uint64_t size) + : nn(n), // + size(size), + data((CNV_PVEC_R*)alloc64(size * nn * 8)) {} + +reim4_elem fft64_cnv_right_layout::get(uint64_t idx, uint64_t blk) { + REQUIRE_DRAMATICALLY(idx < size, "idx overflow: " << idx << " / " << size); + REQUIRE_DRAMATICALLY(blk < nn / 8, "block overflow: " << blk << " / " << (nn / 8)); + return reim4_elem(((double*)data) + blk * size + idx); +} + +fft64_cnv_right_layout::~fft64_cnv_right_layout() { spqlios_free(data); } diff --git a/spqlios/lib/test/testlib/fft64_layouts.h b/spqlios/lib/test/testlib/fft64_layouts.h new file mode 100644 index 0000000..ba71448 --- /dev/null +++ b/spqlios/lib/test/testlib/fft64_layouts.h @@ -0,0 +1,109 @@ +#ifndef SPQLIOS_FFT64_LAYOUTS_H +#define SPQLIOS_FFT64_LAYOUTS_H + +#include "../../spqlios/arithmetic/vec_znx_arithmetic.h" +#include "fft64_dft.h" +#include "negacyclic_polynomial.h" +#include "reim4_elem.h" + +/** @brief test layout for the VEC_ZNX_DFT */ +struct fft64_vec_znx_dft_layout { + public: + const uint64_t nn; + const uint64_t size; + VEC_ZNX_DFT* const data; + reim_vector_view view; + /** @brief fill with random double values (unstructured) */ + void fill_random(double log2bound); + /** @brief fill with random ffts of small int polynomials */ + void fill_dft_random(uint64_t log2bound); + reim4_elem get(uint64_t idx, uint64_t blk) const; + reim4_elem get_zext(uint64_t idx, uint64_t blk) const; + void set(uint64_t idx, uint64_t blk, const reim4_elem& value); + fft64_vec_znx_dft_layout(uint64_t n, uint64_t size); + void fill_random_log2bound(uint64_t bits); + void fill_dft_random_log2bound(uint64_t bits); + double* get_addr(uint64_t idx); + const double* get_addr(uint64_t idx) const; + reim_fft64vec get_copy_zext(uint64_t idx) const; + void set(uint64_t idx, const reim_fft64vec& value); + thash content_hash() const; + ~fft64_vec_znx_dft_layout(); +}; + +/** @brief test layout for the VEC_ZNX_BIG */ +class fft64_vec_znx_big_layout { + public: + const uint64_t nn; + const uint64_t size; + VEC_ZNX_BIG* const data; + fft64_vec_znx_big_layout(uint64_t n, uint64_t size); + void fill_random(); + znx_i64 get_copy(uint64_t index) const; + znx_i64 get_copy_zext(uint64_t index) const; + void set(uint64_t index, const znx_i64& value); + thash content_hash() const; + ~fft64_vec_znx_big_layout(); +}; + +/** @brief test layout for the VMP_PMAT */ +class fft64_vmp_pmat_layout { + public: + const uint64_t nn; + const uint64_t nrows; + const uint64_t ncols; + VMP_PMAT* const data; + fft64_vmp_pmat_layout(uint64_t n, uint64_t nrows, uint64_t ncols); + double* get_addr(uint64_t row, uint64_t col, uint64_t blk) const; + reim4_elem get(uint64_t row, uint64_t col, uint64_t blk) const; + thash content_hash() const; + reim4_elem get_zext(uint64_t row, uint64_t col, uint64_t blk) const; + reim_fft64vec get_zext(uint64_t row, uint64_t col) const; + void set(uint64_t row, uint64_t col, uint64_t blk, const reim4_elem& v) const; + void set(uint64_t row, uint64_t col, const reim_fft64vec& value); + /** @brief fill with random double values (unstructured) */ + void fill_random(double log2bound); + /** @brief fill with random ffts of small int polynomials */ + void fill_dft_random(uint64_t log2bound); + ~fft64_vmp_pmat_layout(); +}; + +/** @brief test layout for the SVP_PPOL */ +class fft64_svp_ppol_layout { + public: + const uint64_t nn; + SVP_PPOL* const data; + fft64_svp_ppol_layout(uint64_t n); + thash content_hash() const; + reim_fft64vec get_copy() const; + void set(const reim_fft64vec&); + /** @brief fill with random double values (unstructured) */ + void fill_random(double log2bound); + /** @brief fill with random ffts of small int polynomials */ + void fill_dft_random(uint64_t log2bound); + ~fft64_svp_ppol_layout(); +}; + +/** @brief test layout for the CNV_PVEC_L */ +class fft64_cnv_left_layout { + const uint64_t nn; + const uint64_t size; + CNV_PVEC_L* const data; + fft64_cnv_left_layout(uint64_t n, uint64_t size); + reim4_elem get(uint64_t idx, uint64_t blk); + thash content_hash() const; + ~fft64_cnv_left_layout(); +}; + +/** @brief test layout for the CNV_PVEC_R */ +class fft64_cnv_right_layout { + const uint64_t nn; + const uint64_t size; + CNV_PVEC_R* const data; + fft64_cnv_right_layout(uint64_t n, uint64_t size); + reim4_elem get(uint64_t idx, uint64_t blk); + thash content_hash() const; + ~fft64_cnv_right_layout(); +}; + +#endif // SPQLIOS_FFT64_LAYOUTS_H diff --git a/spqlios/lib/test/testlib/mod_q120.cpp b/spqlios/lib/test/testlib/mod_q120.cpp new file mode 100644 index 0000000..eb05de8 --- /dev/null +++ b/spqlios/lib/test/testlib/mod_q120.cpp @@ -0,0 +1,229 @@ +#include "mod_q120.h" + +#include +#include + +int64_t centermod(int64_t v, int64_t q) { + int64_t t = v % q; + if (t >= (q + 1) / 2) return t - q; + if (t < -q / 2) return t + q; + return t; +} + +int64_t centermod(uint64_t v, int64_t q) { + int64_t t = int64_t(v % uint64_t(q)); + if (t >= q / 2) return t - q; + return t; +} + +mod_q120::mod_q120() { + for (uint64_t i = 0; i < 4; ++i) { + a[i] = 0; + } +} + +mod_q120::mod_q120(int64_t a0, int64_t a1, int64_t a2, int64_t a3) { + a[0] = centermod(a0, Qi[0]); + a[1] = centermod(a1, Qi[1]); + a[2] = centermod(a2, Qi[2]); + a[3] = centermod(a3, Qi[3]); +} + +mod_q120 operator+(const mod_q120& x, const mod_q120& y) { + mod_q120 r; + for (uint64_t i = 0; i < 4; ++i) { + r.a[i] = centermod(x.a[i] + y.a[i], mod_q120::Qi[i]); + } + return r; +} + +mod_q120 operator-(const mod_q120& x, const mod_q120& y) { + mod_q120 r; + for (uint64_t i = 0; i < 4; ++i) { + r.a[i] = centermod(x.a[i] - y.a[i], mod_q120::Qi[i]); + } + return r; +} + +mod_q120 operator*(const mod_q120& x, const mod_q120& y) { + mod_q120 r; + for (uint64_t i = 0; i < 4; ++i) { + r.a[i] = centermod(x.a[i] * y.a[i], mod_q120::Qi[i]); + } + return r; +} + +mod_q120& operator+=(mod_q120& x, const mod_q120& y) { + for (uint64_t i = 0; i < 4; ++i) { + x.a[i] = centermod(x.a[i] + y.a[i], mod_q120::Qi[i]); + } + return x; +} + +mod_q120& operator-=(mod_q120& x, const mod_q120& y) { + for (uint64_t i = 0; i < 4; ++i) { + x.a[i] = centermod(x.a[i] - y.a[i], mod_q120::Qi[i]); + } + return x; +} + +mod_q120& operator*=(mod_q120& x, const mod_q120& y) { + for (uint64_t i = 0; i < 4; ++i) { + x.a[i] = centermod(x.a[i] * y.a[i], mod_q120::Qi[i]); + } + return x; +} + +int64_t modq_pow(int64_t x, int32_t k, int64_t q) { + k = (k % (q - 1) + q - 1) % (q - 1); + + int64_t res = 1; + int64_t x_pow = centermod(x, q); + while (k != 0) { + if (k & 1) res = centermod(res * x_pow, q); + x_pow = centermod(x_pow * x_pow, q); + k >>= 1; + } + return res; +} + +mod_q120 pow(const mod_q120& x, int32_t k) { + const int64_t r0 = modq_pow(x.a[0], k, x.Qi[0]); + const int64_t r1 = modq_pow(x.a[1], k, x.Qi[1]); + const int64_t r2 = modq_pow(x.a[2], k, x.Qi[2]); + const int64_t r3 = modq_pow(x.a[3], k, x.Qi[3]); + return mod_q120{r0, r1, r2, r3}; +} + +static int64_t half_modq(int64_t x, int64_t q) { + // q must be odd in this function + if (x % 2 == 0) return x / 2; + return centermod((x + q) / 2, q); +} + +mod_q120 half(const mod_q120& x) { + const int64_t r0 = half_modq(x.a[0], x.Qi[0]); + const int64_t r1 = half_modq(x.a[1], x.Qi[1]); + const int64_t r2 = half_modq(x.a[2], x.Qi[2]); + const int64_t r3 = half_modq(x.a[3], x.Qi[3]); + return mod_q120{r0, r1, r2, r3}; +} + +bool operator==(const mod_q120& x, const mod_q120& y) { + for (uint64_t i = 0; i < 4; ++i) { + if (x.a[i] != y.a[i]) return false; + } + return true; +} + +std::ostream& operator<<(std::ostream& out, const mod_q120& x) { + return out << "q120{" << x.a[0] << "," << x.a[1] << "," << x.a[2] << "," << x.a[3] << "}"; +} + +mod_q120 mod_q120::from_q120a(const void* addr) { + static const uint64_t _2p32 = UINT64_C(1) << 32; + const uint64_t* in = (const uint64_t*)addr; + mod_q120 r; + for (uint64_t i = 0; i < 4; ++i) { + REQUIRE_DRAMATICALLY(in[i] < _2p32, "invalid layout a q120"); + r.a[i] = centermod(in[i], mod_q120::Qi[i]); + } + return r; +} + +mod_q120 mod_q120::from_q120b(const void* addr) { + const uint64_t* in = (const uint64_t*)addr; + mod_q120 r; + for (uint64_t i = 0; i < 4; ++i) { + r.a[i] = centermod(in[i], mod_q120::Qi[i]); + } + return r; +} + +mod_q120 mod_q120::from_q120c(const void* addr) { + //static const uint64_t _mask_2p32 = (uint64_t(1) << 32) - 1; + const uint32_t* in = (const uint32_t*)addr; + mod_q120 r; + for (uint64_t i = 0, k = 0; i < 8; i += 2, ++k) { + const uint64_t q = mod_q120::Qi[k]; + uint64_t u = in[i]; + uint64_t w = in[i + 1]; + REQUIRE_DRAMATICALLY(((u << 32) % q) == (w % q), + "invalid layout q120c: " << u << ".2^32 != " << (w >> 32) << " mod " << q); + r.a[k] = centermod(u, q); + } + return r; +} +__int128_t mod_q120::to_int128() const { + static const __int128_t qm[] = {(__int128_t(Qi[1]) * Qi[2]) * Qi[3], (__int128_t(Qi[0]) * Qi[2]) * Qi[3], + (__int128_t(Qi[0]) * Qi[1]) * Qi[3], (__int128_t(Qi[0]) * Qi[1]) * Qi[2]}; + static const int64_t CRTi[] = {Q1_CRT_CST, Q2_CRT_CST, Q3_CRT_CST, Q4_CRT_CST}; + static const __int128_t q = qm[0] * Qi[0]; + static const __int128_t qs2 = q / 2; + __int128_t res = 0; + for (uint64_t i = 0; i < 4; ++i) { + res += (a[i] * CRTi[i] % Qi[i]) * qm[i]; + } + res = (((res % q) + q + qs2) % q) - qs2; // centermod + return res; +} +void mod_q120::save_as_q120a(void* dest) const { + int64_t* d = (int64_t*)dest; + for (uint64_t i = 0; i < 4; ++i) { + d[i] = a[i] + Qi[i]; + } +} +void mod_q120::save_as_q120b(void* dest) const { + int64_t* d = (int64_t*)dest; + for (uint64_t i = 0; i < 4; ++i) { + d[i] = a[i] + (Qi[i] * (1 + uniform_u64_bits(32))); + } +} +void mod_q120::save_as_q120c(void* dest) const { + int32_t* d = (int32_t*)dest; + for (uint64_t i = 0; i < 4; ++i) { + d[2 * i] = a[i] + 3 * Qi[i]; + d[2 * i + 1] = (uint64_t(d[2 * i]) << 32) % uint64_t(Qi[i]); + } +} + +mod_q120 uniform_q120() { + test_rng& gen = randgen(); + std::uniform_int_distribution dista(0, mod_q120::Qi[0]); + std::uniform_int_distribution distb(0, mod_q120::Qi[1]); + std::uniform_int_distribution distc(0, mod_q120::Qi[2]); + std::uniform_int_distribution distd(0, mod_q120::Qi[3]); + return mod_q120(dista(gen), distb(gen), distc(gen), distd(gen)); +} + +void uniform_q120a(void* dest) { + uint64_t* res = (uint64_t*)dest; + for (uint64_t i = 0; i < 4; ++i) { + res[i] = uniform_u64_bits(32); + } +} + +void uniform_q120b(void* dest) { + uint64_t* res = (uint64_t*)dest; + for (uint64_t i = 0; i < 4; ++i) { + res[i] = uniform_u64(); + } +} + +void uniform_q120c(void* dest) { + uint32_t* res = (uint32_t*)dest; + static const uint64_t _2p32 = uint64_t(1) << 32; + for (uint64_t i = 0, k = 0; i < 8; i += 2, ++k) { + const uint64_t q = mod_q120::Qi[k]; + const uint64_t z = uniform_u64_bits(32); + const uint64_t z_pow_red = (z << 32) % q; + const uint64_t room = (_2p32 - z_pow_red) / q; + const uint64_t z_pow = z_pow_red + (uniform_u64() % room) * q; + REQUIRE_DRAMATICALLY(z < _2p32, "bug!"); + REQUIRE_DRAMATICALLY(z_pow < _2p32, "bug!"); + REQUIRE_DRAMATICALLY(z_pow % q == (z << 32) % q, "bug!"); + + res[i] = (uint32_t)z; + res[i + 1] = (uint32_t)z_pow; + } +} diff --git a/spqlios/lib/test/testlib/mod_q120.h b/spqlios/lib/test/testlib/mod_q120.h new file mode 100644 index 0000000..45c7cc7 --- /dev/null +++ b/spqlios/lib/test/testlib/mod_q120.h @@ -0,0 +1,49 @@ +#ifndef SPQLIOS_MOD_Q120_H +#define SPQLIOS_MOD_Q120_H + +#include + +#include "../../spqlios/q120/q120_common.h" +#include "test_commons.h" + +/** @brief centered modulo q */ +int64_t centermod(int64_t v, int64_t q); +int64_t centermod(uint64_t v, int64_t q); + +/** @brief this class represents an integer mod Q120 */ +class mod_q120 { + public: + static constexpr int64_t Qi[] = {Q1, Q2, Q3, Q4}; + int64_t a[4]; + mod_q120(int64_t a1, int64_t a2, int64_t a3, int64_t a4); + mod_q120(); + __int128_t to_int128() const; + static mod_q120 from_q120a(const void* addr); + static mod_q120 from_q120b(const void* addr); + static mod_q120 from_q120c(const void* addr); + void save_as_q120a(void* dest) const; + void save_as_q120b(void* dest) const; + void save_as_q120c(void* dest) const; +}; + +mod_q120 operator+(const mod_q120& x, const mod_q120& y); +mod_q120 operator-(const mod_q120& x, const mod_q120& y); +mod_q120 operator*(const mod_q120& x, const mod_q120& y); +mod_q120& operator+=(mod_q120& x, const mod_q120& y); +mod_q120& operator-=(mod_q120& x, const mod_q120& y); +mod_q120& operator*=(mod_q120& x, const mod_q120& y); +std::ostream& operator<<(std::ostream& out, const mod_q120& x); +bool operator==(const mod_q120& x, const mod_q120& y); +mod_q120 pow(const mod_q120& x, int32_t k); +mod_q120 half(const mod_q120& x); + +/** @brief a uniformly drawn number mod Q120 */ +mod_q120 uniform_q120(); +/** @brief a uniformly random mod Q120 layout A (4 integers < 2^32) */ +void uniform_q120a(void* dest); +/** @brief a uniformly random mod Q120 layout B (4 integers < 2^64) */ +void uniform_q120b(void* dest); +/** @brief a uniformly random mod Q120 layout C (4 integers repr. x,2^32x) */ +void uniform_q120c(void* dest); + +#endif // SPQLIOS_MOD_Q120_H diff --git a/spqlios/lib/test/testlib/negacyclic_polynomial.cpp b/spqlios/lib/test/testlib/negacyclic_polynomial.cpp new file mode 100644 index 0000000..ee516c6 --- /dev/null +++ b/spqlios/lib/test/testlib/negacyclic_polynomial.cpp @@ -0,0 +1,18 @@ +#include "negacyclic_polynomial_impl.h" + +// explicit instantiation +EXPLICIT_INSTANTIATE_POLYNOMIAL(__int128_t); +EXPLICIT_INSTANTIATE_POLYNOMIAL(int64_t); +EXPLICIT_INSTANTIATE_POLYNOMIAL(double); + +double infty_dist(const rnx_f64& a, const rnx_f64& b) { + const uint64_t nn = a.nn(); + const double* aa = a.data(); + const double* bb = b.data(); + double res = 0.; + for (uint64_t i = 0; i < nn; ++i) { + double d = fabs(aa[i] - bb[i]); + if (d > res) res = d; + } + return res; +} diff --git a/spqlios/lib/test/testlib/negacyclic_polynomial.h b/spqlios/lib/test/testlib/negacyclic_polynomial.h new file mode 100644 index 0000000..8f5b17f --- /dev/null +++ b/spqlios/lib/test/testlib/negacyclic_polynomial.h @@ -0,0 +1,69 @@ +#ifndef SPQLIOS_NEGACYCLIC_POLYNOMIAL_H +#define SPQLIOS_NEGACYCLIC_POLYNOMIAL_H + +#include + +#include "test_commons.h" + +template +class polynomial; +typedef polynomial<__int128_t> znx_i128; +typedef polynomial znx_i64; +typedef polynomial rnx_f64; + +template +class polynomial { + public: + std::vector coeffs; + /** @brief create a polynomial out of existing coeffs */ + polynomial(uint64_t N, const T* c); + /** @brief zero polynomial of dimension N */ + explicit polynomial(uint64_t N); + /** @brief empty polynomial (dim 0) */ + polynomial(); + + /** @brief ring dimension */ + uint64_t nn() const; + /** @brief special setter (accept any indexes, and does the negacyclic translation) */ + void set_coeff(int64_t i, T v); + /** @brief special getter (accept any indexes, and does the negacyclic translation) */ + T get_coeff(int64_t i) const; + /** @brief returns the coefficient layout */ + T* data(); + /** @brief returns the coefficient layout (const version) */ + const T* data() const; + /** @brief saves to the layout */ + void save_as(T* dest) const; + /** @brief zero */ + static polynomial zero(uint64_t n); + /** @brief random polynomial with coefficients in [-2^log2bounds, 2^log2bounds]*/ + static polynomial random_log2bound(uint64_t n, uint64_t log2bound); + /** @brief random polynomial with coefficients in [-2^log2bounds, 2^log2bounds]*/ + static polynomial random(uint64_t n); + /** @brief random polynomial with coefficient in [lb;ub] */ + static polynomial random_bound(uint64_t n, const T lb, const T ub); +}; + +/** @brief equality operator (used during tests) */ +template +bool operator==(const polynomial& a, const polynomial& b); + +/** @brief addition operator (used during tests) */ +template +polynomial operator+(const polynomial& a, const polynomial& b); + +/** @brief subtraction operator (used during tests) */ +template +polynomial operator-(const polynomial& a, const polynomial& b); + +/** @brief negation operator (used during tests) */ +template +polynomial operator-(const polynomial& a); + +template +polynomial naive_product(const polynomial& a, const polynomial& b); + +/** @brief distance between two real polynomials (used during tests) */ +double infty_dist(const rnx_f64& a, const rnx_f64& b); + +#endif // SPQLIOS_NEGACYCLIC_POLYNOMIAL_H diff --git a/spqlios/lib/test/testlib/negacyclic_polynomial_impl.h b/spqlios/lib/test/testlib/negacyclic_polynomial_impl.h new file mode 100644 index 0000000..e31c5a5 --- /dev/null +++ b/spqlios/lib/test/testlib/negacyclic_polynomial_impl.h @@ -0,0 +1,247 @@ +#ifndef SPQLIOS_NEGACYCLIC_POLYNOMIAL_IMPL_H +#define SPQLIOS_NEGACYCLIC_POLYNOMIAL_IMPL_H + +#include "negacyclic_polynomial.h" + +template +polynomial::polynomial(uint64_t N, const T* c) : coeffs(N) { + for (uint64_t i = 0; i < N; ++i) coeffs[i] = c[i]; +} +/** @brief zero polynomial of dimension N */ +template +polynomial::polynomial(uint64_t N) : coeffs(N, 0) {} +/** @brief empty polynomial (dim 0) */ +template +polynomial::polynomial() {} + +/** @brief ring dimension */ +template +uint64_t polynomial::nn() const { + uint64_t n = coeffs.size(); + REQUIRE_DRAMATICALLY(is_pow2(n), "polynomial dim is not a pow of 2"); + return n; +} + +/** @brief special setter (accept any indexes, and does the negacyclic translation) */ +template +void polynomial::set_coeff(int64_t i, T v) { + const uint64_t n = nn(); + const uint64_t _2nm = 2 * n - 1; + uint64_t pos = uint64_t(i) & _2nm; + if (pos < n) { + coeffs[pos] = v; + } else { + coeffs[pos - n] = -v; + } +} +/** @brief special getter (accept any indexes, and does the negacyclic translation) */ +template +T polynomial::get_coeff(int64_t i) const { + const uint64_t n = nn(); + const uint64_t _2nm = 2 * n - 1; + uint64_t pos = uint64_t(i) & _2nm; + if (pos < n) { + return coeffs[pos]; + } else { + return -coeffs[pos - n]; + } +} +/** @brief returns the coefficient layout */ +template +T* polynomial::data() { + return coeffs.data(); +} + +template +void polynomial::save_as(T* dest) const { + const uint64_t n = nn(); + for (uint64_t i = 0; i < n; ++i) { + dest[i] = coeffs[i]; + } +} + +/** @brief returns the coefficient layout (const version) */ +template +const T* polynomial::data() const { + return coeffs.data(); +} + +/** @brief returns the coefficient layout (const version) */ +template +polynomial polynomial::zero(uint64_t n) { + return polynomial(n); +} + +/** @brief equality operator (used during tests) */ +template +bool operator==(const polynomial& a, const polynomial& b) { + uint64_t n = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == n, "wrong dimensions"); + for (uint64_t i = 0; i < n; ++i) { + if (a.get_coeff(i) != b.get_coeff(i)) return false; + } + return true; +} + +/** @brief addition operator (used during tests) */ +template +polynomial operator+(const polynomial& a, const polynomial& b) { + uint64_t n = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == n, "wrong dimensions"); + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, a.get_coeff(i) + b.get_coeff(i)); + } + return res; +} + +/** @brief subtraction operator (used during tests) */ +template +polynomial operator-(const polynomial& a, const polynomial& b) { + uint64_t n = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == n, "wrong dimensions"); + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, a.get_coeff(i) - b.get_coeff(i)); + } + return res; +} + +/** @brief subtraction operator (used during tests) */ +template +polynomial operator-(const polynomial& a) { + uint64_t n = a.nn(); + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, -a.get_coeff(i)); + } + return res; +} + +/** @brief random polynomial */ +template +polynomial random_polynomial(uint64_t n); + +/** @brief random int64 polynomial */ +template <> +polynomial random_polynomial(uint64_t n) { + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_i64()); + } + return res; +} + +/** @brief random float64 gaussian polynomial */ +template <> +polynomial random_polynomial(uint64_t n) { + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, random_f64_gaussian()); + } + return res; +} + +template +polynomial random_polynomial_bounds(uint64_t n, const T lb, const T ub); + +/** @brief random int64 polynomial */ +template <> +polynomial random_polynomial_bounds(uint64_t n, const int64_t lb, const int64_t ub) { + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_i64_bounds(lb, ub)); + } + return res; +} + +/** @brief random float64 gaussian polynomial */ +template <> +polynomial random_polynomial_bounds(uint64_t n, const double lb, const double ub) { + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_f64_bounds(lb, ub)); + } + return res; +} + +/** @brief random int64 polynomial */ +template <> +polynomial<__int128_t> random_polynomial_bounds(uint64_t n, const __int128_t lb, const __int128_t ub) { + polynomial<__int128_t> res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_i128_bounds(lb, ub)); + } + return res; +} + +template +polynomial random_polynomial_bits(uint64_t n, const uint64_t bits) { + T b = UINT64_C(1) << bits; + return random_polynomial_bounds(n, -b, b); +} + +template <> +polynomial polynomial::random_log2bound(uint64_t n, uint64_t log2bound) { + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_i64_bits(log2bound)); + } + return res; +} + +template <> +polynomial polynomial::random(uint64_t n) { + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_u64()); + } + return res; +} + +template <> +polynomial polynomial::random_log2bound(uint64_t n, uint64_t log2bound) { + polynomial res(n); + double bound = pow(2., log2bound); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_f64_bounds(-bound, bound)); + } + return res; +} + +template <> +polynomial polynomial::random(uint64_t n) { + polynomial res(n); + double bound = 2.; + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_f64_bounds(-bound, bound)); + } + return res; +} + +template +polynomial naive_product(const polynomial& a, const polynomial& b) { + const int64_t nn = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == uint64_t(nn), "dimension mismatch!"); + polynomial res(nn); + for (int64_t i = 0; i < nn; ++i) { + T ri = 0; + for (int64_t j = 0; j < nn; ++j) { + ri += a.get_coeff(j) * b.get_coeff(i - j); + } + res.set_coeff(i, ri); + } + return res; +} + +#define EXPLICIT_INSTANTIATE_POLYNOMIAL(TYPE) \ + template class polynomial; \ + template bool operator==(const polynomial& a, const polynomial& b); \ + template polynomial operator+(const polynomial& a, const polynomial& b); \ + template polynomial operator-(const polynomial& a, const polynomial& b); \ + template polynomial operator-(const polynomial& a); \ + template polynomial random_polynomial_bits(uint64_t n, const uint64_t bits); \ + template polynomial naive_product(const polynomial& a, const polynomial& b); \ + // template polynomial random_polynomial(uint64_t n); + +#endif // SPQLIOS_NEGACYCLIC_POLYNOMIAL_IMPL_H diff --git a/spqlios/lib/test/testlib/ntt120_dft.cpp b/spqlios/lib/test/testlib/ntt120_dft.cpp new file mode 100644 index 0000000..5d4b6f5 --- /dev/null +++ b/spqlios/lib/test/testlib/ntt120_dft.cpp @@ -0,0 +1,122 @@ +#include "ntt120_dft.h" + +#include "mod_q120.h" + +// @brief alternative version of the NTT + +/** for all s=k/2^17, root_of_unity(s) = omega_0^k */ +static mod_q120 root_of_unity(double s) { + static mod_q120 omega_2pow17{OMEGA1, OMEGA2, OMEGA3, OMEGA4}; + static double _2pow17 = 1 << 17; + return pow(omega_2pow17, s * _2pow17); +} +static mod_q120 root_of_unity_inv(double s) { + static mod_q120 omega_2pow17{OMEGA1, OMEGA2, OMEGA3, OMEGA4}; + static double _2pow17 = 1 << 17; + return pow(omega_2pow17, -s * _2pow17); +} + +/** recursive naive ntt */ +static void q120_ntt_naive_rec(uint64_t n, double entry_pwr, mod_q120* data) { + if (n == 1) return; + const uint64_t h = n / 2; + const double s = entry_pwr / 2.; + mod_q120 om = root_of_unity(s); + for (uint64_t j = 0; j < h; ++j) { + mod_q120 om_right = data[h + j] * om; + data[h + j] = data[j] - om_right; + data[j] = data[j] + om_right; + } + q120_ntt_naive_rec(h, s, data); + q120_ntt_naive_rec(h, s + 0.5, data + h); +} +static void q120_intt_naive_rec(uint64_t n, double entry_pwr, mod_q120* data) { + if (n == 1) return; + const uint64_t h = n / 2; + const double s = entry_pwr / 2.; + q120_intt_naive_rec(h, s, data); + q120_intt_naive_rec(h, s + 0.5, data + h); + mod_q120 om = root_of_unity_inv(s); + for (uint64_t j = 0; j < h; ++j) { + mod_q120 dat_diff = half(data[j] - data[h + j]); + data[j] = half(data[j] + data[h + j]); + data[h + j] = dat_diff * om; + } +} + +/** user friendly version */ +q120_nttvec simple_ntt120(const znx_i64& polynomial) { + const uint64_t n = polynomial.nn(); + q120_nttvec res(n); + for (uint64_t i = 0; i < n; ++i) { + int64_t xi = polynomial.get_coeff(i); + res.v[i] = mod_q120(xi, xi, xi, xi); + } + q120_ntt_naive_rec(n, 0.5, res.v.data()); + return res; +} + +znx_i128 simple_intt120(const q120_nttvec& fftvec) { + const uint64_t n = fftvec.nn(); + q120_nttvec copy = fftvec; + znx_i128 res(n); + q120_intt_naive_rec(n, 0.5, copy.v.data()); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, copy.v[i].to_int128()); + } + return res; +} +bool operator==(const q120_nttvec& a, const q120_nttvec& b) { return a.v == b.v; } + +std::vector q120_ntt_naive(const std::vector& x) { + std::vector res = x; + q120_ntt_naive_rec(res.size(), 0.5, res.data()); + return res; +} +q120_nttvec::q120_nttvec(uint64_t n) : v(n) {} +q120_nttvec::q120_nttvec(uint64_t n, const q120b* data) : v(n) { + int64_t* d = (int64_t*)data; + for (uint64_t i = 0; i < n; ++i) { + v[i] = mod_q120::from_q120b(d + 4 * i); + } +} +q120_nttvec::q120_nttvec(uint64_t n, const q120c* data) : v(n) { + int64_t* d = (int64_t*)data; + for (uint64_t i = 0; i < n; ++i) { + v[i] = mod_q120::from_q120c(d + 4 * i); + } +} +uint64_t q120_nttvec::nn() const { return v.size(); } +q120_nttvec q120_nttvec::zero(uint64_t n) { return q120_nttvec(n); } +void q120_nttvec::save_as(q120a* dest) const { + int64_t* const d = (int64_t*)dest; + const uint64_t n = nn(); + for (uint64_t i = 0; i < n; ++i) { + v[i].save_as_q120a(d + 4 * i); + } +} +void q120_nttvec::save_as(q120b* dest) const { + int64_t* const d = (int64_t*)dest; + const uint64_t n = nn(); + for (uint64_t i = 0; i < n; ++i) { + v[i].save_as_q120b(d + 4 * i); + } +} +void q120_nttvec::save_as(q120c* dest) const { + int64_t* const d = (int64_t*)dest; + const uint64_t n = nn(); + for (uint64_t i = 0; i < n; ++i) { + v[i].save_as_q120c(d + 4 * i); + } +} +mod_q120 q120_nttvec::get_blk(uint64_t blk) const { + REQUIRE_DRAMATICALLY(blk < nn(), "blk overflow"); + return v[blk]; +} +q120_nttvec q120_nttvec::random(uint64_t n) { + q120_nttvec res(n); + for (uint64_t i = 0; i < n; ++i) { + res.v[i] = uniform_q120(); + } + return res; +} diff --git a/spqlios/lib/test/testlib/ntt120_dft.h b/spqlios/lib/test/testlib/ntt120_dft.h new file mode 100644 index 0000000..80f5679 --- /dev/null +++ b/spqlios/lib/test/testlib/ntt120_dft.h @@ -0,0 +1,31 @@ +#ifndef SPQLIOS_NTT120_DFT_H +#define SPQLIOS_NTT120_DFT_H + +#include + +#include "../../spqlios/q120/q120_arithmetic.h" +#include "mod_q120.h" +#include "negacyclic_polynomial.h" +#include "test_commons.h" + +class q120_nttvec { + public: + std::vector v; + q120_nttvec() = default; + explicit q120_nttvec(uint64_t n); + q120_nttvec(uint64_t n, const q120b* data); + q120_nttvec(uint64_t n, const q120c* data); + uint64_t nn() const; + static q120_nttvec zero(uint64_t n); + static q120_nttvec random(uint64_t n); + void save_as(q120a* dest) const; + void save_as(q120b* dest) const; + void save_as(q120c* dest) const; + mod_q120 get_blk(uint64_t blk) const; +}; + +q120_nttvec simple_ntt120(const znx_i64& polynomial); +znx_i128 simple_intt120(const q120_nttvec& fftvec); +bool operator==(const q120_nttvec& a, const q120_nttvec& b); + +#endif // SPQLIOS_NTT120_DFT_H diff --git a/spqlios/lib/test/testlib/ntt120_layouts.cpp b/spqlios/lib/test/testlib/ntt120_layouts.cpp new file mode 100644 index 0000000..d1e582f --- /dev/null +++ b/spqlios/lib/test/testlib/ntt120_layouts.cpp @@ -0,0 +1,66 @@ +#include "ntt120_layouts.h" + +mod_q120x2::mod_q120x2() {} +mod_q120x2::mod_q120x2(const mod_q120& a, const mod_q120& b) { + value[0] = a; + value[1] = b; +} +mod_q120x2::mod_q120x2(q120x2b* addr) { + uint64_t* p = (uint64_t*)addr; + value[0] = mod_q120::from_q120b(p); + value[1] = mod_q120::from_q120b(p + 4); +} + +ntt120_vec_znx_dft_layout::ntt120_vec_znx_dft_layout(uint64_t n, uint64_t size) + : nn(n), // + size(size), // + data((VEC_ZNX_DFT*)alloc64(n * size * 4 * sizeof(uint64_t))) {} + +mod_q120x2 ntt120_vec_znx_dft_layout::get_copy_zext(uint64_t idx, uint64_t blk) { + return mod_q120x2(get_blk(idx, blk)); +} +q120x2b* ntt120_vec_znx_dft_layout::get_blk(uint64_t idx, uint64_t blk) { + REQUIRE_DRAMATICALLY(idx < size, "idx overflow"); + REQUIRE_DRAMATICALLY(blk < nn / 2, "blk overflow"); + uint64_t* d = (uint64_t*)data; + return (q120x2b*)(d + 4 * nn * idx + 8 * blk); +} +ntt120_vec_znx_dft_layout::~ntt120_vec_znx_dft_layout() { spqlios_free(data); } +q120_nttvec ntt120_vec_znx_dft_layout::get_copy_zext(uint64_t idx) { + int64_t* d = (int64_t*)data; + if (idx < size) { + return q120_nttvec(nn, (q120b*)(d + idx * nn * 4)); + } else { + return q120_nttvec::zero(nn); + } +} +void ntt120_vec_znx_dft_layout::set(uint64_t idx, const q120_nttvec& value) { + REQUIRE_DRAMATICALLY(idx < size, "index overflow: " << idx << " / " << size); + q120b* dest_addr = (q120b*)((int64_t*)data + idx * nn * 4); + value.save_as(dest_addr); +} +void ntt120_vec_znx_dft_layout::fill_random() { + for (uint64_t i = 0; i < size; ++i) { + set(i, q120_nttvec::random(nn)); + } +} +thash ntt120_vec_znx_dft_layout::content_hash() const { return test_hash(data, nn * size * 4 * sizeof(int64_t)); } +ntt120_vec_znx_big_layout::ntt120_vec_znx_big_layout(uint64_t n, uint64_t size) + : nn(n), // + size(size), + data((VEC_ZNX_BIG*)alloc64(n * size * sizeof(__int128_t))) {} + +znx_i128 ntt120_vec_znx_big_layout::get_copy(uint64_t index) const { return znx_i128(nn, get_addr(index)); } +znx_i128 ntt120_vec_znx_big_layout::get_copy_zext(uint64_t index) const { + if (index < size) { + return znx_i128(nn, get_addr(index)); + } else { + return znx_i128::zero(nn); + } +} +__int128* ntt120_vec_znx_big_layout::get_addr(uint64_t index) const { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + return (__int128_t*)data + index * nn; +} +void ntt120_vec_znx_big_layout::set(uint64_t index, const znx_i128& value) { value.save_as(get_addr(index)); } +ntt120_vec_znx_big_layout::~ntt120_vec_znx_big_layout() { spqlios_free(data); } diff --git a/spqlios/lib/test/testlib/ntt120_layouts.h b/spqlios/lib/test/testlib/ntt120_layouts.h new file mode 100644 index 0000000..d8fcc08 --- /dev/null +++ b/spqlios/lib/test/testlib/ntt120_layouts.h @@ -0,0 +1,103 @@ +#ifndef SPQLIOS_NTT120_LAYOUTS_H +#define SPQLIOS_NTT120_LAYOUTS_H + +#include "../../spqlios/arithmetic/vec_znx_arithmetic.h" +#include "mod_q120.h" +#include "negacyclic_polynomial.h" +#include "ntt120_dft.h" +#include "test_commons.h" + +struct q120b_vector_view {}; + +struct mod_q120x2 { + mod_q120 value[2]; + mod_q120x2(); + mod_q120x2(const mod_q120& a, const mod_q120& b); + mod_q120x2(__int128_t value); + explicit mod_q120x2(q120x2b* addr); + explicit mod_q120x2(q120x2c* addr); + void save_as(q120x2b* addr) const; + void save_as(q120x2c* addr) const; + static mod_q120x2 random(); +}; +mod_q120x2 operator+(const mod_q120x2& a, const mod_q120x2& b); +mod_q120x2 operator-(const mod_q120x2& a, const mod_q120x2& b); +mod_q120x2 operator*(const mod_q120x2& a, const mod_q120x2& b); +bool operator==(const mod_q120x2& a, const mod_q120x2& b); +bool operator!=(const mod_q120x2& a, const mod_q120x2& b); +mod_q120x2& operator+=(mod_q120x2& a, const mod_q120x2& b); +mod_q120x2& operator-=(mod_q120x2& a, const mod_q120x2& b); + +/** @brief test layout for the VEC_ZNX_DFT */ +struct ntt120_vec_znx_dft_layout { + const uint64_t nn; + const uint64_t size; + VEC_ZNX_DFT* const data; + ntt120_vec_znx_dft_layout(uint64_t n, uint64_t size); + mod_q120x2 get_copy_zext(uint64_t idx, uint64_t blk); + q120_nttvec get_copy_zext(uint64_t idx); + void set(uint64_t idx, const q120_nttvec& v); + q120x2b* get_blk(uint64_t idx, uint64_t blk); + thash content_hash() const; + void fill_random(); + ~ntt120_vec_znx_dft_layout(); +}; + +/** @brief test layout for the VEC_ZNX_BIG */ +class ntt120_vec_znx_big_layout { + public: + const uint64_t nn; + const uint64_t size; + VEC_ZNX_BIG* const data; + ntt120_vec_znx_big_layout(uint64_t n, uint64_t size); + + private: + __int128* get_addr(uint64_t index) const; + + public: + znx_i128 get_copy(uint64_t index) const; + znx_i128 get_copy_zext(uint64_t index) const; + void set(uint64_t index, const znx_i128& value); + ~ntt120_vec_znx_big_layout(); +}; + +/** @brief test layout for the VMP_PMAT */ +class ntt120_vmp_pmat_layout { + const uint64_t nn; + const uint64_t nrows; + const uint64_t ncols; + VMP_PMAT* const data; + ntt120_vmp_pmat_layout(uint64_t n, uint64_t nrows, uint64_t ncols); + mod_q120x2 get(uint64_t row, uint64_t col, uint64_t blk) const; + ~ntt120_vmp_pmat_layout(); +}; + +/** @brief test layout for the SVP_PPOL */ +class ntt120_svp_ppol_layout { + const uint64_t nn; + SVP_PPOL* const data; + ntt120_svp_ppol_layout(uint64_t n); + ~ntt120_svp_ppol_layout(); +}; + +/** @brief test layout for the CNV_PVEC_L */ +class ntt120_cnv_left_layout { + const uint64_t nn; + const uint64_t size; + CNV_PVEC_L* const data; + ntt120_cnv_left_layout(uint64_t n, uint64_t size); + mod_q120x2 get(uint64_t idx, uint64_t blk); + ~ntt120_cnv_left_layout(); +}; + +/** @brief test layout for the CNV_PVEC_R */ +class ntt120_cnv_right_layout { + const uint64_t nn; + const uint64_t size; + CNV_PVEC_R* const data; + ntt120_cnv_right_layout(uint64_t n, uint64_t size); + mod_q120x2 get(uint64_t idx, uint64_t blk); + ~ntt120_cnv_right_layout(); +}; + +#endif // SPQLIOS_NTT120_LAYOUTS_H diff --git a/spqlios/lib/test/testlib/polynomial_vector.cpp b/spqlios/lib/test/testlib/polynomial_vector.cpp new file mode 100644 index 0000000..95d4e78 --- /dev/null +++ b/spqlios/lib/test/testlib/polynomial_vector.cpp @@ -0,0 +1,69 @@ +#include "polynomial_vector.h" + +#include + +#ifdef VALGRIND_MEM_TESTS +#include "valgrind/memcheck.h" +#endif + +#define CANARY_PADDING (1024) +#define GARBAGE_VALUE (242) + +znx_vec_i64_layout::znx_vec_i64_layout(uint64_t n, uint64_t size, uint64_t slice) : n(n), size(size), slice(slice) { + REQUIRE_DRAMATICALLY(is_pow2(n), "not a power of 2" << n); + REQUIRE_DRAMATICALLY(slice >= n, "slice too small" << slice << " < " << n); + this->region = (uint8_t*)malloc(size * slice * sizeof(int64_t) + 2 * CANARY_PADDING); + this->data_start = (int64_t*)(region + CANARY_PADDING); + // ensure that any invalid value is kind-of garbage + memset(region, GARBAGE_VALUE, size * slice * sizeof(int64_t) + 2 * CANARY_PADDING); + // mark inter-slice memory as non accessible +#ifdef VALGRIND_MEM_TESTS + VALGRIND_MAKE_MEM_NOACCESS(region, CANARY_PADDING); + VALGRIND_MAKE_MEM_NOACCESS(region + size * slice * sizeof(int64_t) + CANARY_PADDING, CANARY_PADDING); + for (uint64_t i = 0; i < size; ++i) { + VALGRIND_MAKE_MEM_UNDEFINED(data_start + i * slice, n * sizeof(int64_t)); + } + if (size != slice) { + for (uint64_t i = 0; i < size; ++i) { + VALGRIND_MAKE_MEM_NOACCESS(data_start + i * slice + n, (slice - n) * sizeof(int64_t)); + } + } +#endif +} + +znx_vec_i64_layout::~znx_vec_i64_layout() { free(region); } + +znx_i64 znx_vec_i64_layout::get_copy_zext(uint64_t index) const { + if (index < size) { + return znx_i64(n, data_start + index * slice); + } else { + return znx_i64::zero(n); + } +} + +znx_i64 znx_vec_i64_layout::get_copy(uint64_t index) const { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + return znx_i64(n, data_start + index * slice); +} + +void znx_vec_i64_layout::set(uint64_t index, const znx_i64& elem) { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + REQUIRE_DRAMATICALLY(elem.nn() == n, "incompatible ring dimensions: " << elem.nn() << " / " << n); + elem.save_as(data_start + index * slice); +} + +int64_t* znx_vec_i64_layout::data() { return data_start; } +const int64_t* znx_vec_i64_layout::data() const { return data_start; } + +void znx_vec_i64_layout::fill_random(uint64_t bits) { + for (uint64_t i = 0; i < size; ++i) { + set(i, znx_i64::random_log2bound(n, bits)); + } +} +__uint128_t znx_vec_i64_layout::content_hash() const { + test_hasher hasher; + for (uint64_t i = 0; i < size; ++i) { + hasher.update(data() + i * slice, n * sizeof(int64_t)); + } + return hasher.hash(); +} diff --git a/spqlios/lib/test/testlib/polynomial_vector.h b/spqlios/lib/test/testlib/polynomial_vector.h new file mode 100644 index 0000000..d821193 --- /dev/null +++ b/spqlios/lib/test/testlib/polynomial_vector.h @@ -0,0 +1,42 @@ +#ifndef SPQLIOS_POLYNOMIAL_VECTOR_H +#define SPQLIOS_POLYNOMIAL_VECTOR_H + +#include "negacyclic_polynomial.h" +#include "test_commons.h" + +/** @brief a test memory layout for znx i64 polynomials vectors */ +class znx_vec_i64_layout { + uint64_t n; + uint64_t size; + uint64_t slice; + int64_t* data_start; + uint8_t* region; + + public: + // NO-COPY structure + znx_vec_i64_layout(const znx_vec_i64_layout&) = delete; + void operator=(const znx_vec_i64_layout&) = delete; + znx_vec_i64_layout(znx_vec_i64_layout&&) = delete; + void operator=(znx_vec_i64_layout&&) = delete; + /** @brief initialises a memory layout */ + znx_vec_i64_layout(uint64_t n, uint64_t size, uint64_t slice); + /** @brief destructor */ + ~znx_vec_i64_layout(); + + /** @brief get a copy of item index index (extended with zeros) */ + znx_i64 get_copy_zext(uint64_t index) const; + /** @brief get a copy of item index index (extended with zeros) */ + znx_i64 get_copy(uint64_t index) const; + /** @brief get a copy of item index index (index +#include + +#include "test_commons.h" + +bool is_pow2(uint64_t n) { return !(n & (n - 1)); } + +test_rng& randgen() { + static test_rng gen; + return gen; +} +uint64_t uniform_u64() { + static std::uniform_int_distribution dist64(0, UINT64_MAX); + return dist64(randgen()); +} + +uint64_t uniform_u64_bits(uint64_t nbits) { + if (nbits >= 64) return uniform_u64(); + return uniform_u64() >> (64 - nbits); +} + +int64_t uniform_i64() { + std::uniform_int_distribution dist; + return dist(randgen()); +} + +int64_t uniform_i64_bits(uint64_t nbits) { + int64_t bound = int64_t(1) << nbits; + std::uniform_int_distribution dist(-bound, bound); + return dist(randgen()); +} + +int64_t uniform_i64_bounds(const int64_t lb, const int64_t ub) { + std::uniform_int_distribution dist(lb, ub); + return dist(randgen()); +} + +__int128_t uniform_i128_bounds(const __int128_t lb, const __int128_t ub) { + std::uniform_int_distribution<__int128_t> dist(lb, ub); + return dist(randgen()); +} + +double random_f64_gaussian(double stdev) { + std::normal_distribution dist(0, stdev); + return dist(randgen()); +} + +double uniform_f64_bounds(const double lb, const double ub) { + std::uniform_real_distribution dist(lb, ub); + return dist(randgen()); +} + +double uniform_f64_01() { + return uniform_f64_bounds(0, 1); +} diff --git a/spqlios/lib/test/testlib/reim4_elem.cpp b/spqlios/lib/test/testlib/reim4_elem.cpp new file mode 100644 index 0000000..2028a31 --- /dev/null +++ b/spqlios/lib/test/testlib/reim4_elem.cpp @@ -0,0 +1,145 @@ +#include "reim4_elem.h" + +reim4_elem::reim4_elem(const double* re, const double* im) { + for (uint64_t i = 0; i < 4; ++i) { + value[i] = re[i]; + value[4 + i] = im[i]; + } +} +reim4_elem::reim4_elem(const double* layout) { + for (uint64_t i = 0; i < 8; ++i) { + value[i] = layout[i]; + } +} +reim4_elem::reim4_elem() { + for (uint64_t i = 0; i < 8; ++i) { + value[i] = 0.; + } +} +void reim4_elem::save_re_im(double* re, double* im) const { + for (uint64_t i = 0; i < 4; ++i) { + re[i] = value[i]; + im[i] = value[4 + i]; + } +} +void reim4_elem::save_as(double* reim4) const { + for (uint64_t i = 0; i < 8; ++i) { + reim4[i] = value[i]; + } +} +reim4_elem reim4_elem::zero() { return reim4_elem(); } + +bool operator==(const reim4_elem& x, const reim4_elem& y) { + for (uint64_t i = 0; i < 8; ++i) { + if (x.value[i] != y.value[i]) return false; + } + return true; +} + +reim4_elem gaussian_reim4() { + test_rng& gen = randgen(); + std::normal_distribution dist(0, 1); + reim4_elem res; + for (uint64_t i = 0; i < 8; ++i) { + res.value[i] = dist(gen); + } + return res; +} + +reim4_array_view::reim4_array_view(uint64_t size, double* data) : size(size), data(data) {} +reim4_elem reim4_array_view::get(uint64_t i) const { + REQUIRE_DRAMATICALLY(i < size, "reim4 array overflow"); + return reim4_elem(data + 8 * i); +} +void reim4_array_view::set(uint64_t i, const reim4_elem& value) { + REQUIRE_DRAMATICALLY(i < size, "reim4 array overflow"); + value.save_as(data + 8 * i); +} + +reim_view::reim_view(uint64_t m, double* data) : m(m), data(data) {} +reim4_elem reim_view::get_blk(uint64_t i) { + REQUIRE_DRAMATICALLY(i < m / 4, "block overflow"); + return reim4_elem(data + 4 * i, data + m + 4 * i); +} +void reim_view::set_blk(uint64_t i, const reim4_elem& value) { + REQUIRE_DRAMATICALLY(i < m / 4, "block overflow"); + value.save_re_im(data + 4 * i, data + m + 4 * i); +} + +reim_vector_view::reim_vector_view(uint64_t m, uint64_t nrows, double* data) : m(m), nrows(nrows), data(data) {} +reim_view reim_vector_view::row(uint64_t row) { + REQUIRE_DRAMATICALLY(row < nrows, "row overflow"); + return reim_view(m, data + 2 * m * row); +} + +/** @brief addition */ +reim4_elem operator+(const reim4_elem& x, const reim4_elem& y) { + reim4_elem reps; + for (uint64_t i = 0; i < 8; ++i) { + reps.value[i] = x.value[i] + y.value[i]; + } + return reps; +} +reim4_elem& operator+=(reim4_elem& x, const reim4_elem& y) { + for (uint64_t i = 0; i < 8; ++i) { + x.value[i] += y.value[i]; + } + return x; +} +/** @brief subtraction */ +reim4_elem operator-(const reim4_elem& x, const reim4_elem& y) { + reim4_elem reps; + for (uint64_t i = 0; i < 8; ++i) { + reps.value[i] = x.value[i] + y.value[i]; + } + return reps; +} +reim4_elem& operator-=(reim4_elem& x, const reim4_elem& y) { + for (uint64_t i = 0; i < 8; ++i) { + x.value[i] -= y.value[i]; + } + return x; +} +/** @brief product */ +reim4_elem operator*(const reim4_elem& x, const reim4_elem& y) { + reim4_elem reps; + for (uint64_t i = 0; i < 4; ++i) { + double xre = x.value[i]; + double yre = y.value[i]; + double xim = x.value[i + 4]; + double yim = y.value[i + 4]; + reps.value[i] = xre * yre - xim * yim; + reps.value[i + 4] = xre * yim + xim * yre; + } + return reps; +} +/** @brief distance in infty norm */ +double infty_dist(const reim4_elem& x, const reim4_elem& y) { + double dist = 0; + for (uint64_t i = 0; i < 8; ++i) { + double d = fabs(x.value[i] - y.value[i]); + if (d > dist) dist = d; + } + return dist; +} + +std::ostream& operator<<(std::ostream& out, const reim4_elem& x) { + out << "[\n"; + for (uint64_t i = 0; i < 4; ++i) { + out << " re=" << x.value[i] << ", im=" << x.value[i + 4] << "\n"; + } + return out << "]"; +} + +reim4_matrix_view::reim4_matrix_view(uint64_t nrows, uint64_t ncols, double* data) + : nrows(nrows), ncols(ncols), data(data) {} +reim4_elem reim4_matrix_view::get(uint64_t row, uint64_t col) const { + REQUIRE_DRAMATICALLY(row < nrows, "rows out of bounds" << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "cols out of bounds" << col << " / " << ncols); + return reim4_elem(data + 8 * (row * ncols + col)); +} +void reim4_matrix_view::set(uint64_t row, uint64_t col, const reim4_elem& value) { + REQUIRE_DRAMATICALLY(row < nrows, "rows out of bounds" << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "cols out of bounds" << col << " / " << ncols); + value.save_as(data + 8 * (row * ncols + col)); +} diff --git a/spqlios/lib/test/testlib/reim4_elem.h b/spqlios/lib/test/testlib/reim4_elem.h new file mode 100644 index 0000000..68d9430 --- /dev/null +++ b/spqlios/lib/test/testlib/reim4_elem.h @@ -0,0 +1,95 @@ +#ifndef SPQLIOS_REIM4_ELEM_H +#define SPQLIOS_REIM4_ELEM_H + +#include "test_commons.h" + +/** @brief test class representing one single reim4 element */ +class reim4_elem { + public: + /** @brief 8 components (4 real parts followed by 4 imag parts) */ + double value[8]; + /** @brief constructs from 4 real parts and 4 imaginary parts */ + reim4_elem(const double* re, const double* im); + /** @brief constructs from 8 components */ + explicit reim4_elem(const double* layout); + /** @brief zero */ + reim4_elem(); + /** @brief saves the real parts to re and the 4 imag to im */ + void save_re_im(double* re, double* im) const; + /** @brief saves the 8 components to reim4 */ + void save_as(double* reim4) const; + static reim4_elem zero(); +}; + +/** @brief checks for equality */ +bool operator==(const reim4_elem& x, const reim4_elem& y); +/** @brief random gaussian reim4 of stdev 1 and mean 0 */ +reim4_elem gaussian_reim4(); +/** @brief addition */ +reim4_elem operator+(const reim4_elem& x, const reim4_elem& y); +reim4_elem& operator+=(reim4_elem& x, const reim4_elem& y); +/** @brief subtraction */ +reim4_elem operator-(const reim4_elem& x, const reim4_elem& y); +reim4_elem& operator-=(reim4_elem& x, const reim4_elem& y); +/** @brief product */ +reim4_elem operator*(const reim4_elem& x, const reim4_elem& y); +std::ostream& operator<<(std::ostream& out, const reim4_elem& x); +/** @brief distance in infty norm */ +double infty_dist(const reim4_elem& x, const reim4_elem& y); + +/** @brief test class representing the view of one reim of m complexes */ +class reim4_array_view { + uint64_t size; ///< size of the reim array + double* data; ///< pointer to the start of the array + public: + /** @brief ininitializes a view at an existing given address */ + reim4_array_view(uint64_t size, double* data); + ; + /** @brief gets the i-th element */ + reim4_elem get(uint64_t i) const; + /** @brief sets the i-th element */ + void set(uint64_t i, const reim4_elem& value); +}; + +/** @brief test class representing the view of one matrix of nrowsxncols reim4's */ +class reim4_matrix_view { + uint64_t nrows; ///< number of rows + uint64_t ncols; ///< number of columns + double* data; ///< pointer to the start of the matrix + public: + /** @brief ininitializes a view at an existing given address */ + reim4_matrix_view(uint64_t nrows, uint64_t ncols, double* data); + /** @brief gets the i-th element */ + reim4_elem get(uint64_t row, uint64_t col) const; + /** @brief sets the i-th element */ + void set(uint64_t row, uint64_t col, const reim4_elem& value); +}; + +/** @brief test class representing the view of one reim of m complexes */ +class reim_view { + uint64_t m; ///< (complex) dimension of the reim polynomial + double* data; ///< address of the start of the reim polynomial + public: + /** @brief ininitializes a view at an existing given address */ + reim_view(uint64_t m, double* data); + ; + /** @brief extracts the i-th reim4 block (i +// https://github.com/mjosaarinen/tiny_sha3 +// LICENSE: MIT + +// Revised 07-Aug-15 to match with official release of FIPS PUB 202 "SHA3" +// Revised 03-Sep-15 for portability + OpenSSL - style API + +#include "sha3.h" + +// update the state with given number of rounds + +void sha3_keccakf(uint64_t st[25]) { + // constants + const uint64_t keccakf_rndc[24] = {0x0000000000000001, 0x0000000000008082, 0x800000000000808a, 0x8000000080008000, + 0x000000000000808b, 0x0000000080000001, 0x8000000080008081, 0x8000000000008009, + 0x000000000000008a, 0x0000000000000088, 0x0000000080008009, 0x000000008000000a, + 0x000000008000808b, 0x800000000000008b, 0x8000000000008089, 0x8000000000008003, + 0x8000000000008002, 0x8000000000000080, 0x000000000000800a, 0x800000008000000a, + 0x8000000080008081, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008}; + const int keccakf_rotc[24] = {1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 2, 14, + 27, 41, 56, 8, 25, 43, 62, 18, 39, 61, 20, 44}; + const int keccakf_piln[24] = {10, 7, 11, 17, 18, 3, 5, 16, 8, 21, 24, 4, 15, 23, 19, 13, 12, 2, 20, 14, 22, 9, 6, 1}; + + // variables + int i, j, r; + uint64_t t, bc[5]; + +#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__ + uint8_t* v; + + // endianess conversion. this is redundant on little-endian targets + for (i = 0; i < 25; i++) { + v = (uint8_t*)&st[i]; + st[i] = ((uint64_t)v[0]) | (((uint64_t)v[1]) << 8) | (((uint64_t)v[2]) << 16) | (((uint64_t)v[3]) << 24) | + (((uint64_t)v[4]) << 32) | (((uint64_t)v[5]) << 40) | (((uint64_t)v[6]) << 48) | (((uint64_t)v[7]) << 56); + } +#endif + + // actual iteration + for (r = 0; r < KECCAKF_ROUNDS; r++) { + // Theta + for (i = 0; i < 5; i++) bc[i] = st[i] ^ st[i + 5] ^ st[i + 10] ^ st[i + 15] ^ st[i + 20]; + + for (i = 0; i < 5; i++) { + t = bc[(i + 4) % 5] ^ ROTL64(bc[(i + 1) % 5], 1); + for (j = 0; j < 25; j += 5) st[j + i] ^= t; + } + + // Rho Pi + t = st[1]; + for (i = 0; i < 24; i++) { + j = keccakf_piln[i]; + bc[0] = st[j]; + st[j] = ROTL64(t, keccakf_rotc[i]); + t = bc[0]; + } + + // Chi + for (j = 0; j < 25; j += 5) { + for (i = 0; i < 5; i++) bc[i] = st[j + i]; + for (i = 0; i < 5; i++) st[j + i] ^= (~bc[(i + 1) % 5]) & bc[(i + 2) % 5]; + } + + // Iota + st[0] ^= keccakf_rndc[r]; + } + +#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__ + // endianess conversion. this is redundant on little-endian targets + for (i = 0; i < 25; i++) { + v = (uint8_t*)&st[i]; + t = st[i]; + v[0] = t & 0xFF; + v[1] = (t >> 8) & 0xFF; + v[2] = (t >> 16) & 0xFF; + v[3] = (t >> 24) & 0xFF; + v[4] = (t >> 32) & 0xFF; + v[5] = (t >> 40) & 0xFF; + v[6] = (t >> 48) & 0xFF; + v[7] = (t >> 56) & 0xFF; + } +#endif +} + +// Initialize the context for SHA3 + +int sha3_init(sha3_ctx_t* c, int mdlen) { + int i; + + for (i = 0; i < 25; i++) c->st.q[i] = 0; + c->mdlen = mdlen; + c->rsiz = 200 - 2 * mdlen; + c->pt = 0; + + return 1; +} + +// update state with more data + +int sha3_update(sha3_ctx_t* c, const void* data, size_t len) { + size_t i; + int j; + + j = c->pt; + for (i = 0; i < len; i++) { + c->st.b[j++] ^= ((const uint8_t*)data)[i]; + if (j >= c->rsiz) { + sha3_keccakf(c->st.q); + j = 0; + } + } + c->pt = j; + + return 1; +} + +// finalize and output a hash + +int sha3_final(void* md, sha3_ctx_t* c) { + int i; + + c->st.b[c->pt] ^= 0x06; + c->st.b[c->rsiz - 1] ^= 0x80; + sha3_keccakf(c->st.q); + + for (i = 0; i < c->mdlen; i++) { + ((uint8_t*)md)[i] = c->st.b[i]; + } + + return 1; +} + +// compute a SHA-3 hash (md) of given byte length from "in" + +void* sha3(const void* in, size_t inlen, void* md, int mdlen) { + sha3_ctx_t sha3; + + sha3_init(&sha3, mdlen); + sha3_update(&sha3, in, inlen); + sha3_final(md, &sha3); + + return md; +} + +// SHAKE128 and SHAKE256 extensible-output functionality + +void shake_xof(sha3_ctx_t* c) { + c->st.b[c->pt] ^= 0x1F; + c->st.b[c->rsiz - 1] ^= 0x80; + sha3_keccakf(c->st.q); + c->pt = 0; +} + +void shake_out(sha3_ctx_t* c, void* out, size_t len) { + size_t i; + int j; + + j = c->pt; + for (i = 0; i < len; i++) { + if (j >= c->rsiz) { + sha3_keccakf(c->st.q); + j = 0; + } + ((uint8_t*)out)[i] = c->st.b[j++]; + } + c->pt = j; +} diff --git a/spqlios/lib/test/testlib/sha3.h b/spqlios/lib/test/testlib/sha3.h new file mode 100644 index 0000000..08a7c86 --- /dev/null +++ b/spqlios/lib/test/testlib/sha3.h @@ -0,0 +1,56 @@ +// sha3.h +// 19-Nov-11 Markku-Juhani O. Saarinen +// https://github.com/mjosaarinen/tiny_sha3 +// License: MIT + +#ifndef SHA3_H +#define SHA3_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include + +#ifndef KECCAKF_ROUNDS +#define KECCAKF_ROUNDS 24 +#endif + +#ifndef ROTL64 +#define ROTL64(x, y) (((x) << (y)) | ((x) >> (64 - (y)))) +#endif + +// state context +typedef struct { + union { // state: + uint8_t b[200]; // 8-bit bytes + uint64_t q[25]; // 64-bit words + } st; + int pt, rsiz, mdlen; // these don't overflow +} sha3_ctx_t; + +// Compression function. +void sha3_keccakf(uint64_t st[25]); + +// OpenSSL - like interfece +int sha3_init(sha3_ctx_t* c, int mdlen); // mdlen = hash output in bytes +int sha3_update(sha3_ctx_t* c, const void* data, size_t len); +int sha3_final(void* md, sha3_ctx_t* c); // digest goes to md + +// compute a sha3 hash (md) of given byte length from "in" +void* sha3(const void* in, size_t inlen, void* md, int mdlen); + +// SHAKE128 and SHAKE256 extensible-output functions +#define shake128_init(c) sha3_init(c, 16) +#define shake256_init(c) sha3_init(c, 32) +#define shake_update sha3_update + +void shake_xof(sha3_ctx_t* c); +void shake_out(sha3_ctx_t* c, void* out, size_t len); + +#ifdef __cplusplus +} +#endif + +#endif // SHA3_H diff --git a/spqlios/lib/test/testlib/test_commons.cpp b/spqlios/lib/test/testlib/test_commons.cpp new file mode 100644 index 0000000..90f3606 --- /dev/null +++ b/spqlios/lib/test/testlib/test_commons.cpp @@ -0,0 +1,10 @@ +#include "test_commons.h" + +#include + +std::ostream& operator<<(std::ostream& out, __int128_t x) { + char c[35] = {0}; + snprintf(c, 35, "0x%016" PRIx64 "%016" PRIx64, uint64_t(x >> 64), uint64_t(x)); + return out << c; +} +std::ostream& operator<<(std::ostream& out, __uint128_t x) { return out << __int128_t(x); } diff --git a/spqlios/lib/test/testlib/test_commons.h b/spqlios/lib/test/testlib/test_commons.h new file mode 100644 index 0000000..fcadd79 --- /dev/null +++ b/spqlios/lib/test/testlib/test_commons.h @@ -0,0 +1,74 @@ +#ifndef SPQLIOS_TEST_COMMONS_H +#define SPQLIOS_TEST_COMMONS_H + +#include +#include + +#include "../../spqlios/commons.h" + +/** @brief macro that crashes if the condition are not met */ +#define REQUIRE_DRAMATICALLY(req_contition, error_msg) \ + do { \ + if (!(req_contition)) { \ + std::cerr << "REQUIREMENT FAILED at " << __FILE__ << ":" << __LINE__ << ": " << error_msg << std::endl; \ + abort(); \ + } \ + } while (0) + +typedef std::default_random_engine test_rng; +/** @brief reference to the default test rng */ +test_rng& randgen(); +/** @brief uniformly random 64-bit uint */ +uint64_t uniform_u64(); +/** @brief uniformly random number <= 2^nbits-1 */ +uint64_t uniform_u64_bits(uint64_t nbits); +/** @brief uniformly random signed 64-bit number */ +int64_t uniform_i64(); +/** @brief uniformly random signed |number| <= 2^nbits */ +int64_t uniform_i64_bits(uint64_t nbits); +/** @brief uniformly random signed lb <= number <= ub */ +int64_t uniform_i64_bounds(const int64_t lb, const int64_t ub); +/** @brief uniformly random signed lb <= number <= ub */ +__int128_t uniform_i128_bounds(const __int128_t lb, const __int128_t ub); +/** @brief uniformly random gaussian float64 */ +double random_f64_gaussian(double stdev = 1); +/** @brief uniformly random signed lb <= number <= ub */ +double uniform_f64_bounds(const double lb, const double ub); +/** @brief uniformly random float64 in [0,1] */ +double uniform_f64_01(); +/** @brief random gaussian float64 */ +double random_f64_gaussian(double stdev); + +bool is_pow2(uint64_t n); + +void* alloc64(uint64_t size); + +typedef __uint128_t thash; +/** @brief returns some pseudorandom hash of a contiguous content */ +thash test_hash(const void* data, uint64_t size); +/** @brief class to return a pseudorandom hash of a piecewise-defined content */ +class test_hasher { + void* md; + public: + test_hasher(); + test_hasher(const test_hasher&) = delete; + void operator=(const test_hasher&) = delete; + /** + * @brief append input bytes. + * The final hash only depends on the concatenation of bytes, not on the + * way the content was split into multiple calls to update. + */ + void update(const void* data, uint64_t size); + /** + * @brief returns the final hash. + * no more calls to update(...) shall be issued after this call. + */ + thash hash(); + ~test_hasher(); +}; + +// not included by default, since it makes some versions of gtest not compile +// std::ostream& operator<<(std::ostream& out, __int128_t x); +// std::ostream& operator<<(std::ostream& out, __uint128_t x); + +#endif // SPQLIOS_TEST_COMMONS_H diff --git a/spqlios/lib/test/testlib/test_hash.cpp b/spqlios/lib/test/testlib/test_hash.cpp new file mode 100644 index 0000000..2e065af --- /dev/null +++ b/spqlios/lib/test/testlib/test_hash.cpp @@ -0,0 +1,24 @@ +#include "sha3.h" +#include "test_commons.h" + +/** @brief returns some pseudorandom hash of the content */ +thash test_hash(const void* data, uint64_t size) { + thash res; + sha3(data, size, &res, sizeof(res)); + return res; +} +/** @brief class to return a pseudorandom hash of the content */ +test_hasher::test_hasher() { + md = malloc(sizeof(sha3_ctx_t)); + sha3_init((sha3_ctx_t*)md, 16); +} + +void test_hasher::update(const void* data, uint64_t size) { sha3_update((sha3_ctx_t*)md, data, size); } + +thash test_hasher::hash() { + thash res; + sha3_final(&res, (sha3_ctx_t*)md); + return res; +} + +test_hasher::~test_hasher() { free(md); } diff --git a/spqlios/lib/test/testlib/vec_rnx_layout.cpp b/spqlios/lib/test/testlib/vec_rnx_layout.cpp new file mode 100644 index 0000000..2a61e81 --- /dev/null +++ b/spqlios/lib/test/testlib/vec_rnx_layout.cpp @@ -0,0 +1,182 @@ +#include "vec_rnx_layout.h" + +#include + +#include "../../spqlios/arithmetic/vec_rnx_arithmetic.h" + +#ifdef VALGRIND_MEM_TESTS +#include "valgrind/memcheck.h" +#endif + +#define CANARY_PADDING (1024) +#define GARBAGE_VALUE (242) + +rnx_vec_f64_layout::rnx_vec_f64_layout(uint64_t n, uint64_t size, uint64_t slice) : n(n), size(size), slice(slice) { + REQUIRE_DRAMATICALLY(is_pow2(n), "not a power of 2" << n); + REQUIRE_DRAMATICALLY(slice >= n, "slice too small" << slice << " < " << n); + this->region = (uint8_t*)malloc(size * slice * sizeof(int64_t) + 2 * CANARY_PADDING); + this->data_start = (double*)(region + CANARY_PADDING); + // ensure that any invalid value is kind-of garbage + memset(region, GARBAGE_VALUE, size * slice * sizeof(int64_t) + 2 * CANARY_PADDING); + // mark inter-slice memory as not accessible +#ifdef VALGRIND_MEM_TESTS + VALGRIND_MAKE_MEM_NOACCESS(region, CANARY_PADDING); + VALGRIND_MAKE_MEM_NOACCESS(region + size * slice * sizeof(int64_t) + CANARY_PADDING, CANARY_PADDING); + for (uint64_t i = 0; i < size; ++i) { + VALGRIND_MAKE_MEM_UNDEFINED(data_start + i * slice, n * sizeof(int64_t)); + } + if (size != slice) { + for (uint64_t i = 0; i < size; ++i) { + VALGRIND_MAKE_MEM_NOACCESS(data_start + i * slice + n, (slice - n) * sizeof(int64_t)); + } + } +#endif +} + +rnx_vec_f64_layout::~rnx_vec_f64_layout() { free(region); } + +rnx_f64 rnx_vec_f64_layout::get_copy_zext(uint64_t index) const { + if (index < size) { + return rnx_f64(n, data_start + index * slice); + } else { + return rnx_f64::zero(n); + } +} + +rnx_f64 rnx_vec_f64_layout::get_copy(uint64_t index) const { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + return rnx_f64(n, data_start + index * slice); +} + +reim_fft64vec rnx_vec_f64_layout::get_dft_copy_zext(uint64_t index) const { + if (index < size) { + return reim_fft64vec(n, data_start + index * slice); + } else { + return reim_fft64vec::zero(n); + } +} + +reim_fft64vec rnx_vec_f64_layout::get_dft_copy(uint64_t index) const { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + return reim_fft64vec(n, data_start + index * slice); +} + +void rnx_vec_f64_layout::set(uint64_t index, const rnx_f64& elem) { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + REQUIRE_DRAMATICALLY(elem.nn() == n, "incompatible ring dimensions: " << elem.nn() << " / " << n); + elem.save_as(data_start + index * slice); +} + +double* rnx_vec_f64_layout::data() { return data_start; } +const double* rnx_vec_f64_layout::data() const { return data_start; } + +void rnx_vec_f64_layout::fill_random(double log2bound) { + for (uint64_t i = 0; i < size; ++i) { + set(i, rnx_f64::random_log2bound(n, log2bound)); + } +} + +thash rnx_vec_f64_layout::content_hash() const { + test_hasher hasher; + for (uint64_t i = 0; i < size; ++i) { + hasher.update(data() + i * slice, n * sizeof(int64_t)); + } + return hasher.hash(); +} + +fft64_rnx_vmp_pmat_layout::fft64_rnx_vmp_pmat_layout(uint64_t n, uint64_t nrows, uint64_t ncols) + : nn(n), + nrows(nrows), + ncols(ncols), // + data((RNX_VMP_PMAT*)alloc64(nrows * ncols * nn * 8)) {} + +double* fft64_rnx_vmp_pmat_layout::get_addr(uint64_t row, uint64_t col, uint64_t blk) const { + REQUIRE_DRAMATICALLY(row < nrows, "row overflow: " << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "col overflow: " << col << " / " << ncols); + REQUIRE_DRAMATICALLY(blk < nn / 8, "block overflow: " << blk << " / " << (nn / 8)); + double* d = (double*)data; + if (col == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + return d + blk * nrows * ncols * 8 // major: blk + + col * nrows * 8 // col == ncols-1 + + row * 8; + } else { + // general case: columns go by pair + return d + blk * nrows * ncols * 8 // major: blk + + (col / 2) * (2 * nrows) * 8 // second: col pair index + + row * 2 * 8 // third: row index + + (col % 2) * 8; // minor: col in colpair + } +} + +reim4_elem fft64_rnx_vmp_pmat_layout::get(uint64_t row, uint64_t col, uint64_t blk) const { + return reim4_elem(get_addr(row, col, blk)); +} +reim4_elem fft64_rnx_vmp_pmat_layout::get_zext(uint64_t row, uint64_t col, uint64_t blk) const { + REQUIRE_DRAMATICALLY(blk < nn / 8, "block overflow: " << blk << " / " << (nn / 8)); + if (row < nrows && col < ncols) { + return reim4_elem(get_addr(row, col, blk)); + } else { + return reim4_elem::zero(); + } +} +void fft64_rnx_vmp_pmat_layout::set(uint64_t row, uint64_t col, uint64_t blk, const reim4_elem& value) const { + value.save_as(get_addr(row, col, blk)); +} + +fft64_rnx_vmp_pmat_layout::~fft64_rnx_vmp_pmat_layout() { spqlios_free(data); } + +reim_fft64vec fft64_rnx_vmp_pmat_layout::get_zext(uint64_t row, uint64_t col) const { + if (row >= nrows || col >= ncols) { + return reim_fft64vec::zero(nn); + } + if (nn < 8) { + // the pmat is just col major + double* addr = (double*)data + (row + col * nrows) * nn; + return reim_fft64vec(nn, addr); + } + // otherwise, reconstruct it block by block + reim_fft64vec res(nn); + for (uint64_t blk = 0; blk < nn / 8; ++blk) { + reim4_elem v = get(row, col, blk); + res.set_blk(blk, v); + } + return res; +} +void fft64_rnx_vmp_pmat_layout::set(uint64_t row, uint64_t col, const reim_fft64vec& value) { + REQUIRE_DRAMATICALLY(row < nrows, "row overflow: " << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "row overflow: " << col << " / " << ncols); + if (nn < 8) { + // the pmat is just col major + double* addr = (double*)data + (row + col * nrows) * nn; + value.save_as(addr); + return; + } + // otherwise, reconstruct it block by block + for (uint64_t blk = 0; blk < nn / 8; ++blk) { + reim4_elem v = value.get_blk(blk); + set(row, col, blk, v); + } +} +void fft64_rnx_vmp_pmat_layout::fill_random(double log2bound) { + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + set(row, col, reim_fft64vec::random(nn, log2bound)); + } + } +} + +fft64_rnx_svp_ppol_layout::fft64_rnx_svp_ppol_layout(uint64_t n) + : nn(n), // + data((RNX_SVP_PPOL*)alloc64(nn * 8)) {} + +reim_fft64vec fft64_rnx_svp_ppol_layout::get_copy() const { return reim_fft64vec(nn, (double*)data); } + +void fft64_rnx_svp_ppol_layout::set(const reim_fft64vec& value) { value.save_as((double*)data); } + +void fft64_rnx_svp_ppol_layout::fill_dft_random(uint64_t log2bound) { set(reim_fft64vec::dft_random(nn, log2bound)); } + +void fft64_rnx_svp_ppol_layout::fill_random(double log2bound) { set(reim_fft64vec::random(nn, log2bound)); } + +fft64_rnx_svp_ppol_layout::~fft64_rnx_svp_ppol_layout() { spqlios_free(data); } +thash fft64_rnx_svp_ppol_layout::content_hash() const { return test_hash(data, nn * sizeof(double)); } \ No newline at end of file diff --git a/spqlios/lib/test/testlib/vec_rnx_layout.h b/spqlios/lib/test/testlib/vec_rnx_layout.h new file mode 100644 index 0000000..a92bc04 --- /dev/null +++ b/spqlios/lib/test/testlib/vec_rnx_layout.h @@ -0,0 +1,85 @@ +#ifndef SPQLIOS_EXT_VEC_RNX_LAYOUT_H +#define SPQLIOS_EXT_VEC_RNX_LAYOUT_H + +#include "../../spqlios/arithmetic/vec_rnx_arithmetic.h" +#include "fft64_dft.h" +#include "negacyclic_polynomial.h" +#include "reim4_elem.h" +#include "test_commons.h" + +/** @brief a test memory layout for rnx i64 polynomials vectors */ +class rnx_vec_f64_layout { + uint64_t n; + uint64_t size; + uint64_t slice; + double* data_start; + uint8_t* region; + + public: + // NO-COPY structure + rnx_vec_f64_layout(const rnx_vec_f64_layout&) = delete; + void operator=(const rnx_vec_f64_layout&) = delete; + rnx_vec_f64_layout(rnx_vec_f64_layout&&) = delete; + void operator=(rnx_vec_f64_layout&&) = delete; + /** @brief initialises a memory layout */ + rnx_vec_f64_layout(uint64_t n, uint64_t size, uint64_t slice); + /** @brief destructor */ + ~rnx_vec_f64_layout(); + + /** @brief get a copy of item index index (extended with zeros) */ + rnx_f64 get_copy_zext(uint64_t index) const; + /** @brief get a copy of item index index (extended with zeros) */ + rnx_f64 get_copy(uint64_t index) const; + /** @brief get a copy of item index index (extended with zeros) */ + reim_fft64vec get_dft_copy_zext(uint64_t index) const; + /** @brief get a copy of item index index (extended with zeros) */ + reim_fft64vec get_dft_copy(uint64_t index) const; + + /** @brief get a copy of item index index (index> 5; + const uint64_t rem_ncols = ncols & 31; + uint64_t blk = col >> 5; + uint64_t col_rem = col & 31; + if (blk < nblk) { + // column is part of a full block + return (int32_t*)data + blk * nrows * 32 + row * 32 + col_rem; + } else { + // column is part of the last block + return (int32_t*)data + blk * nrows * 32 + row * rem_ncols + col_rem; + } +} +int32_t zn32_pmat_layout::get(uint64_t row, uint64_t col) const { return *get_addr(row, col); } +int32_t zn32_pmat_layout::get_zext(uint64_t row, uint64_t col) const { + if (row >= nrows || col >= ncols) return 0; + return *get_addr(row, col); +} +void zn32_pmat_layout::set(uint64_t row, uint64_t col, int32_t value) { *get_addr(row, col) = value; } +void zn32_pmat_layout::fill_random() { + int32_t* d = (int32_t*)data; + for (uint64_t i = 0; i < nrows * ncols; ++i) d[i] = uniform_i64_bits(32); +} +thash zn32_pmat_layout::content_hash() const { return test_hash(data, nrows * ncols * sizeof(int32_t)); } + +template +std::vector vmp_product(const T* vec, uint64_t vec_size, uint64_t out_size, const zn32_pmat_layout& mat) { + uint64_t rows = std::min(vec_size, mat.nrows); + uint64_t cols = std::min(out_size, mat.ncols); + std::vector res(out_size, 0); + for (uint64_t j = 0; j < cols; ++j) { + for (uint64_t i = 0; i < rows; ++i) { + res[j] += vec[i] * mat.get(i, j); + } + } + return res; +} + +template std::vector vmp_product(const int8_t* vec, uint64_t vec_size, uint64_t out_size, + const zn32_pmat_layout& mat); +template std::vector vmp_product(const int16_t* vec, uint64_t vec_size, uint64_t out_size, + const zn32_pmat_layout& mat); +template std::vector vmp_product(const int32_t* vec, uint64_t vec_size, uint64_t out_size, + const zn32_pmat_layout& mat); diff --git a/spqlios/lib/test/testlib/zn_layouts.h b/spqlios/lib/test/testlib/zn_layouts.h new file mode 100644 index 0000000..b36ce3e --- /dev/null +++ b/spqlios/lib/test/testlib/zn_layouts.h @@ -0,0 +1,29 @@ +#ifndef SPQLIOS_EXT_ZN_LAYOUTS_H +#define SPQLIOS_EXT_ZN_LAYOUTS_H + +#include "../../spqlios/arithmetic/zn_arithmetic.h" +#include "test_commons.h" + +class zn32_pmat_layout { + public: + const uint64_t nrows; + const uint64_t ncols; + ZN32_VMP_PMAT* const data; + zn32_pmat_layout(uint64_t nrows, uint64_t ncols); + + private: + int32_t* get_addr(uint64_t row, uint64_t col) const; + + public: + int32_t get(uint64_t row, uint64_t col) const; + int32_t get_zext(uint64_t row, uint64_t col) const; + void set(uint64_t row, uint64_t col, int32_t value); + void fill_random(); + thash content_hash() const; + ~zn32_pmat_layout(); +}; + +template +std::vector vmp_product(const T* vec, uint64_t vec_size, uint64_t out_size, const zn32_pmat_layout& mat); + +#endif // SPQLIOS_EXT_ZN_LAYOUTS_H diff --git a/spqlios/src/lib.rs b/spqlios/src/lib.rs new file mode 100644 index 0000000..945716e --- /dev/null +++ b/spqlios/src/lib.rs @@ -0,0 +1,15 @@ +pub mod module; +pub mod poly; + +#[allow( + non_camel_case_types, + non_snake_case, + non_upper_case_globals, + dead_code, + improper_ctypes +)] +pub mod bindings { + include!(concat!(env!("OUT_DIR"), "/bindings.rs")); +} + +pub use bindings::*; diff --git a/spqlios/src/mod.rs b/spqlios/src/mod.rs new file mode 100644 index 0000000..a94a210 --- /dev/null +++ b/spqlios/src/mod.rs @@ -0,0 +1 @@ +pub mod module; \ No newline at end of file diff --git a/spqlios/src/module.rs b/spqlios/src/module.rs new file mode 100644 index 0000000..84b6136 --- /dev/null +++ b/spqlios/src/module.rs @@ -0,0 +1,91 @@ +use crate::bindings::*; + +pub fn create_module(N: u64, mtype: module_type_t) -> *mut MODULE { + unsafe { + let m = new_module_info(N, mtype); + if m.is_null() { + println!("Failed to create module."); + } + m + } +} + +#[test] +fn test_new_module_info() { + let N: u64 = 1024; + let module_ptr: *mut module_info_t = create_module(N, module_type_t_FFT64); + assert!(!module_ptr.is_null()); + println!("{:?}", module_ptr); +} + +#[cfg(test)] +mod tests { + use super::*; + use std::ffi::c_void; + use std::time::Instant; + //use test::Bencher; + + #[test] + fn test_fft() { + let log_bound: usize = 19; + + let n: usize = 2048; + let m: usize = n >> 1; + + let mut a: Vec = vec![i64::default(); n]; + let mut b: Vec = vec![i64::default(); n]; + let mut c: Vec = vec![i64::default(); n]; + + a.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + b[1] = 1; + + println!("{:?}", b); + + unsafe { + let reim_fft_precomp = new_reim_fft_precomp(m as u32, 2); + let reim_ifft_precomp = new_reim_ifft_precomp(m as u32, 1); + + let buf_a = reim_fft_precomp_get_buffer(reim_fft_precomp, 0); + let buf_b = reim_fft_precomp_get_buffer(reim_fft_precomp, 1); + let buf_c = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0); + + let now = Instant::now(); + (0..1024).for_each(|i| { + reim_from_znx64_simple( + m as u32, + log_bound as u32, + buf_a as *mut c_void, + a.as_ptr(), + ); + reim_fft(reim_fft_precomp, buf_a); + + reim_from_znx64_simple( + m as u32, + log_bound as u32, + buf_b as *mut c_void, + b.as_ptr(), + ); + reim_fft(reim_fft_precomp, buf_b); + + reim_fftvec_mul_simple( + m as u32, + buf_c as *mut c_void, + buf_a as *mut c_void, + buf_b as *mut c_void, + ); + reim_ifft(reim_ifft_precomp, buf_c); + + reim_to_znx64_simple( + m as u32, + m as f64, + log_bound as u32, + c.as_mut_ptr(), + buf_c as *mut c_void, + ) + }); + + println!("time: {}us", now.elapsed().as_micros()); + println!("{:?}", &c[..16]); + } + } +} diff --git a/spqlios/src/poly.rs b/spqlios/src/poly.rs new file mode 100644 index 0000000..a7c940f --- /dev/null +++ b/spqlios/src/poly.rs @@ -0,0 +1,190 @@ +use crate::{znx_normalize, znx_zero_i64_ref}; +use itertools::izip; +use std::cmp::min; + +pub struct Poly { + pub n: usize, + pub k: usize, + pub prec: usize, + pub data: Vec, +} + +impl Poly { + pub fn new(n: usize, k: usize, prec: usize) -> Self { + Self { + n: n, + k: k, + prec: prec, + data: vec![i64::default(); Self::buffer_size(n, k, prec)], + } + } + + pub fn buffer_size(n: usize, k: usize, prec: usize) -> usize { + n * ((prec + k - 1) / k) + } + + pub fn from_buffer(&mut self, n: usize, k: usize, prec: usize, buf: &[i64]) { + let size = Self::buffer_size(n, k, prec); + assert!( + buf.len() >= size, + "invalid buffer: buf.len()={} < self.buffer_size(n={}, k={}, prec={})={}", + buf.len(), + n, + k, + prec, + size + ); + self.n = n; + self.k = k; + self.prec = prec; + self.data = Vec::from(&buf[..size]) + } + + pub fn log_n(&self) -> usize { + (u64::BITS - (self.n - 1).leading_zeros()) as _ + } + + pub fn n(&self) -> usize { + self.n + } + + pub fn limbs(&self) -> usize { + self.data.len() / self.n + } + + pub fn at(&self, i: usize) -> &[i64] { + &self.data[i * self.n..(i + 1) * self.n] + } + + pub fn at_ptr(&self, i: usize) -> *const i64 { + &self.data[i * self.n] as *const i64 + } + + pub fn at_mut_ptr(&mut self, i: usize) -> *mut i64 { + &mut self.data[i * self.n] as *mut i64 + } + + pub fn at_mut(&mut self, i: usize) -> &mut [i64] { + &mut self.data[i * self.n..(i + 1) * self.n] + } + + pub fn set_i64(&mut self, data: &[i64], log_max: usize) { + let size: usize = min(data.len(), self.n()); + let k_rem: usize = self.k - (self.prec % self.k); + + // If 2^{base} * 2^{k_rem} < 2^{63}-1, then we can simply copy + // values on the last limb. + // Else we decompose values base k. + if log_max + k_rem < 63 || k_rem == self.k { + self.at_mut(self.limbs() - 1).copy_from_slice(&data[..size]); + } else { + let mask: i64 = (1 << self.k) - 1; + let limbs = self.limbs(); + let steps: usize = min(limbs, (log_max + k_rem + self.k - 1) / self.k); + (limbs - steps..limbs) + .rev() + .enumerate() + .for_each(|(i, i_rev)| { + let shift: usize = i * self.k; + izip!(self.at_mut(i_rev)[..size].iter_mut(), data[..size].iter()) + .for_each(|(y, x)| *y = (x >> shift) & mask); + }) + } + + // Case where self.prec % self.k != 0. + if k_rem != self.k { + let limbs = self.limbs(); + let steps: usize = min(limbs, (log_max + k_rem + self.k - 1) / self.k); + (limbs - steps..limbs).rev().for_each(|i| { + self.at_mut(i)[..size].iter_mut().for_each(|x| *x <<= k_rem); + }) + } + } + + pub fn normalize(&mut self, carry: &mut [i64]) { + assert!( + carry.len() >= self.n, + "invalid carry: carry.len()={} < self.n()={}", + carry.len(), + self.n() + ); + unsafe { + znx_zero_i64_ref(self.n() as u64, carry.as_mut_ptr()); + (0..self.limbs()).rev().for_each(|i| { + znx_normalize( + self.n as u64, + self.k as u64, + self.at_mut_ptr(i), + carry.as_mut_ptr(), + self.at_mut_ptr(i), + carry.as_mut_ptr(), + ) + }); + } + } + + pub fn get_i64(&self, data: &mut [i64]) { + assert!( + data.len() >= self.n, + "invalid data: data.len()={} < self.n()={}", + data.len(), + self.n + ); + data.copy_from_slice(self.at(0)); + let rem: usize = self.k - (self.prec % self.k); + (1..self.limbs()).for_each(|i| { + if i == self.limbs() - 1 && rem != self.k { + let k_rem: usize = self.k - rem; + izip!(self.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { + *y = (*y << k_rem) + (x >> rem); + }); + } else { + izip!(self.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { + *y = (*y << self.k) + x; + }); + } + }) + } +} + +#[cfg(test)] +mod tests { + use crate::poly::Poly; + use itertools::izip; + use sampling::source::Source; + + #[test] + fn test_set_get_i64_lo_norm() { + let n: usize = 32; + let k: usize = 19; + let prec: usize = 128; + let mut a: Poly = Poly::new(n, k, prec); + let mut have: Vec = vec![i64::default(); n]; + have.iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (i as i64) - (n as i64) / 2); + a.set_i64(&have, 10); + let mut want = vec![i64::default(); n]; + a.get_i64(&mut want); + izip!(want, have).for_each(|(a, b)| assert_eq!(a, b)); + } + + #[test] + fn test_set_get_i64_hi_norm() { + let n: usize = 1; + let k: usize = 19; + let prec: usize = 128; + let mut a: Poly = Poly::new(n, k, prec); + let mut have: Vec = vec![i64::default(); n]; + let mut source = Source::new([1; 32]); + have.iter_mut().for_each(|x| { + *x = source + .next_u64n(u64::MAX, u64::MAX) + .wrapping_sub(u64::MAX / 2 + 1) as i64; + }); + a.set_i64(&have, 63); + let mut want = vec![i64::default(); n]; + a.get_i64(&mut want); + izip!(want, have).for_each(|(a, b)| assert_eq!(a, b)); + } +} diff --git a/spqlios/tests/module.rs b/spqlios/tests/module.rs new file mode 100644 index 0000000..2fb065a --- /dev/null +++ b/spqlios/tests/module.rs @@ -0,0 +1,9 @@ +use spqlios::bindings::{module_info_t, module_type_t_FFT64}; +use spqlios::module::create_module; + +#[test] +fn test_new_module_info() { + let N: u64 = 1024; + let module_ptr: *mut module_info_t = create_module(N, module_type_t_FFT64); + assert!(!module_ptr.is_null()); +}