mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
spqlios basic wrapper
This commit is contained in:
113
Cargo.lock
generated
113
Cargo.lock
generated
@@ -49,6 +49,32 @@ version = "1.4.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
|
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bindgen"
|
||||||
|
version = "0.71.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"cexpr",
|
||||||
|
"clang-sys",
|
||||||
|
"itertools 0.10.5",
|
||||||
|
"log",
|
||||||
|
"prettyplease",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"regex",
|
||||||
|
"rustc-hash",
|
||||||
|
"shlex",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bitflags"
|
||||||
|
version = "2.8.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bumpalo"
|
name = "bumpalo"
|
||||||
version = "3.16.0"
|
version = "3.16.0"
|
||||||
@@ -67,6 +93,15 @@ version = "0.3.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
|
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cexpr"
|
||||||
|
version = "0.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
|
||||||
|
dependencies = [
|
||||||
|
"nom",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cfg-if"
|
name = "cfg-if"
|
||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
@@ -100,6 +135,17 @@ dependencies = [
|
|||||||
"half",
|
"half",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clang-sys"
|
||||||
|
version = "1.8.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
|
||||||
|
dependencies = [
|
||||||
|
"glob",
|
||||||
|
"libc",
|
||||||
|
"libloading",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap"
|
name = "clap"
|
||||||
version = "4.5.23"
|
version = "4.5.23"
|
||||||
@@ -215,6 +261,12 @@ dependencies = [
|
|||||||
"wasi",
|
"wasi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "glob"
|
||||||
|
version = "0.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "half"
|
name = "half"
|
||||||
version = "2.4.1"
|
version = "2.4.1"
|
||||||
@@ -288,6 +340,16 @@ version = "0.2.167"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc"
|
checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libloading"
|
||||||
|
version = "0.8.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"windows-targets",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libm"
|
name = "libm"
|
||||||
version = "0.2.11"
|
version = "0.2.11"
|
||||||
@@ -334,6 +396,12 @@ version = "2.7.4"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
|
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "minimal-lexical"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ndarray"
|
name = "ndarray"
|
||||||
version = "0.16.1"
|
version = "0.16.1"
|
||||||
@@ -349,6 +417,16 @@ dependencies = [
|
|||||||
"rawpointer",
|
"rawpointer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nom"
|
||||||
|
version = "7.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
"minimal-lexical",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num"
|
name = "num"
|
||||||
version = "0.4.3"
|
version = "0.4.3"
|
||||||
@@ -507,6 +585,16 @@ dependencies = [
|
|||||||
"zerocopy",
|
"zerocopy",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "prettyplease"
|
||||||
|
version = "0.2.29"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "primality-test"
|
name = "primality-test"
|
||||||
version = "0.3.0"
|
version = "0.3.0"
|
||||||
@@ -637,6 +725,12 @@ version = "0.8.5"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustc-hash"
|
||||||
|
version = "2.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ryu"
|
name = "ryu"
|
||||||
version = "1.0.18"
|
version = "1.0.18"
|
||||||
@@ -692,12 +786,27 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "shlex"
|
||||||
|
version = "1.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "smallvec"
|
name = "smallvec"
|
||||||
version = "1.13.2"
|
version = "1.13.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
|
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "spqlios"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"bindgen",
|
||||||
|
"itertools 0.14.0",
|
||||||
|
"sampling",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sprs"
|
name = "sprs"
|
||||||
version = "0.11.2"
|
version = "0.11.2"
|
||||||
@@ -715,9 +824,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "2.0.90"
|
version = "2.0.96"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31"
|
checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = ["math", "sampling", "utils"]
|
members = ["math", "sampling", "spqlios", "utils"]
|
||||||
|
|||||||
11
spqlios/Cargo.toml
Normal file
11
spqlios/Cargo.toml
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
[package]
|
||||||
|
name = "spqlios"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
itertools = "0.14.0"
|
||||||
|
sampling = { path = "../sampling" }
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
bindgen = "0.71.1"
|
||||||
52
spqlios/build.rs
Normal file
52
spqlios/build.rs
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
use bindgen;
|
||||||
|
use std::env;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::absolute;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// Path to the C header file
|
||||||
|
let header_paths = [
|
||||||
|
"lib/spqlios/coeffs/coeffs_arithmetic.h",
|
||||||
|
"lib/spqlios/arithmetic/vec_znx_arithmetic.h",
|
||||||
|
];
|
||||||
|
|
||||||
|
let out_path: PathBuf = PathBuf::from(env::var("OUT_DIR").unwrap());
|
||||||
|
let bindings_file = out_path.join("bindings.rs");
|
||||||
|
|
||||||
|
let regenerate: bool = header_paths.iter().any(|header| {
|
||||||
|
let header_metadata: SystemTime = fs::metadata(header)
|
||||||
|
.and_then(|m| m.modified())
|
||||||
|
.unwrap_or(SystemTime::UNIX_EPOCH);
|
||||||
|
let bindings_metadata: SystemTime = fs::metadata(&bindings_file)
|
||||||
|
.and_then(|m| m.modified())
|
||||||
|
.unwrap_or(SystemTime::UNIX_EPOCH);
|
||||||
|
header_metadata > bindings_metadata
|
||||||
|
});
|
||||||
|
|
||||||
|
if regenerate {
|
||||||
|
// Generate the Rust bindings
|
||||||
|
let mut builder: bindgen::Builder = bindgen::Builder::default();
|
||||||
|
for header in header_paths {
|
||||||
|
builder = builder.header(header);
|
||||||
|
}
|
||||||
|
|
||||||
|
let bindings = builder
|
||||||
|
.generate_comments(false) // Optional: includes comments in bindings
|
||||||
|
.generate_inline_functions(true) // Optional: includes inline functions
|
||||||
|
.generate()
|
||||||
|
.expect("Unable to generate bindings");
|
||||||
|
|
||||||
|
// Write the bindings to the OUT_DIR
|
||||||
|
bindings
|
||||||
|
.write_to_file(&bindings_file)
|
||||||
|
.expect("Couldn't write bindings!");
|
||||||
|
}
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"cargo:rustc-link-search=native={}",
|
||||||
|
absolute("./lib/build/spqlios").unwrap().to_str().unwrap()
|
||||||
|
);
|
||||||
|
println!("cargo:rustc-link-lib=static=spqlios"); //"cargo:rustc-link-lib=dylib=spqlios"
|
||||||
|
}
|
||||||
57
spqlios/examples/fft.rs
Normal file
57
spqlios/examples/fft.rs
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
use std::ffi::c_void;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use spqlios::bindings::*;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let log_bound: usize = 19;
|
||||||
|
|
||||||
|
let n: usize = 2048;
|
||||||
|
let m: usize = n >> 1;
|
||||||
|
|
||||||
|
let mut a: Vec<i64> = vec![i64::default(); n];
|
||||||
|
let mut b: Vec<i64> = vec![i64::default(); n];
|
||||||
|
let mut c: Vec<i64> = vec![i64::default(); n];
|
||||||
|
|
||||||
|
a.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||||
|
b[1] = 1;
|
||||||
|
|
||||||
|
println!("{:?}", b);
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let reim_fft_precomp = new_reim_fft_precomp(m as u32, 2);
|
||||||
|
let reim_ifft_precomp = new_reim_ifft_precomp(m as u32, 1);
|
||||||
|
|
||||||
|
let buf_a = reim_fft_precomp_get_buffer(reim_fft_precomp, 0);
|
||||||
|
let buf_b = reim_fft_precomp_get_buffer(reim_fft_precomp, 1);
|
||||||
|
let buf_c = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0);
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
(0..1024).for_each(|_| {
|
||||||
|
reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr());
|
||||||
|
reim_fft(reim_fft_precomp, buf_a);
|
||||||
|
|
||||||
|
reim_from_znx64_simple(m as u32, log_bound as u32, buf_b as *mut c_void, b.as_ptr());
|
||||||
|
reim_fft(reim_fft_precomp, buf_b);
|
||||||
|
|
||||||
|
reim_fftvec_mul_simple(
|
||||||
|
m as u32,
|
||||||
|
buf_c as *mut c_void,
|
||||||
|
buf_a as *mut c_void,
|
||||||
|
buf_b as *mut c_void,
|
||||||
|
);
|
||||||
|
reim_ifft(reim_ifft_precomp, buf_c);
|
||||||
|
|
||||||
|
reim_to_znx64_simple(
|
||||||
|
m as u32,
|
||||||
|
m as f64,
|
||||||
|
log_bound as u32,
|
||||||
|
c.as_mut_ptr(),
|
||||||
|
buf_c as *mut c_void,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
println!("time: {}us", now.elapsed().as_micros());
|
||||||
|
println!("{:?}", &c[..16]);
|
||||||
|
}
|
||||||
|
}
|
||||||
14
spqlios/lib/.clang-format
Normal file
14
spqlios/lib/.clang-format
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# Use the Google style in this project.
|
||||||
|
BasedOnStyle: Google
|
||||||
|
|
||||||
|
# Some folks prefer to write "int& foo" while others prefer "int &foo". The
|
||||||
|
# Google Style Guide only asks for consistency within a project, we chose
|
||||||
|
# "int& foo" for this project:
|
||||||
|
DerivePointerAlignment: false
|
||||||
|
PointerAlignment: Left
|
||||||
|
|
||||||
|
# The Google Style Guide only asks for consistency w.r.t. "east const" vs.
|
||||||
|
# "const west" alignment of cv-qualifiers. In this project we use "east const".
|
||||||
|
QualifierAlignment: Left
|
||||||
|
|
||||||
|
ColumnLimit: 120
|
||||||
4
spqlios/lib/.gitignore
vendored
Normal file
4
spqlios/lib/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
cmake-build-*
|
||||||
|
.idea
|
||||||
|
|
||||||
|
build/
|
||||||
69
spqlios/lib/CMakeLists.txt
Normal file
69
spqlios/lib/CMakeLists.txt
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.8)
|
||||||
|
project(spqlios)
|
||||||
|
|
||||||
|
# read the current version from the manifest file
|
||||||
|
file(READ "manifest.yaml" manifest)
|
||||||
|
string(REGEX MATCH "version: +(([0-9]+)\\.([0-9]+)\\.([0-9]+))" SPQLIOS_VERSION_BLAH ${manifest})
|
||||||
|
#message(STATUS "Version: ${SPQLIOS_VERSION_BLAH}")
|
||||||
|
set(SPQLIOS_VERSION ${CMAKE_MATCH_1})
|
||||||
|
set(SPQLIOS_VERSION_MAJOR ${CMAKE_MATCH_2})
|
||||||
|
set(SPQLIOS_VERSION_MINOR ${CMAKE_MATCH_3})
|
||||||
|
set(SPQLIOS_VERSION_PATCH ${CMAKE_MATCH_4})
|
||||||
|
message(STATUS "Compiling spqlios-fft version: ${SPQLIOS_VERSION_MAJOR}.${SPQLIOS_VERSION_MINOR}.${SPQLIOS_VERSION_PATCH}")
|
||||||
|
|
||||||
|
#set(ENABLE_SPQLIOS_F128 ON CACHE BOOL "Enable float128 via libquadmath")
|
||||||
|
set(WARNING_PARANOID ON CACHE BOOL "Treat all warnings as errors")
|
||||||
|
set(ENABLE_TESTING ON CACHE BOOL "Compiles unittests and integration tests")
|
||||||
|
set(DEVMODE_INSTALL OFF CACHE BOOL "Install private headers and testlib (mainly for CI)")
|
||||||
|
|
||||||
|
if (NOT CMAKE_BUILD_TYPE OR CMAKE_BUILD_TYPE STREQUAL "")
|
||||||
|
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type: Release or Debug" FORCE)
|
||||||
|
endif()
|
||||||
|
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||||
|
|
||||||
|
if (WARNING_PARANOID)
|
||||||
|
add_compile_options(-Wall -Werror -Wno-unused-command-line-argument)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
message(STATUS "CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}")
|
||||||
|
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||||
|
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
|
||||||
|
|
||||||
|
if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)")
|
||||||
|
set(X86 ON)
|
||||||
|
set(AARCH64 OFF)
|
||||||
|
else ()
|
||||||
|
set(X86 OFF)
|
||||||
|
# set(ENABLE_SPQLIOS_F128 OFF) # float128 are only supported for x86 targets
|
||||||
|
endif ()
|
||||||
|
if (CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)")
|
||||||
|
set(AARCH64 ON)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "(Windows)|(MSYS)")
|
||||||
|
set(WIN32 ON)
|
||||||
|
endif ()
|
||||||
|
if (WIN32)
|
||||||
|
#overrides for win32
|
||||||
|
set(X86 OFF)
|
||||||
|
set(AARCH64 OFF)
|
||||||
|
set(X86_WIN32 ON)
|
||||||
|
else()
|
||||||
|
set(X86_WIN32 OFF)
|
||||||
|
set(WIN32 OFF)
|
||||||
|
endif (WIN32)
|
||||||
|
|
||||||
|
message(STATUS "--> WIN32: ${WIN32}")
|
||||||
|
message(STATUS "--> X86_WIN32: ${X86_WIN32}")
|
||||||
|
message(STATUS "--> X86_LINUX: ${X86}")
|
||||||
|
message(STATUS "--> AARCH64: ${AARCH64}")
|
||||||
|
|
||||||
|
# compiles the main library in spqlios
|
||||||
|
add_subdirectory(spqlios)
|
||||||
|
|
||||||
|
# compiles and activates unittests and itests
|
||||||
|
if (${ENABLE_TESTING})
|
||||||
|
enable_testing()
|
||||||
|
add_subdirectory(test)
|
||||||
|
endif()
|
||||||
|
|
||||||
77
spqlios/lib/CONTRIBUTING.md
Normal file
77
spqlios/lib/CONTRIBUTING.md
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# Contributing to SPQlios-fft
|
||||||
|
|
||||||
|
The spqlios-fft team encourages contributions.
|
||||||
|
We encourage users to fix bugs, improve the documentation, write tests and to enhance the code, or ask for new features.
|
||||||
|
We encourage researchers to contribute with implementations of their FFT or NTT algorithms.
|
||||||
|
In the following we are trying to give some guidance on how to contribute effectively.
|
||||||
|
|
||||||
|
## Communication ##
|
||||||
|
|
||||||
|
Communication in the spqlios-fft project happens mainly on [GitHub](https://github.com/tfhe/spqlios-fft/issues).
|
||||||
|
|
||||||
|
All communications are public, so please make sure to maintain professional behaviour in
|
||||||
|
all published comments. See [Code of Conduct](https://www.contributor-covenant.org/version/2/1/code_of_conduct/) for
|
||||||
|
guidelines.
|
||||||
|
|
||||||
|
## Reporting Bugs or Requesting features ##
|
||||||
|
|
||||||
|
Bug should be filed at [https://github.com/tfhe/spqlios-fft/issues](https://github.com/tfhe/spqlios-fft/issues).
|
||||||
|
|
||||||
|
Features can also be requested there, in this case, please ensure that the features you request are self-contained,
|
||||||
|
easy to define, and generic enough to be used in different use-cases. Please provide an example of use-cases if
|
||||||
|
possible.
|
||||||
|
|
||||||
|
## Setting up topic branches and generating pull requests
|
||||||
|
|
||||||
|
This section applies to people that already have write access to the repository. Specific instructions for pull-requests
|
||||||
|
from public forks will be given later.
|
||||||
|
|
||||||
|
To implement some changes, please follow these steps:
|
||||||
|
|
||||||
|
- Create a "topic branch". Usually, the branch name should be `username/small-title`
|
||||||
|
or better `username/issuenumber-small-title` where `issuenumber` is the number of
|
||||||
|
the github issue number that is tackled.
|
||||||
|
- Push any needed commits to your branch. Make sure it compiles in `CMAKE_BUILD_TYPE=Debug` and `=Release`, with `-DWARNING_PARANOID=ON`.
|
||||||
|
- When the branch is nearly ready for review, please open a pull request, and add the label `check-on-arm`
|
||||||
|
- Do as many commits as necessary until all CI checks pass and all PR comments have been resolved.
|
||||||
|
|
||||||
|
> _During the process, you may optionnally use `git rebase -i` to clean up your commit history. If you elect to do so,
|
||||||
|
please at the very least make sure that nobody else is working or has forked from your branch: the conflicts it would generate
|
||||||
|
and the human hours to fix them are not worth it. `Git merge` remains the preferred option._
|
||||||
|
|
||||||
|
- Finally, when all reviews are positive and all CI checks pass, you may merge your branch via the github webpage.
|
||||||
|
|
||||||
|
### Keep your pull requests limited to a single issue
|
||||||
|
|
||||||
|
Pull requests should be as small/atomic as possible.
|
||||||
|
|
||||||
|
### Coding Conventions
|
||||||
|
|
||||||
|
* Please make sure that your code is formatted according to the `.clang-format` file and
|
||||||
|
that all files end with a newline character.
|
||||||
|
* Please make sure that all the functions declared in the public api have relevant doxygen comments.
|
||||||
|
Preferably, functions in the private apis should also contain a brief doxygen description.
|
||||||
|
|
||||||
|
### Versions and History
|
||||||
|
|
||||||
|
* **Stable API** The project uses semantic versioning on the functions that are listed as `stable` in the documentation. A version has
|
||||||
|
the form `x.y.z`
|
||||||
|
* a patch release that increments `z` does not modify the stable API.
|
||||||
|
* a minor release that increments `y` adds a new feature to the stable API.
|
||||||
|
* In the unlikely case where we need to change or remove a feature, we will trigger a major release that
|
||||||
|
increments `x`.
|
||||||
|
|
||||||
|
> _If any, we will mark those features as deprecated at least six months before the major release._
|
||||||
|
|
||||||
|
* **Experimental API** Features that are not part of the stable section in the documentation are experimental features: you may test them at
|
||||||
|
your own risk,
|
||||||
|
but keep in mind that semantic versioning does not apply to them.
|
||||||
|
|
||||||
|
> _If you have a use-case that uses an experimental feature, we encourage
|
||||||
|
> you to tell us about it, so that this feature reaches to the stable section faster!_
|
||||||
|
|
||||||
|
* **Version history** The current version is reported in `manifest.yaml`, any change of version comes up with a tag on the main branch, and the history between releases is summarized in `Changelog.md`. It is the main source of truth for anyone who wishes to
|
||||||
|
get insight about
|
||||||
|
the history of the repository (not the commit graph).
|
||||||
|
|
||||||
|
> Note: _The commit graph of git is for git's internal use only. Its main purpose is to reduce potential merge conflicts to a minimum, even in scenario where multiple features are developped in parallel: it may therefore be non-linear. If, as humans, we like to see a linear history, please read `Changelog.md` instead!_
|
||||||
18
spqlios/lib/Changelog.md
Normal file
18
spqlios/lib/Changelog.md
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# Changelog
|
||||||
|
|
||||||
|
All notable changes to this project will be documented in this file.
|
||||||
|
this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [2.0.0] - 2024-08-21
|
||||||
|
|
||||||
|
- Initial release of the `vec_znx` (except convolution products), `vec_rnx` and `zn` apis.
|
||||||
|
- Hardware acceleration available: AVX2 (most parts)
|
||||||
|
- APIs are documented in the wiki and are in "beta mode": during the 2.x -> 3.x transition, functions whose API is satisfactory in test projects will pass in "stable mode".
|
||||||
|
|
||||||
|
## [1.0.0] - 2023-07-18
|
||||||
|
|
||||||
|
- Initial release of the double precision fft on the reim and cplx backends
|
||||||
|
- Coeffs-space conversions cplx <-> znx32 and tnx32
|
||||||
|
- FFT-space conversions cplx <-> reim4 layouts
|
||||||
|
- FFT-space multiplications on the cplx, reim and reim4 layouts.
|
||||||
|
- In this first release, the only platform supported is linux x86_64 (generic C code, and avx2/fma). It compiles on arm64, but without any acceleration.
|
||||||
201
spqlios/lib/LICENSE
Normal file
201
spqlios/lib/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
65
spqlios/lib/README.md
Normal file
65
spqlios/lib/README.md
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
# SPQlios library
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
The SPQlios library provides fast arithmetic for Fully Homomorphic Encryption, and other lattice constructions that arise in post quantum cryptography.
|
||||||
|
|
||||||
|
<img src="docs/api-full.svg">
|
||||||
|
|
||||||
|
Namely, it is divided into 4 sections:
|
||||||
|
|
||||||
|
* The low-level DFT section support FFT over 64-bit floats, as well as NTT modulo one fixed 120-bit modulus. It is an upgrade of the original spqlios-fft module embedded in the TFHE library since 2016. The DFT section exposes the traditional DFT, inverse-DFT, and coefficient-wise multiplications in DFT space.
|
||||||
|
* The VEC_ZNX section exposes fast algebra over vectors of small integer polynomial modulo $X^N+1$. It proposed in particular efficient (prepared) vector-matrix products, scalar-vector products, convolution products, and element-wise products, operations that naturally occurs on gadget-decomposed Ring-LWE coordinates.
|
||||||
|
* The RNX section is a simpler variant of VEC_ZNX, to represent single polynomials modulo $X^N+1$ (over the reals or over the torus) when the coefficient precision fits on 64-bit doubles. The small vector-matrix API of the RNX section is particularly adapted to reproducing the fastest CGGI-based bootstrappings.
|
||||||
|
* The ZN section focuses over vector and matrix algebra over scalars (used by scalar LWE, or scalar key-switches, but also on non-ring schemes like Frodo, FrodoPIR, and SimplePIR).
|
||||||
|
|
||||||
|
### A high value target for hardware accelerations
|
||||||
|
|
||||||
|
SPQlios is more than a library, it is also a good target for hardware developers.
|
||||||
|
On one hand, the arithmetic operations that are defined in the library have a clear standalone mathematical definition. And at the same time, the amount of work in each operations is sufficiently large so that meaningful functions only require a few of these.
|
||||||
|
|
||||||
|
This makes the SPQlios API a high value target for hardware acceleration, that targets FHE.
|
||||||
|
|
||||||
|
### SPQLios is not an FHE library, but a huge enabler
|
||||||
|
|
||||||
|
SPQlios itself is not an FHE library: there is no ciphertext, plaintext or key. It is a mathematical library that exposes efficient algebra over polynomials. Using the functions exposed, it is possible to quickly build efficient FHE libraries, with support for the main schemes based on Ring-LWE: BFV, BGV, CGGI, DM, CKKS.
|
||||||
|
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
The SPQLIOS-FFT library is a C library that can be compiled with a standard C compiler, and depends only on libc and libm. The API
|
||||||
|
interface can be used in a regular C code, and any other language via classical foreign APIs.
|
||||||
|
|
||||||
|
The unittests and integration tests are in an optional part of the code, and are written in C++. These tests rely on
|
||||||
|
[```benchmark```](https://github.com/google/benchmark), and [```gtest```](https://github.com/google/googletest) libraries, and therefore require a C++17 compiler.
|
||||||
|
|
||||||
|
Currently, the project has been tested with the gcc,g++ >= 11.3.0 compiler under Linux (x86_64). In the future, we plan to
|
||||||
|
extend the compatibility to other compilers, platforms and operating systems.
|
||||||
|
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
The library uses a classical ```cmake``` build mechanism: use ```cmake``` to create a ```build``` folder in the top level directory and run ```make``` from inside it. This assumes that the standard tool ```cmake``` is already installed on the system, and an up-to-date c++ compiler (i.e. g++ >=11.3.0) as well.
|
||||||
|
|
||||||
|
It will compile the shared library in optimized mode, and ```make install``` install it to the desired prefix folder (by default ```/usr/local/lib```).
|
||||||
|
|
||||||
|
If you want to choose additional compile options (i.e. other installation folder, debug mode, tests), you need to run cmake manually and pass the desired options:
|
||||||
|
```
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
cmake ../src -CMAKE_INSTALL_PREFIX=/usr/
|
||||||
|
make
|
||||||
|
```
|
||||||
|
The available options are the following:
|
||||||
|
|
||||||
|
| Variable Name | values |
|
||||||
|
| -------------------- | ------------------------------------------------------------ |
|
||||||
|
| CMAKE_INSTALL_PREFIX | */usr/local* installation folder (libs go in lib/ and headers in include/) |
|
||||||
|
| WARNING_PARANOID | All warnings are shown and treated as errors. Off by default |
|
||||||
|
| ENABLE_TESTING | Compiles unit tests and integration tests |
|
||||||
|
|
||||||
|
------
|
||||||
|
|
||||||
|
<img src="docs/logo-sandboxaq-black.svg">
|
||||||
|
|
||||||
|
<img src="docs/logo-inpher1.png">
|
||||||
416
spqlios/lib/docs/api-full.svg
Normal file
416
spqlios/lib/docs/api-full.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 550 KiB |
BIN
spqlios/lib/docs/logo-inpher1.png
Normal file
BIN
spqlios/lib/docs/logo-inpher1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
BIN
spqlios/lib/docs/logo-inpher2.png
Normal file
BIN
spqlios/lib/docs/logo-inpher2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
139
spqlios/lib/docs/logo-sandboxaq-black.svg
Normal file
139
spqlios/lib/docs/logo-sandboxaq-black.svg
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||||
|
<!-- Generator: Adobe Illustrator 24.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||||
|
|
||||||
|
<svg
|
||||||
|
version="1.1"
|
||||||
|
id="Layer_1"
|
||||||
|
x="0px"
|
||||||
|
y="0px"
|
||||||
|
viewBox="0 0 270 49.4"
|
||||||
|
style="enable-background:new 0 0 270 49.4;"
|
||||||
|
xml:space="preserve"
|
||||||
|
sodipodi:docname="logo-sandboxaq-black.svg"
|
||||||
|
inkscape:version="1.3.2 (1:1.3.2+202311252150+091e20ef0f)"
|
||||||
|
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||||
|
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
xmlns:svg="http://www.w3.org/2000/svg"><defs
|
||||||
|
id="defs9839">
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
</defs><sodipodi:namedview
|
||||||
|
id="namedview9837"
|
||||||
|
pagecolor="#ffffff"
|
||||||
|
bordercolor="#000000"
|
||||||
|
borderopacity="0.25"
|
||||||
|
inkscape:showpageshadow="2"
|
||||||
|
inkscape:pageopacity="0.0"
|
||||||
|
inkscape:pagecheckerboard="0"
|
||||||
|
inkscape:deskcolor="#d1d1d1"
|
||||||
|
showgrid="false"
|
||||||
|
inkscape:zoom="1.194332"
|
||||||
|
inkscape:cx="135.64068"
|
||||||
|
inkscape:cy="25.118645"
|
||||||
|
inkscape:window-width="804"
|
||||||
|
inkscape:window-height="436"
|
||||||
|
inkscape:window-x="190"
|
||||||
|
inkscape:window-y="27"
|
||||||
|
inkscape:window-maximized="0"
|
||||||
|
inkscape:current-layer="Layer_1" />
|
||||||
|
<style
|
||||||
|
type="text/css"
|
||||||
|
id="style9786">
|
||||||
|
.st0{fill:#EBB028;}
|
||||||
|
.st1{fill:#FFFFFF;}
|
||||||
|
</style>
|
||||||
|
<text
|
||||||
|
transform="matrix(1 0 0 1 393.832 -491.944)"
|
||||||
|
class="st1"
|
||||||
|
style="font-family:'Satoshi-Medium'; font-size:86.2078px;"
|
||||||
|
id="text9788">SANDBOX </text>
|
||||||
|
<text
|
||||||
|
transform="matrix(1 0 0 1 896.332 -491.944)"
|
||||||
|
class="st1"
|
||||||
|
style="font-family:'Satoshi-Black'; font-size:86.2078px;"
|
||||||
|
id="text9790">AQ</text>
|
||||||
|
<g
|
||||||
|
id="g9808">
|
||||||
|
<g
|
||||||
|
id="g9800">
|
||||||
|
<g
|
||||||
|
id="g9798">
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="m 8.9,9.7 v 3.9 l 29.6,17.1 v 2.7 c 0,1.2 -0.6,2.3 -1.6,2.9 L 31,39.8 v -4 L 1.4,18.6 V 15.9 C 1.4,14.7 2,13.6 3.1,13 Z"
|
||||||
|
id="path9792" />
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="M 18.3,45.1 3.1,36.3 C 2.1,35.7 1.4,34.6 1.4,33.4 V 26 L 28,41.4 21.5,45.1 c -0.9,0.6 -2.2,0.6 -3.2,0 z"
|
||||||
|
id="path9794" />
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="m 21.6,4.3 15.2,8.8 c 1,0.6 1.7,1.7 1.7,2.9 v 7.5 L 11.8,8 18.3,4.3 c 1,-0.6 2.3,-0.6 3.3,0 z"
|
||||||
|
id="path9796" />
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<g
|
||||||
|
id="g9806">
|
||||||
|
<polygon
|
||||||
|
class="st0"
|
||||||
|
points="248.1,23.2 248.1,30 251.4,33.8 257.3,33.8 "
|
||||||
|
id="polygon9802" />
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="m 246.9,31 -0.1,-0.1 h -0.1 c -0.2,0 -0.4,0 -0.6,0 -3.5,0 -5.7,-2.6 -5.7,-6.7 0,-4.1 2.2,-6.7 5.7,-6.7 3.5,0 5.7,2.6 5.7,6.7 0,0.3 0,0.6 0,0.9 l 3.6,4.2 c 0.7,-1.5 1,-3.2 1,-5.1 0,-6.5 -4.2,-11 -10.3,-11 -6.1,0 -10.3,4.5 -10.3,11 0,6.5 4.2,11 10.3,11 1.2,0 2.3,-0.2 3.4,-0.5 l 0.5,-0.2 z"
|
||||||
|
id="path9804" />
|
||||||
|
</g>
|
||||||
|
</g><g
|
||||||
|
id="g9824"
|
||||||
|
style="fill:#1a1a1a">
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="m 58.7,13.2 c 4.6,0 7.4,2.5 7.4,6.5 h -4.6 c 0,-1.5 -1.1,-2.4 -2.9,-2.4 -1.9,0 -3.1,0.9 -3.1,2.3 0,1.3 0.7,1.9 2.2,2.2 l 3.2,0.7 c 3.8,0.8 5.6,2.6 5.6,5.9 0,4.1 -3.2,6.8 -8.1,6.8 -4.7,0 -7.8,-2.6 -7.8,-6.5 h 4.6 c 0,1.6 1.1,2.4 3.2,2.4 2.1,0 3.4,-0.8 3.4,-2.2 0,-1.2 -0.5,-1.8 -2,-2.1 l -3.2,-0.7 c -3.8,-0.8 -5.7,-2.9 -5.7,-6.4 0,-3.7 3.2,-6.5 7.8,-6.5 z"
|
||||||
|
id="path9810"
|
||||||
|
style="fill:#1a1a1a" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M 70.4,34.9 78,13.6 h 4.5 l 7.6,21.3 h -4.9 l -1.5,-4.5 h -6.9 l -1.5,4.5 z m 7.7,-8.4 h 4.2 L 80.8,22 c -0.2,-0.7 -0.5,-1.6 -0.6,-2.1 -0.1,0.5 -0.3,1.3 -0.6,2.1 z"
|
||||||
|
id="path9812"
|
||||||
|
style="fill:#1a1a1a" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M 95.3,34.9 V 13.6 h 4.6 l 9,13.5 V 13.6 h 4.6 v 21.3 h -4.6 l -9,-13.5 v 13.5 z"
|
||||||
|
id="path9814"
|
||||||
|
style="fill:#1a1a1a" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M 120.7,34.9 V 13.6 h 8 c 6.2,0 10.6,4.4 10.6,10.7 0,6.2 -4.2,10.6 -10.3,10.6 z m 4.7,-17 v 12.6 h 3.2 c 3.7,0 5.8,-2.3 5.8,-6.3 0,-4 -2.3,-6.4 -6.1,-6.4 h -2.9 z"
|
||||||
|
id="path9816"
|
||||||
|
style="fill:#1a1a1a" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="m 145.4,13.6 h 8.8 c 4.3,0 6.9,2.2 6.9,5.9 0,2.3 -1,3.9 -3,4.8 2.1,0.7 3.2,2.3 3.2,4.7 0,3.8 -2.5,5.9 -7.1,5.9 h -8.8 z m 4.7,4.1 v 4.6 h 3.7 c 1.7,0 2.6,-0.8 2.6,-2.4 0,-1.5 -0.9,-2.3 -2.6,-2.3 h -3.7 z m 0,8.5 v 4.6 h 3.9 c 1.7,0 2.6,-0.8 2.6,-2.4 0,-1.4 -0.9,-2.2 -2.6,-2.2 z"
|
||||||
|
id="path9818"
|
||||||
|
style="fill:#1a1a1a" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="m 176.5,35.2 c -6.1,0 -10.4,-4.5 -10.4,-11 0,-6.5 4.3,-11 10.4,-11 6.2,0 10.4,4.5 10.4,11 0,6.5 -4.2,11 -10.4,11 z m 0.1,-17.5 c -3.4,0 -5.5,2.4 -5.5,6.5 0,4.1 2.1,6.5 5.5,6.5 3.4,0 5.5,-2.5 5.5,-6.5 0,-4 -2.1,-6.5 -5.5,-6.5 z"
|
||||||
|
id="path9820"
|
||||||
|
style="fill:#1a1a1a" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="m 190.4,13.6 h 5.5 l 1.8,2.8 c 0.8,1.2 1.5,2.5 2.5,4.3 l 4.3,-7 h 5.4 l -6.7,10.6 6.7,10.6 h -5.5 L 203,32.7 c -1.1,-1.7 -1.8,-3 -2.8,-4.9 l -4.6,7.1 h -5.5 l 7.1,-10.6 z"
|
||||||
|
id="path9822"
|
||||||
|
style="fill:#1a1a1a" />
|
||||||
|
</g><path
|
||||||
|
class="st0"
|
||||||
|
d="m 229,34.9 h 4.7 L 226,13.6 h -4.3 L 214,34.8 h 4.6 l 1.6,-4.5 h 7.1 z m -5.1,-14.6 c 0,0 0,0 0,0 0,-0.1 0,-0.1 0,0 l 2.2,6.2 h -4.4 z"
|
||||||
|
id="path9826" /><g
|
||||||
|
id="g9832">
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="m 259.5,11.2 h 3.9 v 1 h -1.3 v 3.1 h -1.3 v -3.1 h -1.3 z m 4.5,0 h 1.7 l 0.6,2.5 0.6,-2.5 h 1.7 v 4.1 h -1 v -3.1 l -0.8,3.1 h -0.9 l -0.8,-3.1 v 3.1 h -1 v -4.1 z"
|
||||||
|
id="path9830" />
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 5.0 KiB |
133
spqlios/lib/docs/logo-sandboxaq-white.svg
Normal file
133
spqlios/lib/docs/logo-sandboxaq-white.svg
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||||
|
<!-- Generator: Adobe Illustrator 24.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||||
|
|
||||||
|
<svg
|
||||||
|
version="1.1"
|
||||||
|
id="Layer_1"
|
||||||
|
x="0px"
|
||||||
|
y="0px"
|
||||||
|
viewBox="0 0 270 49.4"
|
||||||
|
style="enable-background:new 0 0 270 49.4;"
|
||||||
|
xml:space="preserve"
|
||||||
|
sodipodi:docname="logo-sandboxaq-white.svg"
|
||||||
|
inkscape:version="1.2.2 (1:1.2.2+202212051551+b0a8486541)"
|
||||||
|
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||||
|
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
xmlns:svg="http://www.w3.org/2000/svg"><defs
|
||||||
|
id="defs9839" /><sodipodi:namedview
|
||||||
|
id="namedview9837"
|
||||||
|
pagecolor="#ffffff"
|
||||||
|
bordercolor="#000000"
|
||||||
|
borderopacity="0.25"
|
||||||
|
inkscape:showpageshadow="2"
|
||||||
|
inkscape:pageopacity="0.0"
|
||||||
|
inkscape:pagecheckerboard="0"
|
||||||
|
inkscape:deskcolor="#d1d1d1"
|
||||||
|
showgrid="false"
|
||||||
|
inkscape:zoom="2.3886639"
|
||||||
|
inkscape:cx="135.22204"
|
||||||
|
inkscape:cy="25.327967"
|
||||||
|
inkscape:window-width="1072"
|
||||||
|
inkscape:window-height="688"
|
||||||
|
inkscape:window-x="0"
|
||||||
|
inkscape:window-y="0"
|
||||||
|
inkscape:window-maximized="1"
|
||||||
|
inkscape:current-layer="Layer_1" />
|
||||||
|
<style
|
||||||
|
type="text/css"
|
||||||
|
id="style9786">
|
||||||
|
.st0{fill:#EBB028;}
|
||||||
|
.st1{fill:#FFFFFF;}
|
||||||
|
</style>
|
||||||
|
<text
|
||||||
|
transform="matrix(1 0 0 1 393.832 -491.944)"
|
||||||
|
class="st1"
|
||||||
|
style="font-family:'Satoshi-Medium'; font-size:86.2078px;"
|
||||||
|
id="text9788">SANDBOX </text>
|
||||||
|
<text
|
||||||
|
transform="matrix(1 0 0 1 896.332 -491.944)"
|
||||||
|
class="st1"
|
||||||
|
style="font-family:'Satoshi-Black'; font-size:86.2078px;"
|
||||||
|
id="text9790">AQ</text>
|
||||||
|
<g
|
||||||
|
id="g9834">
|
||||||
|
<g
|
||||||
|
id="g9828">
|
||||||
|
<g
|
||||||
|
id="g9808">
|
||||||
|
<g
|
||||||
|
id="g9800">
|
||||||
|
<g
|
||||||
|
id="g9798">
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="M8.9,9.7v3.9l29.6,17.1v2.7c0,1.2-0.6,2.3-1.6,2.9L31,39.8v-4L1.4,18.6v-2.7c0-1.2,0.6-2.3,1.7-2.9 L8.9,9.7z"
|
||||||
|
id="path9792" />
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="M18.3,45.1L3.1,36.3c-1-0.6-1.7-1.7-1.7-2.9V26L28,41.4l-6.5,3.7C20.6,45.7,19.3,45.7,18.3,45.1z"
|
||||||
|
id="path9794" />
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="M21.6,4.3l15.2,8.8c1,0.6,1.7,1.7,1.7,2.9v7.5L11.8,8l6.5-3.7C19.3,3.7,20.6,3.7,21.6,4.3z"
|
||||||
|
id="path9796" />
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<g
|
||||||
|
id="g9806">
|
||||||
|
<polygon
|
||||||
|
class="st0"
|
||||||
|
points="248.1,23.2 248.1,30 251.4,33.8 257.3,33.8 "
|
||||||
|
id="polygon9802" />
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="M246.9,31l-0.1-0.1l-0.1,0c-0.2,0-0.4,0-0.6,0c-3.5,0-5.7-2.6-5.7-6.7c0-4.1,2.2-6.7,5.7-6.7 s5.7,2.6,5.7,6.7c0,0.3,0,0.6,0,0.9l3.6,4.2c0.7-1.5,1-3.2,1-5.1c0-6.5-4.2-11-10.3-11c-6.1,0-10.3,4.5-10.3,11s4.2,11,10.3,11 c1.2,0,2.3-0.2,3.4-0.5l0.5-0.2L246.9,31z"
|
||||||
|
id="path9804" />
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<g
|
||||||
|
id="g9824">
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M58.7,13.2c4.6,0,7.4,2.5,7.4,6.5h-4.6c0-1.5-1.1-2.4-2.9-2.4c-1.9,0-3.1,0.9-3.1,2.3c0,1.3,0.7,1.9,2.2,2.2 l3.2,0.7c3.8,0.8,5.6,2.6,5.6,5.9c0,4.1-3.2,6.8-8.1,6.8c-4.7,0-7.8-2.6-7.8-6.5h4.6c0,1.6,1.1,2.4,3.2,2.4 c2.1,0,3.4-0.8,3.4-2.2c0-1.2-0.5-1.8-2-2.1l-3.2-0.7c-3.8-0.8-5.7-2.9-5.7-6.4C50.9,16,54.1,13.2,58.7,13.2z"
|
||||||
|
id="path9810" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M70.4,34.9L78,13.6h4.5l7.6,21.3h-4.9l-1.5-4.5h-6.9l-1.5,4.5H70.4z M78.1,26.5h4.2L80.8,22 c-0.2-0.7-0.5-1.6-0.6-2.1c-0.1,0.5-0.3,1.3-0.6,2.1L78.1,26.5z"
|
||||||
|
id="path9812" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M95.3,34.9V13.6h4.6l9,13.5V13.6h4.6v21.3h-4.6l-9-13.5v13.5H95.3z"
|
||||||
|
id="path9814" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M120.7,34.9V13.6h8c6.2,0,10.6,4.4,10.6,10.7c0,6.2-4.2,10.6-10.3,10.6H120.7z M125.4,17.9v12.6h3.2 c3.7,0,5.8-2.3,5.8-6.3c0-4-2.3-6.4-6.1-6.4H125.4z"
|
||||||
|
id="path9816" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M145.4,13.6h8.8c4.3,0,6.9,2.2,6.9,5.9c0,2.3-1,3.9-3,4.8c2.1,0.7,3.2,2.3,3.2,4.7c0,3.8-2.5,5.9-7.1,5.9 h-8.8V13.6z M150.1,17.7v4.6h3.7c1.7,0,2.6-0.8,2.6-2.4c0-1.5-0.9-2.3-2.6-2.3H150.1z M150.1,26.2v4.6h3.9c1.7,0,2.6-0.8,2.6-2.4 c0-1.4-0.9-2.2-2.6-2.2H150.1z"
|
||||||
|
id="path9818" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M176.5,35.2c-6.1,0-10.4-4.5-10.4-11s4.3-11,10.4-11c6.2,0,10.4,4.5,10.4,11S182.7,35.2,176.5,35.2z M176.6,17.7c-3.4,0-5.5,2.4-5.5,6.5c0,4.1,2.1,6.5,5.5,6.5c3.4,0,5.5-2.5,5.5-6.5C182.1,20.2,180,17.7,176.6,17.7z"
|
||||||
|
id="path9820" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M190.4,13.6h5.5l1.8,2.8c0.8,1.2,1.5,2.5,2.5,4.3l4.3-7h5.4l-6.7,10.6l6.7,10.6h-5.5l-1.4-2.2 c-1.1-1.7-1.8-3-2.8-4.9l-4.6,7.1h-5.5l7.1-10.6L190.4,13.6z"
|
||||||
|
id="path9822" />
|
||||||
|
</g>
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="M229,34.9h4.7L226,13.6h-4.3l-7.7,21.2h4.6l1.6-4.5h7.1L229,34.9z M223.9,20.3 C223.9,20.3,223.9,20.3,223.9,20.3C223.9,20.2,223.9,20.2,223.9,20.3l2.2,6.2h-4.4L223.9,20.3z"
|
||||||
|
id="path9826" />
|
||||||
|
</g>
|
||||||
|
<g
|
||||||
|
id="g9832">
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M259.5,11.2h3.9v1h-1.3v3.1h-1.3v-3.1h-1.3V11.2L259.5,11.2z M264,11.2h1.7l0.6,2.5l0.6-2.5h1.7v4.1h-1v-3.1 l-0.8,3.1h-0.9l-0.8-3.1v3.1h-1V11.2L264,11.2z"
|
||||||
|
id="path9830" />
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 4.7 KiB |
2
spqlios/lib/manifest.yaml
Normal file
2
spqlios/lib/manifest.yaml
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
library: spqlios-fft
|
||||||
|
version: 2.0.0
|
||||||
27
spqlios/lib/scripts/auto-release.sh
Normal file
27
spqlios/lib/scripts/auto-release.sh
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
# this script generates one tag if there is a version change in manifest.yaml
|
||||||
|
cd `dirname $0`/..
|
||||||
|
if [ "v$1" = "v-y" ]; then
|
||||||
|
echo "production mode!";
|
||||||
|
fi
|
||||||
|
changes=`git diff HEAD~1..HEAD -- manifest.yaml | grep 'version:'`
|
||||||
|
oldversion=$(echo "$changes" | grep '^-version:' | cut '-d ' -f2)
|
||||||
|
version=$(echo "$changes" | grep '^+version:' | cut '-d ' -f2)
|
||||||
|
echo "Versions: $oldversion --> $version"
|
||||||
|
if [ "v$oldversion" = "v$version" ]; then
|
||||||
|
echo "Same version - nothing to do"; exit 0;
|
||||||
|
fi
|
||||||
|
if [ "v$1" = "v-y" ]; then
|
||||||
|
git config user.name github-actions
|
||||||
|
git config user.email github-actions@github.com
|
||||||
|
git tag -a "v$version" -m "Version $version"
|
||||||
|
git push origin "v$version"
|
||||||
|
else
|
||||||
|
cat <<EOF
|
||||||
|
# the script would do:
|
||||||
|
git tag -a "v$version" -m "Version $version"
|
||||||
|
git push origin "v$version"
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
102
spqlios/lib/scripts/ci-pkg
Normal file
102
spqlios/lib/scripts/ci-pkg
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
# ONLY USE A PREFIX YOU ARE CONFIDENT YOU CAN WIPE OUT ENTIRELY
|
||||||
|
CI_INSTALL_PREFIX=/opt/spqlios
|
||||||
|
CI_REPO_URL=https://spq-dav.algonics.net/ci
|
||||||
|
WORKDIR=`pwd`
|
||||||
|
if [ "x$DESTDIR" = "x" ]; then
|
||||||
|
DESTDIR=/
|
||||||
|
else
|
||||||
|
mkdir -p $DESTDIR
|
||||||
|
DESTDIR=`realpath $DESTDIR`
|
||||||
|
fi
|
||||||
|
DIR=`dirname "$0"`
|
||||||
|
cd $DIR/..
|
||||||
|
DIR=`pwd`
|
||||||
|
|
||||||
|
FULL_UNAME=`uname -a | tr '[A-Z]' '[a-z]'`
|
||||||
|
HOST=`echo $FULL_UNAME | sed 's/ .*//'`
|
||||||
|
ARCH=none
|
||||||
|
case "$HOST" in
|
||||||
|
*linux*)
|
||||||
|
DISTRIB=`lsb_release -c | awk '{print $2}' | tr '[A-Z]' '[a-z]'`
|
||||||
|
HOST=linux-$DISTRIB
|
||||||
|
;;
|
||||||
|
*darwin*)
|
||||||
|
HOST=darwin
|
||||||
|
;;
|
||||||
|
*mingw*|*msys*)
|
||||||
|
DISTRIB=`echo $MSYSTEM | tr '[A-Z]' '[a-z]'`
|
||||||
|
HOST=msys64-$DISTRIB
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Host unknown: $HOST";
|
||||||
|
exit 1
|
||||||
|
esac
|
||||||
|
case "$FULL_UNAME" in
|
||||||
|
*x86_64*)
|
||||||
|
ARCH=x86_64
|
||||||
|
;;
|
||||||
|
*aarch64*)
|
||||||
|
ARCH=aarch64
|
||||||
|
;;
|
||||||
|
*arm64*)
|
||||||
|
ARCH=arm64
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Architecture unknown: $FULL_UNAME";
|
||||||
|
exit 1
|
||||||
|
esac
|
||||||
|
UNAME="$HOST-$ARCH"
|
||||||
|
CMH=
|
||||||
|
if [ -d lib/spqlios/.git ]; then
|
||||||
|
CMH=`git submodule status | sed 's/\(..........\).*/\1/'`
|
||||||
|
else
|
||||||
|
CMH=`git rev-parse HEAD | sed 's/\(..........\).*/\1/'`
|
||||||
|
fi
|
||||||
|
FNAME=spqlios-arithmetic-$CMH-$UNAME.tar.gz
|
||||||
|
|
||||||
|
cat <<EOF
|
||||||
|
================= CI MINI-PACKAGER ==================
|
||||||
|
Work Dir: WORKDIR=$WORKDIR
|
||||||
|
Spq Dir: DIR=$DIR
|
||||||
|
Install Root: DESTDIR=$DESTDIR
|
||||||
|
Install Prefix: CI_INSTALL_PREFIX=$CI_INSTALL_PREFIX
|
||||||
|
Archive Name: FNAME=$FNAME
|
||||||
|
CI WebDav: CI_REPO_URL=$CI_REPO_URL
|
||||||
|
=====================================================
|
||||||
|
EOF
|
||||||
|
|
||||||
|
if [ "x$1" = "xcreate" ]; then
|
||||||
|
rm -rf dist
|
||||||
|
cmake -B build -S . -DCMAKE_INSTALL_PREFIX="$CI_INSTALL_PREFIX" -DCMAKE_BUILD_TYPE=Release -DENABLE_TESTING=ON -DWARNING_PARANOID=ON -DDEVMODE_INSTALL=ON || exit 1
|
||||||
|
cmake --build build || exit 1
|
||||||
|
rm -rf "$DIR/dist" 2>/dev/null
|
||||||
|
rm -f "$DIR/$FNAME" 2>/dev/null
|
||||||
|
DESTDIR="$DIR/dist" cmake --install build || exit 1
|
||||||
|
if [ -d "$DIR/dist$CI_INSTALL_PREFIX" ]; then
|
||||||
|
tar -C "$DIR/dist" -cvzf "$DIR/$FNAME" .
|
||||||
|
else
|
||||||
|
# fix since msys can mess up the paths
|
||||||
|
REAL_DEST=`find "$DIR/dist" -type d -exec test -d "{}$CI_INSTALL_PREFIX" \; -print`
|
||||||
|
echo "REAL_DEST: $REAL_DEST"
|
||||||
|
[ -d "$REAL_DEST$CI_INSTALL_PREFIX" ] && tar -C "$REAL_DEST" -cvzf "$DIR/$FNAME" .
|
||||||
|
fi
|
||||||
|
[ -f "$DIR/$FNAME" ] || { echo "failed to create $DIR/$FNAME"; exit 1; }
|
||||||
|
[ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not uploading"; exit 1; }
|
||||||
|
curl -u "$CI_CREDS" -T "$DIR/$FNAME" "$CI_REPO_URL/$FNAME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "x$1" = "xinstall" ]; then
|
||||||
|
[ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not downloading"; exit 1; }
|
||||||
|
# cleaning
|
||||||
|
rm -rf "$DESTDIR$CI_INSTALL_PREFIX"/* 2>/dev/null
|
||||||
|
rm -f "$DIR/$FNAME" 2>/dev/null
|
||||||
|
# downloading
|
||||||
|
curl -u "$CI_CREDS" -o "$DIR/$FNAME" "$CI_REPO_URL/$FNAME"
|
||||||
|
[ -f "$DIR/$FNAME" ] || { echo "failed to download $DIR/$FNAME"; exit 0; }
|
||||||
|
# installing
|
||||||
|
mkdir -p $DESTDIR
|
||||||
|
tar -C "$DESTDIR" -xvzf "$DIR/$FNAME"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
181
spqlios/lib/scripts/prepare-release
Normal file
181
spqlios/lib/scripts/prepare-release
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
#!/usr/bin/perl
|
||||||
|
##
|
||||||
|
## This script will help update manifest.yaml and Changelog.md before a release
|
||||||
|
## Any merge to master that changes the version line in manifest.yaml
|
||||||
|
## is considered as a new release.
|
||||||
|
##
|
||||||
|
## When ready to make a release, please run ./scripts/prepare-release
|
||||||
|
## and commit push the final result!
|
||||||
|
use File::Basename;
|
||||||
|
use Cwd 'abs_path';
|
||||||
|
|
||||||
|
# find its way to the root of git's repository
|
||||||
|
my $scriptsdirname = dirname(abs_path(__FILE__));
|
||||||
|
chdir "$scriptsdirname/..";
|
||||||
|
print "✓ Entering directory:".`pwd`;
|
||||||
|
|
||||||
|
# ensures that the current branch is ahead of origin/main
|
||||||
|
my $diff= `git diff`;
|
||||||
|
chop $diff;
|
||||||
|
if ($diff =~ /./) {
|
||||||
|
die("ERROR: Please commit all the changes before calling the prepare-release script.");
|
||||||
|
} else {
|
||||||
|
print("✓ All changes are comitted.\n");
|
||||||
|
}
|
||||||
|
system("git fetch origin");
|
||||||
|
my $vcount = `git rev-list --left-right --count origin/main...HEAD`;
|
||||||
|
$vcount =~ /^([0-9]+)[ \t]*([0-9]+)$/;
|
||||||
|
if ($2>0) {
|
||||||
|
die("ERROR: the current HEAD is not ahead of origin/main\n. Please use git merge origin/main.");
|
||||||
|
} else {
|
||||||
|
print("✓ Current HEAD is up to date with origin/main.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
mkdir ".changes";
|
||||||
|
my $currentbranch = `git rev-parse --abbrev-ref HEAD`;
|
||||||
|
chop $currentbranch;
|
||||||
|
$currentbranch =~ s/[^a-zA-Z._-]+/-/g;
|
||||||
|
my $changefile=".changes/$currentbranch.md";
|
||||||
|
my $origmanifestfile=".changes/$currentbranch--manifest.yaml";
|
||||||
|
my $origchangelogfile=".changes/$currentbranch--Changelog.md";
|
||||||
|
|
||||||
|
my $exit_code=system("wget -O $origmanifestfile https://raw.githubusercontent.com/tfhe/spqlios-fft/main/manifest.yaml");
|
||||||
|
if ($exit_code!=0 or ! -f $origmanifestfile) {
|
||||||
|
die("ERROR: failed to download manifest.yaml");
|
||||||
|
}
|
||||||
|
$exit_code=system("wget -O $origchangelogfile https://raw.githubusercontent.com/tfhe/spqlios-fft/main/Changelog.md");
|
||||||
|
if ($exit_code!=0 or ! -f $origchangelogfile) {
|
||||||
|
die("ERROR: failed to download Changelog.md");
|
||||||
|
}
|
||||||
|
|
||||||
|
# read the current version (from origin/main manifest)
|
||||||
|
my $vmajor = 0;
|
||||||
|
my $vminor = 0;
|
||||||
|
my $vpatch = 0;
|
||||||
|
my $versionline = `grep '^version: ' $origmanifestfile | cut -d" " -f2`;
|
||||||
|
chop $versionline;
|
||||||
|
if (not $versionline =~ /^([0-9]+)\.([0-9]+)\.([0-9]+)$/) {
|
||||||
|
die("ERROR: invalid version in manifest file: $versionline\n");
|
||||||
|
} else {
|
||||||
|
$vmajor = int($1);
|
||||||
|
$vminor = int($2);
|
||||||
|
$vpatch = int($3);
|
||||||
|
}
|
||||||
|
print "Version in manifest file: $vmajor.$vminor.$vpatch\n";
|
||||||
|
|
||||||
|
if (not -f $changefile) {
|
||||||
|
## create a changes file
|
||||||
|
open F,">$changefile";
|
||||||
|
print F "# Changefile for branch $currentbranch\n\n";
|
||||||
|
print F "## Type of release (major,minor,patch)?\n\n";
|
||||||
|
print F "releasetype: patch\n\n";
|
||||||
|
print F "## What has changed (please edit)?\n\n";
|
||||||
|
print F "- This has changed.\n";
|
||||||
|
close F;
|
||||||
|
}
|
||||||
|
|
||||||
|
system("editor $changefile");
|
||||||
|
|
||||||
|
# compute the new version
|
||||||
|
my $nvmajor;
|
||||||
|
my $nvminor;
|
||||||
|
my $nvpatch;
|
||||||
|
my $changelog;
|
||||||
|
my $recordchangelog=0;
|
||||||
|
open F,"$changefile";
|
||||||
|
while ($line=<F>) {
|
||||||
|
chop $line;
|
||||||
|
if ($recordchangelog) {
|
||||||
|
($line =~ /^$/) and next;
|
||||||
|
$changelog .= "$line\n";
|
||||||
|
next;
|
||||||
|
}
|
||||||
|
if ($line =~ /^releasetype *: *patch *$/) {
|
||||||
|
$nvmajor=$vmajor;
|
||||||
|
$nvminor=$vminor;
|
||||||
|
$nvpatch=$vpatch+1;
|
||||||
|
}
|
||||||
|
if ($line =~ /^releasetype *: *minor *$/) {
|
||||||
|
$nvmajor=$vmajor;
|
||||||
|
$nvminor=$vminor+1;
|
||||||
|
$nvpatch=0;
|
||||||
|
}
|
||||||
|
if ($line =~ /^releasetype *: *major *$/) {
|
||||||
|
$nvmajor=$vmajor+1;
|
||||||
|
$nvminor=0;
|
||||||
|
$nvpatch=0;
|
||||||
|
}
|
||||||
|
if ($line =~ /^## What has changed/) {
|
||||||
|
$recordchangelog=1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close F;
|
||||||
|
print "New version: $nvmajor.$nvminor.$nvpatch\n";
|
||||||
|
print "Changes:\n$changelog";
|
||||||
|
|
||||||
|
# updating manifest.yaml
|
||||||
|
open F,"manifest.yaml";
|
||||||
|
open G,">.changes/manifest.yaml";
|
||||||
|
while ($line=<F>) {
|
||||||
|
if ($line =~ /^version *: */) {
|
||||||
|
print G "version: $nvmajor.$nvminor.$nvpatch\n";
|
||||||
|
next;
|
||||||
|
}
|
||||||
|
print G $line;
|
||||||
|
}
|
||||||
|
close F;
|
||||||
|
close G;
|
||||||
|
# updating Changelog.md
|
||||||
|
open F,"$origchangelogfile";
|
||||||
|
open G,">.changes/Changelog.md";
|
||||||
|
print G <<EOF
|
||||||
|
# Changelog
|
||||||
|
|
||||||
|
All notable changes to this project will be documented in this file.
|
||||||
|
this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
EOF
|
||||||
|
;
|
||||||
|
print G "## [$nvmajor.$nvminor.$nvpatch] - ".`date '+%Y-%m-%d'`."\n";
|
||||||
|
print G "$changelog\n";
|
||||||
|
my $skip_section=1;
|
||||||
|
while ($line=<F>) {
|
||||||
|
if ($line =~ /^## +\[([0-9]+)\.([0-9]+)\.([0-9]+)\] +/) {
|
||||||
|
if ($1>$nvmajor) {
|
||||||
|
die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n");
|
||||||
|
} elsif ($1<$nvmajor) {
|
||||||
|
$skip_section=0;
|
||||||
|
} elsif ($2>$nvminor) {
|
||||||
|
die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n");
|
||||||
|
} elsif ($2<$nvminor) {
|
||||||
|
$skip_section=0;
|
||||||
|
} elsif ($3>$nvpatch) {
|
||||||
|
die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n");
|
||||||
|
} elsif ($2<$nvpatch) {
|
||||||
|
$skip_section=0;
|
||||||
|
} else {
|
||||||
|
$skip_section=1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
($skip_section) and next;
|
||||||
|
print G $line;
|
||||||
|
}
|
||||||
|
close F;
|
||||||
|
close G;
|
||||||
|
|
||||||
|
print "-------------------------------------\n";
|
||||||
|
print "THIS WILL BE UPDATED:\n";
|
||||||
|
print "-------------------------------------\n";
|
||||||
|
system("diff -u manifest.yaml .changes/manifest.yaml");
|
||||||
|
system("diff -u Changelog.md .changes/Changelog.md");
|
||||||
|
print "-------------------------------------\n";
|
||||||
|
print "To proceed: press <enter> otherwise <CTRL+C>\n";
|
||||||
|
my $bla;
|
||||||
|
$bla=<STDIN>;
|
||||||
|
system("cp -vf .changes/manifest.yaml manifest.yaml");
|
||||||
|
system("cp -vf .changes/Changelog.md Changelog.md");
|
||||||
|
system("git commit -a -m \"Update version and changelog.\"");
|
||||||
|
system("git push");
|
||||||
|
print("✓ Changes have been committed and pushed!\n");
|
||||||
|
print("✓ A new release will be created when this branch is merged to main.\n");
|
||||||
|
|
||||||
223
spqlios/lib/spqlios/CMakeLists.txt
Normal file
223
spqlios/lib/spqlios/CMakeLists.txt
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
enable_language(ASM)
|
||||||
|
|
||||||
|
# C source files that are compiled for all targets (i.e. reference code)
|
||||||
|
set(SRCS_GENERIC
|
||||||
|
commons.c
|
||||||
|
commons_private.c
|
||||||
|
coeffs/coeffs_arithmetic.c
|
||||||
|
arithmetic/vec_znx.c
|
||||||
|
arithmetic/vec_znx_dft.c
|
||||||
|
arithmetic/vector_matrix_product.c
|
||||||
|
cplx/cplx_common.c
|
||||||
|
cplx/cplx_conversions.c
|
||||||
|
cplx/cplx_fft_asserts.c
|
||||||
|
cplx/cplx_fft_ref.c
|
||||||
|
cplx/cplx_fftvec_ref.c
|
||||||
|
cplx/cplx_ifft_ref.c
|
||||||
|
cplx/spqlios_cplx_fft.c
|
||||||
|
reim4/reim4_arithmetic_ref.c
|
||||||
|
reim4/reim4_fftvec_addmul_ref.c
|
||||||
|
reim4/reim4_fftvec_conv_ref.c
|
||||||
|
reim/reim_conversions.c
|
||||||
|
reim/reim_fft_ifft.c
|
||||||
|
reim/reim_fft_ref.c
|
||||||
|
reim/reim_fftvec_addmul_ref.c
|
||||||
|
reim/reim_ifft_ref.c
|
||||||
|
reim/reim_ifft_ref.c
|
||||||
|
reim/reim_to_tnx_ref.c
|
||||||
|
q120/q120_ntt.c
|
||||||
|
q120/q120_arithmetic_ref.c
|
||||||
|
q120/q120_arithmetic_simple.c
|
||||||
|
arithmetic/scalar_vector_product.c
|
||||||
|
arithmetic/vec_znx_big.c
|
||||||
|
arithmetic/znx_small.c
|
||||||
|
arithmetic/module_api.c
|
||||||
|
arithmetic/zn_vmp_int8_ref.c
|
||||||
|
arithmetic/zn_vmp_int16_ref.c
|
||||||
|
arithmetic/zn_vmp_int32_ref.c
|
||||||
|
arithmetic/zn_vmp_ref.c
|
||||||
|
arithmetic/zn_api.c
|
||||||
|
arithmetic/zn_conversions_ref.c
|
||||||
|
arithmetic/zn_approxdecomp_ref.c
|
||||||
|
arithmetic/vec_rnx_api.c
|
||||||
|
arithmetic/vec_rnx_conversions_ref.c
|
||||||
|
arithmetic/vec_rnx_svp_ref.c
|
||||||
|
reim/reim_execute.c
|
||||||
|
cplx/cplx_execute.c
|
||||||
|
reim4/reim4_execute.c
|
||||||
|
arithmetic/vec_rnx_arithmetic.c
|
||||||
|
arithmetic/vec_rnx_approxdecomp_ref.c
|
||||||
|
arithmetic/vec_rnx_vmp_ref.c
|
||||||
|
)
|
||||||
|
# C or assembly source files compiled only on x86 targets
|
||||||
|
set(SRCS_X86
|
||||||
|
)
|
||||||
|
# C or assembly source files compiled only on aarch64 targets
|
||||||
|
set(SRCS_AARCH64
|
||||||
|
cplx/cplx_fallbacks_aarch64.c
|
||||||
|
reim/reim_fallbacks_aarch64.c
|
||||||
|
reim4/reim4_fallbacks_aarch64.c
|
||||||
|
q120/q120_fallbacks_aarch64.c
|
||||||
|
reim/reim_fft_neon.c
|
||||||
|
)
|
||||||
|
|
||||||
|
# C or assembly source files compiled only on x86: avx, avx2, fma targets
|
||||||
|
set(SRCS_FMA_C
|
||||||
|
arithmetic/vector_matrix_product_avx.c
|
||||||
|
cplx/cplx_conversions_avx2_fma.c
|
||||||
|
cplx/cplx_fft_avx2_fma.c
|
||||||
|
cplx/cplx_fft_sse.c
|
||||||
|
cplx/cplx_fftvec_avx2_fma.c
|
||||||
|
cplx/cplx_ifft_avx2_fma.c
|
||||||
|
reim4/reim4_arithmetic_avx2.c
|
||||||
|
reim4/reim4_fftvec_conv_fma.c
|
||||||
|
reim4/reim4_fftvec_addmul_fma.c
|
||||||
|
reim/reim_conversions_avx.c
|
||||||
|
reim/reim_fft4_avx_fma.c
|
||||||
|
reim/reim_fft8_avx_fma.c
|
||||||
|
reim/reim_ifft4_avx_fma.c
|
||||||
|
reim/reim_ifft8_avx_fma.c
|
||||||
|
reim/reim_fft_avx2.c
|
||||||
|
reim/reim_ifft_avx2.c
|
||||||
|
reim/reim_to_tnx_avx.c
|
||||||
|
reim/reim_fftvec_addmul_fma.c
|
||||||
|
)
|
||||||
|
set(SRCS_FMA_ASM
|
||||||
|
cplx/cplx_fft16_avx_fma.s
|
||||||
|
cplx/cplx_ifft16_avx_fma.s
|
||||||
|
reim/reim_fft16_avx_fma.s
|
||||||
|
reim/reim_ifft16_avx_fma.s
|
||||||
|
)
|
||||||
|
set(SRCS_FMA_WIN32_ASM
|
||||||
|
cplx/cplx_fft16_avx_fma_win32.s
|
||||||
|
cplx/cplx_ifft16_avx_fma_win32.s
|
||||||
|
reim/reim_fft16_avx_fma_win32.s
|
||||||
|
reim/reim_ifft16_avx_fma_win32.s
|
||||||
|
)
|
||||||
|
set_source_files_properties(${SRCS_FMA_C} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx;-mavx2")
|
||||||
|
set_source_files_properties(${SRCS_FMA_ASM} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx;-mavx2")
|
||||||
|
|
||||||
|
# C or assembly source files compiled only on x86: avx512f/vl/dq + fma targets
|
||||||
|
set(SRCS_AVX512
|
||||||
|
cplx/cplx_fft_avx512.c
|
||||||
|
)
|
||||||
|
set_source_files_properties(${SRCS_AVX512} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx512f;-mavx512vl;-mavx512dq")
|
||||||
|
|
||||||
|
# C or assembly source files compiled only on x86: avx2 + bmi targets
|
||||||
|
set(SRCS_AVX2
|
||||||
|
arithmetic/vec_znx_avx.c
|
||||||
|
coeffs/coeffs_arithmetic_avx.c
|
||||||
|
arithmetic/vec_znx_dft_avx2.c
|
||||||
|
arithmetic/zn_vmp_int8_avx.c
|
||||||
|
arithmetic/zn_vmp_int16_avx.c
|
||||||
|
arithmetic/zn_vmp_int32_avx.c
|
||||||
|
q120/q120_arithmetic_avx2.c
|
||||||
|
q120/q120_ntt_avx2.c
|
||||||
|
arithmetic/vec_rnx_arithmetic_avx.c
|
||||||
|
arithmetic/vec_rnx_approxdecomp_avx.c
|
||||||
|
arithmetic/vec_rnx_vmp_avx.c
|
||||||
|
|
||||||
|
)
|
||||||
|
set_source_files_properties(${SRCS_AVX2} PROPERTIES COMPILE_OPTIONS "-mbmi2;-mavx2")
|
||||||
|
|
||||||
|
# C source files on float128 via libquadmath on x86 targets targets
|
||||||
|
set(SRCS_F128
|
||||||
|
cplx_f128/cplx_fft_f128.c
|
||||||
|
cplx_f128/cplx_fft_f128.h
|
||||||
|
)
|
||||||
|
|
||||||
|
# H header files containing the public API (these headers are installed)
|
||||||
|
set(HEADERSPUBLIC
|
||||||
|
commons.h
|
||||||
|
arithmetic/vec_znx_arithmetic.h
|
||||||
|
arithmetic/vec_rnx_arithmetic.h
|
||||||
|
arithmetic/zn_arithmetic.h
|
||||||
|
cplx/cplx_fft.h
|
||||||
|
reim/reim_fft.h
|
||||||
|
q120/q120_common.h
|
||||||
|
q120/q120_arithmetic.h
|
||||||
|
q120/q120_ntt.h
|
||||||
|
)
|
||||||
|
|
||||||
|
# H header files containing the private API (these headers are used internally)
|
||||||
|
set(HEADERSPRIVATE
|
||||||
|
commons_private.h
|
||||||
|
cplx/cplx_fft_internal.h
|
||||||
|
cplx/cplx_fft_private.h
|
||||||
|
reim4/reim4_arithmetic.h
|
||||||
|
reim4/reim4_fftvec_internal.h
|
||||||
|
reim4/reim4_fftvec_private.h
|
||||||
|
reim4/reim4_fftvec_public.h
|
||||||
|
reim/reim_fft_internal.h
|
||||||
|
reim/reim_fft_private.h
|
||||||
|
q120/q120_arithmetic_private.h
|
||||||
|
q120/q120_ntt_private.h
|
||||||
|
arithmetic/vec_znx_arithmetic.h
|
||||||
|
arithmetic/vec_rnx_arithmetic_private.h
|
||||||
|
arithmetic/vec_rnx_arithmetic_plugin.h
|
||||||
|
arithmetic/zn_arithmetic_private.h
|
||||||
|
arithmetic/zn_arithmetic_plugin.h
|
||||||
|
coeffs/coeffs_arithmetic.h
|
||||||
|
reim/reim_fft_core_template.h
|
||||||
|
)
|
||||||
|
|
||||||
|
set(SPQLIOSSOURCES
|
||||||
|
${SRCS_GENERIC}
|
||||||
|
${HEADERSPUBLIC}
|
||||||
|
${HEADERSPRIVATE}
|
||||||
|
)
|
||||||
|
if (${X86})
|
||||||
|
set(SPQLIOSSOURCES ${SPQLIOSSOURCES}
|
||||||
|
${SRCS_X86}
|
||||||
|
${SRCS_FMA_C}
|
||||||
|
${SRCS_FMA_ASM}
|
||||||
|
${SRCS_AVX2}
|
||||||
|
${SRCS_AVX512}
|
||||||
|
)
|
||||||
|
elseif (${X86_WIN32})
|
||||||
|
set(SPQLIOSSOURCES ${SPQLIOSSOURCES}
|
||||||
|
#${SRCS_X86}
|
||||||
|
${SRCS_FMA_C}
|
||||||
|
${SRCS_FMA_WIN32_ASM}
|
||||||
|
${SRCS_AVX2}
|
||||||
|
${SRCS_AVX512}
|
||||||
|
)
|
||||||
|
elseif (${AARCH64})
|
||||||
|
set(SPQLIOSSOURCES ${SPQLIOSSOURCES}
|
||||||
|
${SRCS_AARCH64}
|
||||||
|
)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
|
||||||
|
set(SPQLIOSLIBDEP
|
||||||
|
m # libmath depencency for cosinus/sinus functions
|
||||||
|
)
|
||||||
|
|
||||||
|
if (ENABLE_SPQLIOS_F128)
|
||||||
|
find_library(quadmath REQUIRED NAMES quadmath)
|
||||||
|
set(SPQLIOSSOURCES ${SPQLIOSSOURCES} ${SRCS_F128})
|
||||||
|
set(SPQLIOSLIBDEP ${SPQLIOSLIBDEP} quadmath)
|
||||||
|
endif (ENABLE_SPQLIOS_F128)
|
||||||
|
|
||||||
|
add_library(libspqlios-static STATIC ${SPQLIOSSOURCES})
|
||||||
|
add_library(libspqlios SHARED ${SPQLIOSSOURCES})
|
||||||
|
set_property(TARGET libspqlios-static PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||||
|
set_property(TARGET libspqlios PROPERTY OUTPUT_NAME spqlios)
|
||||||
|
set_property(TARGET libspqlios-static PROPERTY OUTPUT_NAME spqlios)
|
||||||
|
set_property(TARGET libspqlios PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||||
|
set_property(TARGET libspqlios PROPERTY SOVERSION ${SPQLIOS_VERSION_MAJOR})
|
||||||
|
set_property(TARGET libspqlios PROPERTY VERSION ${SPQLIOS_VERSION})
|
||||||
|
if (NOT APPLE)
|
||||||
|
target_link_options(libspqlios-static PUBLIC -Wl,--no-undefined)
|
||||||
|
target_link_options(libspqlios PUBLIC -Wl,--no-undefined)
|
||||||
|
endif()
|
||||||
|
target_link_libraries(libspqlios ${SPQLIOSLIBDEP})
|
||||||
|
target_link_libraries(libspqlios-static ${SPQLIOSLIBDEP})
|
||||||
|
install(TARGETS libspqlios-static)
|
||||||
|
install(TARGETS libspqlios)
|
||||||
|
|
||||||
|
# install the public headers only
|
||||||
|
foreach (file ${HEADERSPUBLIC})
|
||||||
|
get_filename_component(dir ${file} DIRECTORY)
|
||||||
|
install(FILES ${file} DESTINATION include/spqlios/${dir})
|
||||||
|
endforeach ()
|
||||||
164
spqlios/lib/spqlios/arithmetic/module_api.c
Normal file
164
spqlios/lib/spqlios/arithmetic/module_api.c
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
static void fill_generic_virtual_table(MODULE* module) {
|
||||||
|
// TODO add default ref handler here
|
||||||
|
module->func.vec_znx_zero = vec_znx_zero_ref;
|
||||||
|
module->func.vec_znx_copy = vec_znx_copy_ref;
|
||||||
|
module->func.vec_znx_negate = vec_znx_negate_ref;
|
||||||
|
module->func.vec_znx_add = vec_znx_add_ref;
|
||||||
|
module->func.vec_znx_sub = vec_znx_sub_ref;
|
||||||
|
module->func.vec_znx_rotate = vec_znx_rotate_ref;
|
||||||
|
module->func.vec_znx_automorphism = vec_znx_automorphism_ref;
|
||||||
|
module->func.vec_znx_normalize_base2k = vec_znx_normalize_base2k_ref;
|
||||||
|
module->func.vec_znx_normalize_base2k_tmp_bytes = vec_znx_normalize_base2k_tmp_bytes_ref;
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
// TODO add avx handlers here
|
||||||
|
module->func.vec_znx_negate = vec_znx_negate_avx;
|
||||||
|
module->func.vec_znx_add = vec_znx_add_avx;
|
||||||
|
module->func.vec_znx_sub = vec_znx_sub_avx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void fill_fft64_virtual_table(MODULE* module) {
|
||||||
|
// TODO add default ref handler here
|
||||||
|
// module->func.vec_znx_dft = ...;
|
||||||
|
module->func.vec_znx_big_normalize_base2k = fft64_vec_znx_big_normalize_base2k;
|
||||||
|
module->func.vec_znx_big_normalize_base2k_tmp_bytes = fft64_vec_znx_big_normalize_base2k_tmp_bytes;
|
||||||
|
module->func.vec_znx_big_range_normalize_base2k = fft64_vec_znx_big_range_normalize_base2k;
|
||||||
|
module->func.vec_znx_big_range_normalize_base2k_tmp_bytes = fft64_vec_znx_big_range_normalize_base2k_tmp_bytes;
|
||||||
|
module->func.vec_znx_dft = fft64_vec_znx_dft;
|
||||||
|
module->func.vec_znx_idft = fft64_vec_znx_idft;
|
||||||
|
module->func.vec_znx_idft_tmp_bytes = fft64_vec_znx_idft_tmp_bytes;
|
||||||
|
module->func.vec_znx_idft_tmp_a = fft64_vec_znx_idft_tmp_a;
|
||||||
|
module->func.vec_znx_big_add = fft64_vec_znx_big_add;
|
||||||
|
module->func.vec_znx_big_add_small = fft64_vec_znx_big_add_small;
|
||||||
|
module->func.vec_znx_big_add_small2 = fft64_vec_znx_big_add_small2;
|
||||||
|
module->func.vec_znx_big_sub = fft64_vec_znx_big_sub;
|
||||||
|
module->func.vec_znx_big_sub_small_a = fft64_vec_znx_big_sub_small_a;
|
||||||
|
module->func.vec_znx_big_sub_small_b = fft64_vec_znx_big_sub_small_b;
|
||||||
|
module->func.vec_znx_big_sub_small2 = fft64_vec_znx_big_sub_small2;
|
||||||
|
module->func.vec_znx_big_rotate = fft64_vec_znx_big_rotate;
|
||||||
|
module->func.vec_znx_big_automorphism = fft64_vec_znx_big_automorphism;
|
||||||
|
module->func.svp_prepare = fft64_svp_prepare_ref;
|
||||||
|
module->func.svp_apply_dft = fft64_svp_apply_dft_ref;
|
||||||
|
module->func.znx_small_single_product = fft64_znx_small_single_product;
|
||||||
|
module->func.znx_small_single_product_tmp_bytes = fft64_znx_small_single_product_tmp_bytes;
|
||||||
|
module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_ref;
|
||||||
|
module->func.vmp_prepare_contiguous_tmp_bytes = fft64_vmp_prepare_contiguous_tmp_bytes;
|
||||||
|
module->func.vmp_apply_dft = fft64_vmp_apply_dft_ref;
|
||||||
|
module->func.vmp_apply_dft_tmp_bytes = fft64_vmp_apply_dft_tmp_bytes;
|
||||||
|
module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_ref;
|
||||||
|
module->func.vmp_apply_dft_to_dft_tmp_bytes = fft64_vmp_apply_dft_to_dft_tmp_bytes;
|
||||||
|
module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft;
|
||||||
|
module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft;
|
||||||
|
module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft;
|
||||||
|
module->func.bytes_of_vec_znx_big = fft64_bytes_of_vec_znx_big;
|
||||||
|
module->func.bytes_of_svp_ppol = fft64_bytes_of_svp_ppol;
|
||||||
|
module->func.bytes_of_vmp_pmat = fft64_bytes_of_vmp_pmat;
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
// TODO add avx handlers here
|
||||||
|
// TODO: enable when avx implementation is done
|
||||||
|
module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_avx;
|
||||||
|
module->func.vmp_apply_dft = fft64_vmp_apply_dft_avx;
|
||||||
|
module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_avx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void fill_ntt120_virtual_table(MODULE* module) {
|
||||||
|
// TODO add default ref handler here
|
||||||
|
// module->func.vec_znx_dft = ...;
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
// TODO add avx handlers here
|
||||||
|
module->func.vec_znx_dft = ntt120_vec_znx_dft_avx;
|
||||||
|
module->func.vec_znx_idft = ntt120_vec_znx_idft_avx;
|
||||||
|
module->func.vec_znx_idft_tmp_bytes = ntt120_vec_znx_idft_tmp_bytes_avx;
|
||||||
|
module->func.vec_znx_idft_tmp_a = ntt120_vec_znx_idft_tmp_a_avx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void fill_virtual_table(MODULE* module) {
|
||||||
|
fill_generic_virtual_table(module);
|
||||||
|
switch (module->module_type) {
|
||||||
|
case FFT64:
|
||||||
|
fill_fft64_virtual_table(module);
|
||||||
|
break;
|
||||||
|
case NTT120:
|
||||||
|
fill_ntt120_virtual_table(module);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
NOT_SUPPORTED(); // invalid type
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void fill_fft64_precomp(MODULE* module) {
|
||||||
|
// fill any necessary precomp stuff
|
||||||
|
module->mod.fft64.p_conv = new_reim_from_znx64_precomp(module->m, 50);
|
||||||
|
module->mod.fft64.p_fft = new_reim_fft_precomp(module->m, 0);
|
||||||
|
module->mod.fft64.p_reim_to_znx = new_reim_to_znx64_precomp(module->m, module->m, 63);
|
||||||
|
module->mod.fft64.p_ifft = new_reim_ifft_precomp(module->m, 0);
|
||||||
|
module->mod.fft64.p_addmul = new_reim_fftvec_addmul_precomp(module->m);
|
||||||
|
module->mod.fft64.mul_fft = new_reim_fftvec_mul_precomp(module->m);
|
||||||
|
}
|
||||||
|
static void fill_ntt120_precomp(MODULE* module) {
|
||||||
|
// fill any necessary precomp stuff
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
module->mod.q120.p_ntt = q120_new_ntt_bb_precomp(module->nn);
|
||||||
|
module->mod.q120.p_intt = q120_new_intt_bb_precomp(module->nn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void fill_module_precomp(MODULE* module) {
|
||||||
|
switch (module->module_type) {
|
||||||
|
case FFT64:
|
||||||
|
fill_fft64_precomp(module);
|
||||||
|
break;
|
||||||
|
case NTT120:
|
||||||
|
fill_ntt120_precomp(module);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
NOT_SUPPORTED(); // invalid type
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void fill_module(MODULE* module, uint64_t nn, MODULE_TYPE mtype) {
|
||||||
|
// init to zero to ensure that any non-initialized field bug is detected
|
||||||
|
// by at least a "proper" segfault
|
||||||
|
memset(module, 0, sizeof(MODULE));
|
||||||
|
module->module_type = mtype;
|
||||||
|
module->nn = nn;
|
||||||
|
module->m = nn >> 1;
|
||||||
|
fill_module_precomp(module);
|
||||||
|
fill_virtual_table(module);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT MODULE* new_module_info(uint64_t N, MODULE_TYPE mtype) {
|
||||||
|
MODULE* m = (MODULE*)malloc(sizeof(MODULE));
|
||||||
|
fill_module(m, N, mtype);
|
||||||
|
return m;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void delete_module_info(MODULE* mod) {
|
||||||
|
switch (mod->module_type) {
|
||||||
|
case FFT64:
|
||||||
|
free(mod->mod.fft64.p_conv);
|
||||||
|
free(mod->mod.fft64.p_fft);
|
||||||
|
free(mod->mod.fft64.p_ifft);
|
||||||
|
free(mod->mod.fft64.p_reim_to_znx);
|
||||||
|
free(mod->mod.fft64.mul_fft);
|
||||||
|
free(mod->mod.fft64.p_addmul);
|
||||||
|
break;
|
||||||
|
case NTT120:
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
q120_del_ntt_bb_precomp(mod->mod.q120.p_ntt);
|
||||||
|
q120_del_intt_bb_precomp(mod->mod.q120.p_intt);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
free(mod);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t module_get_n(const MODULE* module) { return module->nn; }
|
||||||
63
spqlios/lib/spqlios/arithmetic/scalar_vector_product.c
Normal file
63
spqlios/lib/spqlios/arithmetic/scalar_vector_product.c
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT uint64_t bytes_of_svp_ppol(const MODULE* module) { return module->func.bytes_of_svp_ppol(module); }
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_svp_ppol(const MODULE* module) { return module->nn * sizeof(double); }
|
||||||
|
|
||||||
|
EXPORT SVP_PPOL* new_svp_ppol(const MODULE* module) { return spqlios_alloc(bytes_of_svp_ppol(module)); }
|
||||||
|
|
||||||
|
EXPORT void delete_svp_ppol(SVP_PPOL* ppol) { spqlios_free(ppol); }
|
||||||
|
|
||||||
|
// public wrappers
|
||||||
|
EXPORT void svp_prepare(const MODULE* module, // N
|
||||||
|
SVP_PPOL* ppol, // output
|
||||||
|
const int64_t* pol // a
|
||||||
|
) {
|
||||||
|
module->func.svp_prepare(module, ppol, pol);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares a svp polynomial */
|
||||||
|
EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N
|
||||||
|
SVP_PPOL* ppol, // output
|
||||||
|
const int64_t* pol // a
|
||||||
|
) {
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, ppol, pol);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, (double*)ppol);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void svp_apply_dft(const MODULE* module, // N
|
||||||
|
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||||
|
const SVP_PPOL* ppol, // prepared pol
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl) {
|
||||||
|
module->func.svp_apply_dft(module, // N
|
||||||
|
res,
|
||||||
|
res_size, // output
|
||||||
|
ppol, // prepared pol
|
||||||
|
a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
// result = ppol * a
|
||||||
|
EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N
|
||||||
|
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||||
|
const SVP_PPOL* ppol, // prepared pol
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
double* const dres = (double*)res;
|
||||||
|
double* const dppol = (double*)ppol;
|
||||||
|
|
||||||
|
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||||
|
const int64_t* a_ptr = a + i * a_sl;
|
||||||
|
double* const res_ptr = dres + i * nn;
|
||||||
|
// copy the polynomial to res, apply fft in place, call fftvec_mul in place.
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, res_ptr, a_ptr);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, res_ptr);
|
||||||
|
reim_fftvec_mul(module->mod.fft64.mul_fft, res_ptr, res_ptr, dppol);
|
||||||
|
}
|
||||||
|
|
||||||
|
// then extend with zeros
|
||||||
|
memset(dres + auto_end_idx * nn, 0, (res_size - auto_end_idx) * nn * sizeof(double));
|
||||||
|
}
|
||||||
318
spqlios/lib/spqlios/arithmetic/vec_rnx_api.c
Normal file
318
spqlios/lib/spqlios/arithmetic/vec_rnx_api.c
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
void fft64_init_rnx_module_precomp(MOD_RNX* module) {
|
||||||
|
// Add here initialization of items that are in the precomp
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
module->precomp.fft64.p_fft = new_reim_fft_precomp(m, 0);
|
||||||
|
module->precomp.fft64.p_ifft = new_reim_ifft_precomp(m, 0);
|
||||||
|
module->precomp.fft64.p_fftvec_mul = new_reim_fftvec_mul_precomp(m);
|
||||||
|
module->precomp.fft64.p_fftvec_addmul = new_reim_fftvec_addmul_precomp(m);
|
||||||
|
}
|
||||||
|
|
||||||
|
void fft64_finalize_rnx_module_precomp(MOD_RNX* module) {
|
||||||
|
// Add here deleters for items that are in the precomp
|
||||||
|
delete_reim_fft_precomp(module->precomp.fft64.p_fft);
|
||||||
|
delete_reim_ifft_precomp(module->precomp.fft64.p_ifft);
|
||||||
|
delete_reim_fftvec_mul_precomp(module->precomp.fft64.p_fftvec_mul);
|
||||||
|
delete_reim_fftvec_addmul_precomp(module->precomp.fft64.p_fftvec_addmul);
|
||||||
|
}
|
||||||
|
|
||||||
|
void fft64_init_rnx_module_vtable(MOD_RNX* module) {
|
||||||
|
// Add function pointers here
|
||||||
|
module->vtable.vec_rnx_add = vec_rnx_add_ref;
|
||||||
|
module->vtable.vec_rnx_zero = vec_rnx_zero_ref;
|
||||||
|
module->vtable.vec_rnx_copy = vec_rnx_copy_ref;
|
||||||
|
module->vtable.vec_rnx_negate = vec_rnx_negate_ref;
|
||||||
|
module->vtable.vec_rnx_sub = vec_rnx_sub_ref;
|
||||||
|
module->vtable.vec_rnx_rotate = vec_rnx_rotate_ref;
|
||||||
|
module->vtable.vec_rnx_automorphism = vec_rnx_automorphism_ref;
|
||||||
|
module->vtable.vec_rnx_mul_xp_minus_one = vec_rnx_mul_xp_minus_one_ref;
|
||||||
|
module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref;
|
||||||
|
module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_ref;
|
||||||
|
module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref;
|
||||||
|
module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_ref;
|
||||||
|
module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref;
|
||||||
|
module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_ref;
|
||||||
|
module->vtable.bytes_of_rnx_vmp_pmat = fft64_bytes_of_rnx_vmp_pmat;
|
||||||
|
module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_ref;
|
||||||
|
module->vtable.vec_rnx_to_znx32 = vec_rnx_to_znx32_ref;
|
||||||
|
module->vtable.vec_rnx_from_znx32 = vec_rnx_from_znx32_ref;
|
||||||
|
module->vtable.vec_rnx_to_tnx32 = vec_rnx_to_tnx32_ref;
|
||||||
|
module->vtable.vec_rnx_from_tnx32 = vec_rnx_from_tnx32_ref;
|
||||||
|
module->vtable.vec_rnx_to_tnxdbl = vec_rnx_to_tnxdbl_ref;
|
||||||
|
module->vtable.bytes_of_rnx_svp_ppol = fft64_bytes_of_rnx_svp_ppol;
|
||||||
|
module->vtable.rnx_svp_prepare = fft64_rnx_svp_prepare_ref;
|
||||||
|
module->vtable.rnx_svp_apply = fft64_rnx_svp_apply_ref;
|
||||||
|
|
||||||
|
// Add optimized function pointers here
|
||||||
|
if (CPU_SUPPORTS("avx")) {
|
||||||
|
module->vtable.vec_rnx_add = vec_rnx_add_avx;
|
||||||
|
module->vtable.vec_rnx_sub = vec_rnx_sub_avx;
|
||||||
|
module->vtable.vec_rnx_negate = vec_rnx_negate_avx;
|
||||||
|
module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx;
|
||||||
|
module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_avx;
|
||||||
|
module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx;
|
||||||
|
module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_avx;
|
||||||
|
module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx;
|
||||||
|
module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_avx;
|
||||||
|
module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_avx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void init_rnx_module_info(MOD_RNX* module, //
|
||||||
|
uint64_t n, RNX_MODULE_TYPE mtype) {
|
||||||
|
memset(module, 0, sizeof(MOD_RNX));
|
||||||
|
module->n = n;
|
||||||
|
module->m = n >> 1;
|
||||||
|
module->mtype = mtype;
|
||||||
|
switch (mtype) {
|
||||||
|
case FFT64:
|
||||||
|
fft64_init_rnx_module_precomp(module);
|
||||||
|
fft64_init_rnx_module_vtable(module);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
NOT_SUPPORTED(); // unknown mtype
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void finalize_rnx_module_info(MOD_RNX* module) {
|
||||||
|
if (module->custom) module->custom_deleter(module->custom);
|
||||||
|
switch (module->mtype) {
|
||||||
|
case FFT64:
|
||||||
|
fft64_finalize_rnx_module_precomp(module);
|
||||||
|
// fft64_finalize_rnx_module_vtable(module); // nothing to finalize
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
NOT_SUPPORTED(); // unknown mtype
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT MOD_RNX* new_rnx_module_info(uint64_t nn, RNX_MODULE_TYPE mtype) {
|
||||||
|
MOD_RNX* res = (MOD_RNX*)malloc(sizeof(MOD_RNX));
|
||||||
|
init_rnx_module_info(res, nn, mtype);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void delete_rnx_module_info(MOD_RNX* module_info) {
|
||||||
|
finalize_rnx_module_info(module_info);
|
||||||
|
free(module_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module) { return module->n; }
|
||||||
|
|
||||||
|
/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */
|
||||||
|
EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols) { // dimensions
|
||||||
|
return (RNX_VMP_PMAT*)spqlios_alloc(bytes_of_rnx_vmp_pmat(module, nrows, ncols));
|
||||||
|
}
|
||||||
|
EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr) { spqlios_free(ptr); }
|
||||||
|
|
||||||
|
//////////////// wrappers //////////////////
|
||||||
|
|
||||||
|
/** @brief sets res = a + b */
|
||||||
|
EXPORT void vec_rnx_add( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_add(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = 0 */
|
||||||
|
EXPORT void vec_rnx_zero( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_zero(module, res, res_size, res_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_rnx_copy( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_copy(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = -a */
|
||||||
|
EXPORT void vec_rnx_negate( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_negate(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a - b */
|
||||||
|
EXPORT void vec_rnx_sub( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_sub(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void vec_rnx_rotate( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_rotate(module, p, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void vec_rnx_automorphism( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int64_t p, // X -> X^p
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_automorphism(module, p, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_mul_xp_minus_one( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_mul_xp_minus_one(module, p, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||||
|
EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols) { // dimensions
|
||||||
|
return module->vtable.bytes_of_rnx_vmp_pmat(module, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void rnx_vmp_prepare_contiguous( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* a, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
module->vtable.rnx_vmp_prepare_contiguous(module, pmat, a, nrows, ncols, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||||
|
EXPORT uint64_t rnx_vmp_prepare_contiguous_tmp_bytes(const MOD_RNX* module) {
|
||||||
|
return module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes(module);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product res = a x pmat */
|
||||||
|
EXPORT void rnx_vmp_apply_tmp_a( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
module->vtable.rnx_vmp_apply_tmp_a(module, res, res_size, res_sl, tmpa, a_size, a_sl, pmat, nrows, ncols, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res size
|
||||||
|
uint64_t a_size, // a size
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix dims
|
||||||
|
) {
|
||||||
|
return module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes(module, res_size, a_size, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT void rnx_vmp_apply_dft_to_dft( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
module->vtable.rnx_vmp_apply_dft_to_dft(module, res, res_size, res_sl, a_dft, a_size, a_sl, pmat, nrows, ncols,
|
||||||
|
tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) {
|
||||||
|
return module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes(module, res_size, a_size, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->vtable.bytes_of_rnx_svp_ppol(module); }
|
||||||
|
|
||||||
|
EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N
|
||||||
|
RNX_SVP_PPOL* ppol, // output
|
||||||
|
const double* pol // a
|
||||||
|
) {
|
||||||
|
module->vtable.rnx_svp_prepare(module, ppol, pol);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void rnx_svp_apply( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||||
|
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.rnx_svp_apply(module, // N
|
||||||
|
res, res_size, res_sl, // output
|
||||||
|
ppol, // prepared pol
|
||||||
|
a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnxdbl( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a) { // a
|
||||||
|
module->vtable.rnx_approxdecomp_from_tnxdbl(module, gadget, res, res_size, res_sl, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_znx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_to_znx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_znx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_from_znx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_to_tnx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_tnx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_from_tnx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnxdbl( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_to_tnxdbl(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
59
spqlios/lib/spqlios/arithmetic/vec_rnx_approxdecomp_avx.c
Normal file
59
spqlios/lib/spqlios/arithmetic/vec_rnx_approxdecomp_avx.c
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "immintrin.h"
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) */
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnxdbl_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
if (nn < 4) return rnx_approxdecomp_from_tnxdbl_ref(module, gadget, res, res_size, res_sl, a);
|
||||||
|
const uint64_t ell = gadget->ell;
|
||||||
|
const __m256i k = _mm256_set1_epi64x(gadget->k);
|
||||||
|
const __m256d add_cst = _mm256_set1_pd(gadget->add_cst);
|
||||||
|
const __m256i and_mask = _mm256_set1_epi64x(gadget->and_mask);
|
||||||
|
const __m256i or_mask = _mm256_set1_epi64x(gadget->or_mask);
|
||||||
|
const __m256d sub_cst = _mm256_set1_pd(gadget->sub_cst);
|
||||||
|
const uint64_t msize = res_size <= ell ? res_size : ell;
|
||||||
|
// gadget decompose column by column
|
||||||
|
if (msize == ell) {
|
||||||
|
// this is the main scenario when msize == ell
|
||||||
|
double* const last_r = res + (msize - 1) * res_sl;
|
||||||
|
for (uint64_t j = 0; j < nn; j += 4) {
|
||||||
|
double* rr = last_r + j;
|
||||||
|
const double* aa = a + j;
|
||||||
|
__m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst);
|
||||||
|
__m256i t_int = _mm256_castpd_si256(t_dbl);
|
||||||
|
do {
|
||||||
|
__m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask);
|
||||||
|
_mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst));
|
||||||
|
t_int = _mm256_srlv_epi64(t_int, k);
|
||||||
|
rr -= res_sl;
|
||||||
|
} while (rr >= res);
|
||||||
|
}
|
||||||
|
} else if (msize > 0) {
|
||||||
|
// otherwise, if msize < ell: there is one additional rshift
|
||||||
|
const __m256i first_rsh = _mm256_set1_epi64x((ell - msize) * gadget->k);
|
||||||
|
double* const last_r = res + (msize - 1) * res_sl;
|
||||||
|
for (uint64_t j = 0; j < nn; j += 4) {
|
||||||
|
double* rr = last_r + j;
|
||||||
|
const double* aa = a + j;
|
||||||
|
__m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst);
|
||||||
|
__m256i t_int = _mm256_srlv_epi64(_mm256_castpd_si256(t_dbl), first_rsh);
|
||||||
|
do {
|
||||||
|
__m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask);
|
||||||
|
_mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst));
|
||||||
|
t_int = _mm256_srlv_epi64(t_int, k);
|
||||||
|
rr -= res_sl;
|
||||||
|
} while (rr >= res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// zero-out the last slices (if any)
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
75
spqlios/lib/spqlios/arithmetic/vec_rnx_approxdecomp_ref.c
Normal file
75
spqlios/lib/spqlios/arithmetic/vec_rnx_approxdecomp_ref.c
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
typedef union di {
|
||||||
|
double dv;
|
||||||
|
uint64_t uv;
|
||||||
|
} di_t;
|
||||||
|
|
||||||
|
/** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */
|
||||||
|
EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t k, uint64_t ell // base 2^K and size
|
||||||
|
) {
|
||||||
|
if (k * ell > 50) return spqlios_error("gadget requires a too large fp precision");
|
||||||
|
TNXDBL_APPROXDECOMP_GADGET* res = spqlios_alloc(sizeof(TNXDBL_APPROXDECOMP_GADGET));
|
||||||
|
res->k = k;
|
||||||
|
res->ell = ell;
|
||||||
|
// double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[)
|
||||||
|
union di add_cst;
|
||||||
|
add_cst.dv = UINT64_C(3) << (51 - ell * k);
|
||||||
|
for (uint64_t i = 0; i < ell; ++i) {
|
||||||
|
add_cst.uv |= UINT64_C(1) << ((i + 1) * k - 1);
|
||||||
|
}
|
||||||
|
res->add_cst = add_cst.dv;
|
||||||
|
// uint64_t and_mask; // uint64(2^(K)-1)
|
||||||
|
res->and_mask = (UINT64_C(1) << k) - 1;
|
||||||
|
// uint64_t or_mask; // double(2^52)
|
||||||
|
union di or_mask;
|
||||||
|
or_mask.dv = (UINT64_C(1) << 52);
|
||||||
|
res->or_mask = or_mask.uv;
|
||||||
|
// double sub_cst; // double(2^52 + 2^(K-1))
|
||||||
|
res->sub_cst = ((UINT64_C(1) << 52) + (UINT64_C(1) << (k - 1)));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget) { spqlios_free(gadget); }
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) */
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnxdbl_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t k = gadget->k;
|
||||||
|
const uint64_t ell = gadget->ell;
|
||||||
|
const double add_cst = gadget->add_cst;
|
||||||
|
const uint64_t and_mask = gadget->and_mask;
|
||||||
|
const uint64_t or_mask = gadget->or_mask;
|
||||||
|
const double sub_cst = gadget->sub_cst;
|
||||||
|
const uint64_t msize = res_size <= ell ? res_size : ell;
|
||||||
|
const uint64_t first_rsh = (ell - msize) * k;
|
||||||
|
// gadget decompose column by column
|
||||||
|
if (msize > 0) {
|
||||||
|
double* const last_r = res + (msize - 1) * res_sl;
|
||||||
|
for (uint64_t j = 0; j < nn; ++j) {
|
||||||
|
double* rr = last_r + j;
|
||||||
|
di_t t = {.dv = a[j] + add_cst};
|
||||||
|
if (msize < ell) t.uv >>= first_rsh;
|
||||||
|
do {
|
||||||
|
di_t u;
|
||||||
|
u.uv = (t.uv & and_mask) | or_mask;
|
||||||
|
*rr = u.dv - sub_cst;
|
||||||
|
t.uv >>= k;
|
||||||
|
rr -= res_sl;
|
||||||
|
} while (rr >= res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// zero-out the last slices (if any)
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
223
spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic.c
Normal file
223
spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic.c
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../coeffs/coeffs_arithmetic.h"
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
void rnx_add_ref(uint64_t nn, double* res, const double* a, const double* b) {
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
res[i] = a[i] + b[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void rnx_sub_ref(uint64_t nn, double* res, const double* a, const double* b) {
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
res[i] = a[i] - b[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void rnx_negate_ref(uint64_t nn, double* res, const double* a) {
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
res[i] = -a[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a + b */
|
||||||
|
EXPORT void vec_rnx_add_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
if (a_size < b_size) {
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_add_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
memcpy(res + i * res_sl, b + i * b_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_add_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = 0 */
|
||||||
|
EXPORT void vec_rnx_zero_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
for (uint64_t i = 0; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_rnx_copy_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
|
||||||
|
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// rotate up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||||
|
double* res_ptr = res + i * res_sl;
|
||||||
|
const double* a_ptr = a + i * a_sl;
|
||||||
|
memcpy(res_ptr, a_ptr, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = -a */
|
||||||
|
EXPORT void vec_rnx_negate_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
|
||||||
|
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// rotate up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||||
|
double* res_ptr = res + i * res_sl;
|
||||||
|
const double* a_ptr = a + i * a_sl;
|
||||||
|
rnx_negate_ref(nn, res_ptr, a_ptr);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a - b */
|
||||||
|
EXPORT void vec_rnx_sub_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
if (a_size < b_size) {
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_sub_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
rnx_negate_ref(nn, res + i * res_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_sub_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void vec_rnx_rotate_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
|
||||||
|
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// rotate up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||||
|
double* res_ptr = res + i * res_sl;
|
||||||
|
const double* a_ptr = a + i * a_sl;
|
||||||
|
if (res_ptr == a_ptr) {
|
||||||
|
rnx_rotate_inplace_f64(nn, p, res_ptr);
|
||||||
|
} else {
|
||||||
|
rnx_rotate_f64(nn, p, res_ptr, a_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void vec_rnx_automorphism_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int64_t p, // X -> X^p
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
|
||||||
|
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// rotate up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||||
|
double* res_ptr = res + i * res_sl;
|
||||||
|
const double* a_ptr = a + i * a_sl;
|
||||||
|
if (res_ptr == a_ptr) {
|
||||||
|
rnx_automorphism_inplace_f64(nn, p, res_ptr);
|
||||||
|
} else {
|
||||||
|
rnx_automorphism_f64(nn, p, res_ptr, a_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a . (X^p - 1) */
|
||||||
|
EXPORT void vec_rnx_mul_xp_minus_one_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
|
||||||
|
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// rotate up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||||
|
double* res_ptr = res + i * res_sl;
|
||||||
|
const double* a_ptr = a + i * a_sl;
|
||||||
|
if (res_ptr == a_ptr) {
|
||||||
|
rnx_mul_xp_minus_one_inplace(nn, p, res_ptr);
|
||||||
|
} else {
|
||||||
|
rnx_mul_xp_minus_one(nn, p, res_ptr, a_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
340
spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic.h
Normal file
340
spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic.h
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_H
|
||||||
|
#define SPQLIOS_VEC_RNX_ARITHMETIC_H
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "../commons.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* We support the following module families:
|
||||||
|
* - FFT64:
|
||||||
|
* the overall precision should fit at all times over 52 bits.
|
||||||
|
*/
|
||||||
|
typedef enum rnx_module_type_t { FFT64 } RNX_MODULE_TYPE;
|
||||||
|
|
||||||
|
/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */
|
||||||
|
typedef struct rnx_module_info_t MOD_RNX;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief obtain a module info for ring dimension N
|
||||||
|
* the module-info knows about:
|
||||||
|
* - the dimension N (or the complex dimension m=N/2)
|
||||||
|
* - any moduleuted fft or ntt items
|
||||||
|
* - the hardware (avx, arm64, x86, ...)
|
||||||
|
*/
|
||||||
|
EXPORT MOD_RNX* new_rnx_module_info(uint64_t N, RNX_MODULE_TYPE mode);
|
||||||
|
EXPORT void delete_rnx_module_info(MOD_RNX* module_info);
|
||||||
|
EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module);
|
||||||
|
|
||||||
|
// basic arithmetic
|
||||||
|
|
||||||
|
/** @brief sets res = 0 */
|
||||||
|
EXPORT void vec_rnx_zero( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_rnx_copy( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = -a */
|
||||||
|
EXPORT void vec_rnx_negate( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a + b */
|
||||||
|
EXPORT void vec_rnx_add( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a - b */
|
||||||
|
EXPORT void vec_rnx_sub( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void vec_rnx_rotate( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a . (X^p - 1) */
|
||||||
|
EXPORT void vec_rnx_mul_xp_minus_one( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void vec_rnx_automorphism( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int64_t p, // X -> X^p
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// conversions //
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_znx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_znx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_tnx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnx32x2( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_tnx32x2( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnxdbl( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// isolated products (n.log(n), but not particularly optimized //
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/** @brief res = a * b : small polynomial product */
|
||||||
|
EXPORT void rnx_small_single_product( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, // output
|
||||||
|
const double* a, // a
|
||||||
|
const double* b, // b
|
||||||
|
uint8_t* tmp); // scratch space
|
||||||
|
|
||||||
|
EXPORT uint64_t rnx_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||||
|
|
||||||
|
/** @brief res = a * b centermod 1: small polynomial product */
|
||||||
|
EXPORT void tnxdbl_small_single_product( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* torus_res, // output
|
||||||
|
const double* int_a, // a
|
||||||
|
const double* torus_b, // b
|
||||||
|
uint8_t* tmp); // scratch space
|
||||||
|
|
||||||
|
EXPORT uint64_t tnxdbl_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||||
|
|
||||||
|
/** @brief res = a * b: small polynomial product */
|
||||||
|
EXPORT void znx32_small_single_product( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* int_res, // output
|
||||||
|
const int32_t* int_a, // a
|
||||||
|
const int32_t* int_b, // b
|
||||||
|
uint8_t* tmp); // scratch space
|
||||||
|
|
||||||
|
EXPORT uint64_t znx32_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||||
|
|
||||||
|
/** @brief res = a * b centermod 1: small polynomial product */
|
||||||
|
EXPORT void tnx32_small_single_product( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* torus_res, // output
|
||||||
|
const int32_t* int_a, // a
|
||||||
|
const int32_t* torus_b, // b
|
||||||
|
uint8_t* tmp); // scratch space
|
||||||
|
|
||||||
|
EXPORT uint64_t tnx32_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// prepared gadget decompositions (optimized) //
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// decompose from tnx32
|
||||||
|
|
||||||
|
typedef struct tnx32_approxdecomp_gadget_t TNX32_APPROXDECOMP_GADGET;
|
||||||
|
|
||||||
|
/** @brief new gadget: delete with delete_tnx32_approxdecomp_gadget */
|
||||||
|
EXPORT TNX32_APPROXDECOMP_GADGET* new_tnx32_approxdecomp_gadget( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t k, uint64_t ell // base 2^K and size
|
||||||
|
);
|
||||||
|
EXPORT void delete_tnx32_approxdecomp_gadget(const MOD_RNX* module);
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) */
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNX32_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a // a
|
||||||
|
);
|
||||||
|
|
||||||
|
// decompose from tnx32x2
|
||||||
|
|
||||||
|
typedef struct tnx32x2_approxdecomp_gadget_t TNX32X2_APPROXDECOMP_GADGET;
|
||||||
|
|
||||||
|
/** @brief new gadget: delete with delete_tnx32x2_approxdecomp_gadget */
|
||||||
|
EXPORT TNX32X2_APPROXDECOMP_GADGET* new_tnx32x2_approxdecomp_gadget(const MOD_RNX* module, uint64_t ka, uint64_t ella,
|
||||||
|
uint64_t kb, uint64_t ellb);
|
||||||
|
EXPORT void delete_tnx32x2_approxdecomp_gadget(const MOD_RNX* module);
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) */
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnx32x2( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNX32X2_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a // a
|
||||||
|
);
|
||||||
|
|
||||||
|
// decompose from tnxdbl
|
||||||
|
|
||||||
|
typedef struct tnxdbl_approxdecomp_gadget_t TNXDBL_APPROXDECOMP_GADGET;
|
||||||
|
|
||||||
|
/** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */
|
||||||
|
EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t k, uint64_t ell // base 2^K and size
|
||||||
|
);
|
||||||
|
EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget);
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) */
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnxdbl( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a); // a
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// prepared scalar-vector product (optimized) //
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/** @brief opaque type that represents a polynomial of RnX prepared for a scalar-vector product */
|
||||||
|
typedef struct rnx_svp_ppol_t RNX_SVP_PPOL;
|
||||||
|
|
||||||
|
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||||
|
EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module); // N
|
||||||
|
|
||||||
|
/** @brief allocates a prepared vector (release with delete_rnx_svp_ppol) */
|
||||||
|
EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module); // N
|
||||||
|
|
||||||
|
/** @brief frees memory for a prepared vector */
|
||||||
|
EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* res);
|
||||||
|
|
||||||
|
/** @brief prepares a svp polynomial */
|
||||||
|
EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N
|
||||||
|
RNX_SVP_PPOL* ppol, // output
|
||||||
|
const double* pol // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||||
|
EXPORT void rnx_svp_apply( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||||
|
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// prepared vector-matrix product (optimized) //
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
typedef struct rnx_vmp_pmat_t RNX_VMP_PMAT;
|
||||||
|
|
||||||
|
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||||
|
EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols); // dimensions
|
||||||
|
|
||||||
|
/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */
|
||||||
|
EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols); // dimensions
|
||||||
|
EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr);
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void rnx_vmp_prepare_contiguous( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* a, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||||
|
EXPORT uint64_t rnx_vmp_prepare_contiguous_tmp_bytes(const MOD_RNX* module);
|
||||||
|
|
||||||
|
/** @brief applies a vmp product res = a x pmat */
|
||||||
|
EXPORT void rnx_vmp_apply_tmp_a( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res size
|
||||||
|
uint64_t a_size, // a size
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix dims
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT void rnx_vmp_apply_dft_to_dft( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = DFT(a) */
|
||||||
|
EXPORT void vec_rnx_dft(const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = iDFT(a_dft) -- idft is not normalized */
|
||||||
|
EXPORT void vec_rnx_idft(const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a_dft, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_VEC_RNX_ARITHMETIC_H
|
||||||
189
spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_avx.c
Normal file
189
spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_avx.c
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
#include <immintrin.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
void rnx_add_avx(uint64_t nn, double* res, const double* a, const double* b) {
|
||||||
|
if (nn < 8) {
|
||||||
|
if (nn == 4) {
|
||||||
|
_mm256_storeu_pd(res, _mm256_add_pd(_mm256_loadu_pd(a), _mm256_loadu_pd(b)));
|
||||||
|
} else if (nn == 2) {
|
||||||
|
_mm_storeu_pd(res, _mm_add_pd(_mm_loadu_pd(a), _mm_loadu_pd(b)));
|
||||||
|
} else if (nn == 1) {
|
||||||
|
*res = *a + *b;
|
||||||
|
} else {
|
||||||
|
NOT_SUPPORTED(); // not a power of 2
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// general case: nn >= 8
|
||||||
|
__m256d x0, x1, x2, x3, x4, x5;
|
||||||
|
const double* aa = a;
|
||||||
|
const double* bb = b;
|
||||||
|
double* rr = res;
|
||||||
|
double* const rrend = res + nn;
|
||||||
|
do {
|
||||||
|
x0 = _mm256_loadu_pd(aa);
|
||||||
|
x1 = _mm256_loadu_pd(aa + 4);
|
||||||
|
x2 = _mm256_loadu_pd(bb);
|
||||||
|
x3 = _mm256_loadu_pd(bb + 4);
|
||||||
|
x4 = _mm256_add_pd(x0, x2);
|
||||||
|
x5 = _mm256_add_pd(x1, x3);
|
||||||
|
_mm256_storeu_pd(rr, x4);
|
||||||
|
_mm256_storeu_pd(rr + 4, x5);
|
||||||
|
aa += 8;
|
||||||
|
bb += 8;
|
||||||
|
rr += 8;
|
||||||
|
} while (rr < rrend);
|
||||||
|
}
|
||||||
|
|
||||||
|
void rnx_sub_avx(uint64_t nn, double* res, const double* a, const double* b) {
|
||||||
|
if (nn < 8) {
|
||||||
|
if (nn == 4) {
|
||||||
|
_mm256_storeu_pd(res, _mm256_sub_pd(_mm256_loadu_pd(a), _mm256_loadu_pd(b)));
|
||||||
|
} else if (nn == 2) {
|
||||||
|
_mm_storeu_pd(res, _mm_sub_pd(_mm_loadu_pd(a), _mm_loadu_pd(b)));
|
||||||
|
} else if (nn == 1) {
|
||||||
|
*res = *a - *b;
|
||||||
|
} else {
|
||||||
|
NOT_SUPPORTED(); // not a power of 2
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// general case: nn >= 8
|
||||||
|
__m256d x0, x1, x2, x3, x4, x5;
|
||||||
|
const double* aa = a;
|
||||||
|
const double* bb = b;
|
||||||
|
double* rr = res;
|
||||||
|
double* const rrend = res + nn;
|
||||||
|
do {
|
||||||
|
x0 = _mm256_loadu_pd(aa);
|
||||||
|
x1 = _mm256_loadu_pd(aa + 4);
|
||||||
|
x2 = _mm256_loadu_pd(bb);
|
||||||
|
x3 = _mm256_loadu_pd(bb + 4);
|
||||||
|
x4 = _mm256_sub_pd(x0, x2);
|
||||||
|
x5 = _mm256_sub_pd(x1, x3);
|
||||||
|
_mm256_storeu_pd(rr, x4);
|
||||||
|
_mm256_storeu_pd(rr + 4, x5);
|
||||||
|
aa += 8;
|
||||||
|
bb += 8;
|
||||||
|
rr += 8;
|
||||||
|
} while (rr < rrend);
|
||||||
|
}
|
||||||
|
|
||||||
|
void rnx_negate_avx(uint64_t nn, double* res, const double* b) {
|
||||||
|
if (nn < 8) {
|
||||||
|
if (nn == 4) {
|
||||||
|
_mm256_storeu_pd(res, _mm256_sub_pd(_mm256_set1_pd(0), _mm256_loadu_pd(b)));
|
||||||
|
} else if (nn == 2) {
|
||||||
|
_mm_storeu_pd(res, _mm_sub_pd(_mm_set1_pd(0), _mm_loadu_pd(b)));
|
||||||
|
} else if (nn == 1) {
|
||||||
|
*res = -*b;
|
||||||
|
} else {
|
||||||
|
NOT_SUPPORTED(); // not a power of 2
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// general case: nn >= 8
|
||||||
|
__m256d x2, x3, x4, x5;
|
||||||
|
const __m256d ZERO = _mm256_set1_pd(0);
|
||||||
|
const double* bb = b;
|
||||||
|
double* rr = res;
|
||||||
|
double* const rrend = res + nn;
|
||||||
|
do {
|
||||||
|
x2 = _mm256_loadu_pd(bb);
|
||||||
|
x3 = _mm256_loadu_pd(bb + 4);
|
||||||
|
x4 = _mm256_sub_pd(ZERO, x2);
|
||||||
|
x5 = _mm256_sub_pd(ZERO, x3);
|
||||||
|
_mm256_storeu_pd(rr, x4);
|
||||||
|
_mm256_storeu_pd(rr + 4, x5);
|
||||||
|
bb += 8;
|
||||||
|
rr += 8;
|
||||||
|
} while (rr < rrend);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a + b */
|
||||||
|
EXPORT void vec_rnx_add_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
if (a_size < b_size) {
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_add_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
memcpy(res + i * res_sl, b + i * b_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_add_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = -a */
|
||||||
|
EXPORT void vec_rnx_negate_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_negate_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a - b */
|
||||||
|
EXPORT void vec_rnx_sub_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
if (a_size < b_size) {
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_sub_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
rnx_negate_avx(nn, res + i * res_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_sub_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
88
spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h
Normal file
88
spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H
|
||||||
|
#define SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H
|
||||||
|
|
||||||
|
#include "vec_rnx_arithmetic.h"
|
||||||
|
|
||||||
|
typedef typeof(vec_rnx_zero) VEC_RNX_ZERO_F;
|
||||||
|
typedef typeof(vec_rnx_copy) VEC_RNX_COPY_F;
|
||||||
|
typedef typeof(vec_rnx_negate) VEC_RNX_NEGATE_F;
|
||||||
|
typedef typeof(vec_rnx_add) VEC_RNX_ADD_F;
|
||||||
|
typedef typeof(vec_rnx_sub) VEC_RNX_SUB_F;
|
||||||
|
typedef typeof(vec_rnx_rotate) VEC_RNX_ROTATE_F;
|
||||||
|
typedef typeof(vec_rnx_mul_xp_minus_one) VEC_RNX_MUL_XP_MINUS_ONE_F;
|
||||||
|
typedef typeof(vec_rnx_automorphism) VEC_RNX_AUTOMORPHISM_F;
|
||||||
|
typedef typeof(vec_rnx_to_znx32) VEC_RNX_TO_ZNX32_F;
|
||||||
|
typedef typeof(vec_rnx_from_znx32) VEC_RNX_FROM_ZNX32_F;
|
||||||
|
typedef typeof(vec_rnx_to_tnx32) VEC_RNX_TO_TNX32_F;
|
||||||
|
typedef typeof(vec_rnx_from_tnx32) VEC_RNX_FROM_TNX32_F;
|
||||||
|
typedef typeof(vec_rnx_to_tnx32x2) VEC_RNX_TO_TNX32X2_F;
|
||||||
|
typedef typeof(vec_rnx_from_tnx32x2) VEC_RNX_FROM_TNX32X2_F;
|
||||||
|
typedef typeof(vec_rnx_to_tnxdbl) VEC_RNX_TO_TNXDBL_F;
|
||||||
|
// typedef typeof(vec_rnx_from_tnxdbl) VEC_RNX_FROM_TNXDBL_F;
|
||||||
|
typedef typeof(rnx_small_single_product) RNX_SMALL_SINGLE_PRODUCT_F;
|
||||||
|
typedef typeof(rnx_small_single_product_tmp_bytes) RNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||||
|
typedef typeof(tnxdbl_small_single_product) TNXDBL_SMALL_SINGLE_PRODUCT_F;
|
||||||
|
typedef typeof(tnxdbl_small_single_product_tmp_bytes) TNXDBL_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||||
|
typedef typeof(znx32_small_single_product) ZNX32_SMALL_SINGLE_PRODUCT_F;
|
||||||
|
typedef typeof(znx32_small_single_product_tmp_bytes) ZNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||||
|
typedef typeof(tnx32_small_single_product) TNX32_SMALL_SINGLE_PRODUCT_F;
|
||||||
|
typedef typeof(tnx32_small_single_product_tmp_bytes) TNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||||
|
typedef typeof(rnx_approxdecomp_from_tnx32) RNX_APPROXDECOMP_FROM_TNX32_F;
|
||||||
|
typedef typeof(rnx_approxdecomp_from_tnx32x2) RNX_APPROXDECOMP_FROM_TNX32X2_F;
|
||||||
|
typedef typeof(rnx_approxdecomp_from_tnxdbl) RNX_APPROXDECOMP_FROM_TNXDBL_F;
|
||||||
|
typedef typeof(bytes_of_rnx_svp_ppol) BYTES_OF_RNX_SVP_PPOL_F;
|
||||||
|
typedef typeof(rnx_svp_prepare) RNX_SVP_PREPARE_F;
|
||||||
|
typedef typeof(rnx_svp_apply) RNX_SVP_APPLY_F;
|
||||||
|
typedef typeof(bytes_of_rnx_vmp_pmat) BYTES_OF_RNX_VMP_PMAT_F;
|
||||||
|
typedef typeof(rnx_vmp_prepare_contiguous) RNX_VMP_PREPARE_CONTIGUOUS_F;
|
||||||
|
typedef typeof(rnx_vmp_prepare_contiguous_tmp_bytes) RNX_VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F;
|
||||||
|
typedef typeof(rnx_vmp_apply_tmp_a) RNX_VMP_APPLY_TMP_A_F;
|
||||||
|
typedef typeof(rnx_vmp_apply_tmp_a_tmp_bytes) RNX_VMP_APPLY_TMP_A_TMP_BYTES_F;
|
||||||
|
typedef typeof(rnx_vmp_apply_dft_to_dft) RNX_VMP_APPLY_DFT_TO_DFT_F;
|
||||||
|
typedef typeof(rnx_vmp_apply_dft_to_dft_tmp_bytes) RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F;
|
||||||
|
typedef typeof(vec_rnx_dft) VEC_RNX_DFT_F;
|
||||||
|
typedef typeof(vec_rnx_idft) VEC_RNX_IDFT_F;
|
||||||
|
|
||||||
|
typedef struct rnx_module_vtable_t RNX_MODULE_VTABLE;
|
||||||
|
struct rnx_module_vtable_t {
|
||||||
|
VEC_RNX_ZERO_F* vec_rnx_zero;
|
||||||
|
VEC_RNX_COPY_F* vec_rnx_copy;
|
||||||
|
VEC_RNX_NEGATE_F* vec_rnx_negate;
|
||||||
|
VEC_RNX_ADD_F* vec_rnx_add;
|
||||||
|
VEC_RNX_SUB_F* vec_rnx_sub;
|
||||||
|
VEC_RNX_ROTATE_F* vec_rnx_rotate;
|
||||||
|
VEC_RNX_MUL_XP_MINUS_ONE_F* vec_rnx_mul_xp_minus_one;
|
||||||
|
VEC_RNX_AUTOMORPHISM_F* vec_rnx_automorphism;
|
||||||
|
VEC_RNX_TO_ZNX32_F* vec_rnx_to_znx32;
|
||||||
|
VEC_RNX_FROM_ZNX32_F* vec_rnx_from_znx32;
|
||||||
|
VEC_RNX_TO_TNX32_F* vec_rnx_to_tnx32;
|
||||||
|
VEC_RNX_FROM_TNX32_F* vec_rnx_from_tnx32;
|
||||||
|
VEC_RNX_TO_TNX32X2_F* vec_rnx_to_tnx32x2;
|
||||||
|
VEC_RNX_FROM_TNX32X2_F* vec_rnx_from_tnx32x2;
|
||||||
|
VEC_RNX_TO_TNXDBL_F* vec_rnx_to_tnxdbl;
|
||||||
|
RNX_SMALL_SINGLE_PRODUCT_F* rnx_small_single_product;
|
||||||
|
RNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* rnx_small_single_product_tmp_bytes;
|
||||||
|
TNXDBL_SMALL_SINGLE_PRODUCT_F* tnxdbl_small_single_product;
|
||||||
|
TNXDBL_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* tnxdbl_small_single_product_tmp_bytes;
|
||||||
|
ZNX32_SMALL_SINGLE_PRODUCT_F* znx32_small_single_product;
|
||||||
|
ZNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx32_small_single_product_tmp_bytes;
|
||||||
|
TNX32_SMALL_SINGLE_PRODUCT_F* tnx32_small_single_product;
|
||||||
|
TNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* tnx32_small_single_product_tmp_bytes;
|
||||||
|
RNX_APPROXDECOMP_FROM_TNX32_F* rnx_approxdecomp_from_tnx32;
|
||||||
|
RNX_APPROXDECOMP_FROM_TNX32X2_F* rnx_approxdecomp_from_tnx32x2;
|
||||||
|
RNX_APPROXDECOMP_FROM_TNXDBL_F* rnx_approxdecomp_from_tnxdbl;
|
||||||
|
BYTES_OF_RNX_SVP_PPOL_F* bytes_of_rnx_svp_ppol;
|
||||||
|
RNX_SVP_PREPARE_F* rnx_svp_prepare;
|
||||||
|
RNX_SVP_APPLY_F* rnx_svp_apply;
|
||||||
|
BYTES_OF_RNX_VMP_PMAT_F* bytes_of_rnx_vmp_pmat;
|
||||||
|
RNX_VMP_PREPARE_CONTIGUOUS_F* rnx_vmp_prepare_contiguous;
|
||||||
|
RNX_VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F* rnx_vmp_prepare_contiguous_tmp_bytes;
|
||||||
|
RNX_VMP_APPLY_TMP_A_F* rnx_vmp_apply_tmp_a;
|
||||||
|
RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* rnx_vmp_apply_tmp_a_tmp_bytes;
|
||||||
|
RNX_VMP_APPLY_DFT_TO_DFT_F* rnx_vmp_apply_dft_to_dft;
|
||||||
|
RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* rnx_vmp_apply_dft_to_dft_tmp_bytes;
|
||||||
|
VEC_RNX_DFT_F* vec_rnx_dft;
|
||||||
|
VEC_RNX_IDFT_F* vec_rnx_idft;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H
|
||||||
284
spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_private.h
Normal file
284
spqlios/lib/spqlios/arithmetic/vec_rnx_arithmetic_private.h
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H
|
||||||
|
#define SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H
|
||||||
|
|
||||||
|
#include "../commons_private.h"
|
||||||
|
#include "../reim/reim_fft.h"
|
||||||
|
#include "vec_rnx_arithmetic.h"
|
||||||
|
#include "vec_rnx_arithmetic_plugin.h"
|
||||||
|
|
||||||
|
typedef struct fft64_rnx_module_precomp_t FFT64_RNX_MODULE_PRECOMP;
|
||||||
|
struct fft64_rnx_module_precomp_t {
|
||||||
|
REIM_FFT_PRECOMP* p_fft;
|
||||||
|
REIM_IFFT_PRECOMP* p_ifft;
|
||||||
|
REIM_FFTVEC_MUL_PRECOMP* p_fftvec_mul;
|
||||||
|
REIM_FFTVEC_ADDMUL_PRECOMP* p_fftvec_addmul;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef union rnx_module_precomp_t RNX_MODULE_PRECOMP;
|
||||||
|
union rnx_module_precomp_t {
|
||||||
|
FFT64_RNX_MODULE_PRECOMP fft64;
|
||||||
|
};
|
||||||
|
|
||||||
|
void fft64_init_rnx_module_precomp(MOD_RNX* module);
|
||||||
|
|
||||||
|
void fft64_finalize_rnx_module_precomp(MOD_RNX* module);
|
||||||
|
|
||||||
|
/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */
|
||||||
|
struct rnx_module_info_t {
|
||||||
|
uint64_t n;
|
||||||
|
uint64_t m;
|
||||||
|
RNX_MODULE_TYPE mtype;
|
||||||
|
RNX_MODULE_VTABLE vtable;
|
||||||
|
RNX_MODULE_PRECOMP precomp;
|
||||||
|
void* custom;
|
||||||
|
void (*custom_deleter)(void*);
|
||||||
|
};
|
||||||
|
|
||||||
|
void init_rnx_module_info(MOD_RNX* module, //
|
||||||
|
uint64_t, RNX_MODULE_TYPE mtype);
|
||||||
|
|
||||||
|
void finalize_rnx_module_info(MOD_RNX* module);
|
||||||
|
|
||||||
|
void fft64_init_rnx_module_vtable(MOD_RNX* module);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// prepared gadget decompositions (optimized) //
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
struct tnx32_approxdec_gadget_t {
|
||||||
|
uint64_t k;
|
||||||
|
uint64_t ell;
|
||||||
|
int32_t add_cst; // 1/2.(sum 2^-(i+1)K)
|
||||||
|
int32_t rshift_base; // 32 - K
|
||||||
|
int64_t and_mask; // 2^K-1
|
||||||
|
int64_t or_mask; // double(2^52)
|
||||||
|
double sub_cst; // double(2^52 + 2^(K-1))
|
||||||
|
uint8_t rshifts[8]; // 32 - (i+1).K
|
||||||
|
};
|
||||||
|
|
||||||
|
struct tnx32x2_approxdec_gadget_t {
|
||||||
|
// TODO
|
||||||
|
};
|
||||||
|
|
||||||
|
struct tnxdbl_approxdecomp_gadget_t {
|
||||||
|
uint64_t k;
|
||||||
|
uint64_t ell;
|
||||||
|
double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[)
|
||||||
|
uint64_t and_mask; // uint64(2^(K)-1)
|
||||||
|
uint64_t or_mask; // double(2^52)
|
||||||
|
double sub_cst; // double(2^52 + 2^(K-1))
|
||||||
|
};
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_add_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
EXPORT void vec_rnx_add_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = 0 */
|
||||||
|
EXPORT void vec_rnx_zero_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_rnx_copy_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = -a */
|
||||||
|
EXPORT void vec_rnx_negate_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = -a */
|
||||||
|
EXPORT void vec_rnx_negate_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a - b */
|
||||||
|
EXPORT void vec_rnx_sub_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a - b */
|
||||||
|
EXPORT void vec_rnx_sub_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void vec_rnx_rotate_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void vec_rnx_automorphism_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int64_t p, // X -> X^p
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||||
|
EXPORT uint64_t fft64_bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols);
|
||||||
|
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref(const MOD_RNX* module);
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx(const MOD_RNX* module);
|
||||||
|
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||||
|
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_tmp_a_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||||
|
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
|
||||||
|
/// gadget decompositions
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) */
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnxdbl_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a); // a
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnxdbl_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a); // a
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_mul_xp_minus_one_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_znx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_znx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_tnx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnxdbl_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module); // N
|
||||||
|
|
||||||
|
/** @brief prepares a svp polynomial */
|
||||||
|
EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N
|
||||||
|
RNX_SVP_PPOL* ppol, // output
|
||||||
|
const double* pol // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||||
|
EXPORT void fft64_rnx_svp_apply_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||||
|
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H
|
||||||
91
spqlios/lib/spqlios/arithmetic/vec_rnx_conversions_ref.c
Normal file
91
spqlios/lib/spqlios/arithmetic/vec_rnx_conversions_ref.c
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
#include "zn_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_znx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
dbl_round_to_i32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_znx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
i32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPORT void vec_rnx_to_tnx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
dbl_to_tn32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPORT void vec_rnx_from_tnx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
tn32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void dbl_to_tndbl_ref( //
|
||||||
|
const void* UNUSED, // N
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
static const double OFF_CST = INT64_C(3) << 51;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
double ai = a[i] + OFF_CST;
|
||||||
|
res[i] = a[i] - (ai - OFF_CST);
|
||||||
|
}
|
||||||
|
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnxdbl_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
dbl_to_tndbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
47
spqlios/lib/spqlios/arithmetic/vec_rnx_svp_ref.c
Normal file
47
spqlios/lib/spqlios/arithmetic/vec_rnx_svp_ref.c
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../coeffs/coeffs_arithmetic.h"
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->n * sizeof(double); }
|
||||||
|
|
||||||
|
EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module) { return spqlios_alloc(bytes_of_rnx_svp_ppol(module)); }
|
||||||
|
|
||||||
|
EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* ppol) { spqlios_free(ppol); }
|
||||||
|
|
||||||
|
/** @brief prepares a svp polynomial */
|
||||||
|
EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N
|
||||||
|
RNX_SVP_PPOL* ppol, // output
|
||||||
|
const double* pol // a
|
||||||
|
) {
|
||||||
|
double* const dppol = (double*)ppol;
|
||||||
|
rnx_divide_by_m_ref(module->n, module->m, dppol, pol);
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, dppol);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void fft64_rnx_svp_apply_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||||
|
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
double* const dppol = (double*)ppol;
|
||||||
|
|
||||||
|
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||||
|
const double* a_ptr = a + i * a_sl;
|
||||||
|
double* const res_ptr = res + i * res_sl;
|
||||||
|
// copy the polynomial to res, apply fft in place, call fftvec
|
||||||
|
// _mul, apply ifft in place.
|
||||||
|
memcpy(res_ptr, a_ptr, nn * sizeof(double));
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, (double*)res_ptr);
|
||||||
|
reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, res_ptr, res_ptr, dppol);
|
||||||
|
reim_ifft(module->precomp.fft64.p_ifft, res_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = auto_end_idx; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
196
spqlios/lib/spqlios/arithmetic/vec_rnx_vmp_avx.c
Normal file
196
spqlios/lib/spqlios/arithmetic/vec_rnx_vmp_avx.c
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
#include <assert.h>
|
||||||
|
#include <immintrin.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../coeffs/coeffs_arithmetic.h"
|
||||||
|
#include "../reim/reim_fft.h"
|
||||||
|
#include "../reim4/reim4_arithmetic.h"
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
// there is an edge case if nn < 8
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
|
||||||
|
double* const dtmp = (double*)tmp_space;
|
||||||
|
double* const output_mat = (double*)pmat;
|
||||||
|
double* start_addr = (double*)pmat;
|
||||||
|
uint64_t offset = nrows * ncols * 8;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
rnx_divide_by_m_avx(nn, m, dtmp, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, dtmp);
|
||||||
|
|
||||||
|
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||||
|
// special case: last column out of an odd column number
|
||||||
|
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||||
|
+ row_i * 8;
|
||||||
|
} else {
|
||||||
|
// general case: columns go by pair
|
||||||
|
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||||
|
+ row_i * 2 * 8 // third: row index
|
||||||
|
+ (col_i % 2) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
// extract blk from tmp and save it
|
||||||
|
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, dtmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
double* res = output_mat + (col_i * nrows + row_i) * nn;
|
||||||
|
rnx_divide_by_m_avx(nn, m, res, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
|
||||||
|
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||||
|
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||||
|
|
||||||
|
double* mat_input = (double*)pmat;
|
||||||
|
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||||
|
|
||||||
|
if (row_max > 0 && col_max > 0) {
|
||||||
|
if (nn >= 8) {
|
||||||
|
// let's do some prefetching of the GSW key, since on some cpus,
|
||||||
|
// it helps
|
||||||
|
const uint64_t ms4 = m >> 2; // m/4
|
||||||
|
const uint64_t gsw_iter_doubles = 8 * nrows * ncols;
|
||||||
|
const uint64_t pref_doubles = 1200;
|
||||||
|
const double* gsw_pref_ptr = mat_input;
|
||||||
|
const double* const gsw_ptr_end = mat_input + ms4 * gsw_iter_doubles;
|
||||||
|
const double* gsw_pref_ptr_target = mat_input + pref_doubles;
|
||||||
|
for (; gsw_pref_ptr < gsw_pref_ptr_target; gsw_pref_ptr += 8) {
|
||||||
|
__builtin_prefetch(gsw_pref_ptr, 0, _MM_HINT_T0);
|
||||||
|
}
|
||||||
|
const double* mat_blk_start;
|
||||||
|
uint64_t blk_i;
|
||||||
|
for (blk_i = 0, mat_blk_start = mat_input; blk_i < ms4; blk_i++, mat_blk_start += gsw_iter_doubles) {
|
||||||
|
// prefetch the next iteration
|
||||||
|
if (gsw_pref_ptr_target < gsw_ptr_end) {
|
||||||
|
gsw_pref_ptr_target += gsw_iter_doubles;
|
||||||
|
if (gsw_pref_ptr_target > gsw_ptr_end) gsw_pref_ptr_target = gsw_ptr_end;
|
||||||
|
for (; gsw_pref_ptr < gsw_pref_ptr_target; gsw_pref_ptr += 8) {
|
||||||
|
__builtin_prefetch(gsw_pref_ptr, 0, _MM_HINT_T0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reim4_extract_1blk_from_contiguous_reim_sl_avx(m, a_sl, row_max, blk_i, extracted_blk, a_dft);
|
||||||
|
// apply mat2cols
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||||
|
uint64_t col_offset = col_i * (8 * nrows);
|
||||||
|
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
|
||||||
|
reim4_save_1blk_to_reim_avx(m, blk_i, res + col_i * res_sl, mat2cols_output);
|
||||||
|
reim4_save_1blk_to_reim_avx(m, blk_i, res + (col_i + 1) * res_sl, mat2cols_output + 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if col_max is odd, then special case
|
||||||
|
if (col_max % 2 == 1) {
|
||||||
|
uint64_t last_col = col_max - 1;
|
||||||
|
uint64_t col_offset = last_col * (8 * nrows);
|
||||||
|
|
||||||
|
// the last column is alone in the pmat: vec_mat1col
|
||||||
|
if (ncols == col_max) {
|
||||||
|
reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
} else {
|
||||||
|
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||||
|
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
}
|
||||||
|
reim4_save_1blk_to_reim_avx(m, blk_i, res + last_col * res_sl, mat2cols_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const double* in;
|
||||||
|
uint64_t in_sl;
|
||||||
|
if (res == a_dft) {
|
||||||
|
// it is in place: copy the input vector
|
||||||
|
in = (double*)tmp_space;
|
||||||
|
in_sl = nn;
|
||||||
|
// vec_rnx_copy(module, (double*)tmp_space, row_max, nn, a_dft, row_max, a_sl);
|
||||||
|
for (uint64_t row_i = 0; row_i < row_max; row_i++) {
|
||||||
|
memcpy((double*)tmp_space + row_i * nn, a_dft + row_i * a_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// it is out of place: do the product directly
|
||||||
|
in = a_dft;
|
||||||
|
in_sl = a_sl;
|
||||||
|
}
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||||
|
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||||
|
{
|
||||||
|
reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, //
|
||||||
|
res + col_i * res_sl, //
|
||||||
|
in, //
|
||||||
|
pmat_col);
|
||||||
|
}
|
||||||
|
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||||
|
reim_fftvec_addmul(module->precomp.fft64.p_fftvec_addmul, //
|
||||||
|
res + col_i * res_sl, //
|
||||||
|
in + row_i * in_sl, //
|
||||||
|
pmat_col + row_i * nn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// zero out remaining bytes (if any)
|
||||||
|
for (uint64_t i = col_max; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product res = a x pmat */
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_tmp_a_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||||
|
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||||
|
const uint64_t cols = ncols < res_size ? ncols : res_size;
|
||||||
|
|
||||||
|
// fft is done in place on the input (tmpa is destroyed)
|
||||||
|
for (uint64_t i = 0; i < rows; ++i) {
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, tmpa + i * a_sl);
|
||||||
|
}
|
||||||
|
fft64_rnx_vmp_apply_dft_to_dft_avx(module, //
|
||||||
|
res, cols, res_sl, //
|
||||||
|
tmpa, rows, a_sl, //
|
||||||
|
pmat, nrows, ncols, //
|
||||||
|
tmp_space);
|
||||||
|
// ifft is done in place on the output
|
||||||
|
for (uint64_t i = 0; i < cols; ++i) {
|
||||||
|
reim_ifft(module->precomp.fft64.p_ifft, res + i * res_sl);
|
||||||
|
}
|
||||||
|
// zero out the remaining positions
|
||||||
|
for (uint64_t i = cols; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
251
spqlios/lib/spqlios/arithmetic/vec_rnx_vmp_ref.c
Normal file
251
spqlios/lib/spqlios/arithmetic/vec_rnx_vmp_ref.c
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
#include <assert.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../coeffs/coeffs_arithmetic.h"
|
||||||
|
#include "../reim/reim_fft.h"
|
||||||
|
#include "../reim4/reim4_arithmetic.h"
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||||
|
EXPORT uint64_t fft64_bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols) { // dimensions
|
||||||
|
return nrows * ncols * module->n * sizeof(double);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
// there is an edge case if nn < 8
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
|
||||||
|
double* const dtmp = (double*)tmp_space;
|
||||||
|
double* const output_mat = (double*)pmat;
|
||||||
|
double* start_addr = (double*)pmat;
|
||||||
|
uint64_t offset = nrows * ncols * 8;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
rnx_divide_by_m_ref(nn, m, dtmp, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, dtmp);
|
||||||
|
|
||||||
|
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||||
|
// special case: last column out of an odd column number
|
||||||
|
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||||
|
+ row_i * 8;
|
||||||
|
} else {
|
||||||
|
// general case: columns go by pair
|
||||||
|
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||||
|
+ row_i * 2 * 8 // third: row index
|
||||||
|
+ (col_i % 2) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
// extract blk from tmp and save it
|
||||||
|
reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, dtmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
double* res = output_mat + (col_i * nrows + row_i) * nn;
|
||||||
|
rnx_divide_by_m_ref(nn, m, res, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref(const MOD_RNX* module) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
return nn * sizeof(int64_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
|
||||||
|
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||||
|
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||||
|
|
||||||
|
double* mat_input = (double*)pmat;
|
||||||
|
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||||
|
|
||||||
|
if (row_max > 0 && col_max > 0) {
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||||
|
|
||||||
|
reim4_extract_1blk_from_contiguous_reim_sl_ref(m, a_sl, row_max, blk_i, extracted_blk, a_dft);
|
||||||
|
// apply mat2cols
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||||
|
uint64_t col_offset = col_i * (8 * nrows);
|
||||||
|
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
|
||||||
|
reim4_save_1blk_to_reim_ref(m, blk_i, res + col_i * res_sl, mat2cols_output);
|
||||||
|
reim4_save_1blk_to_reim_ref(m, blk_i, res + (col_i + 1) * res_sl, mat2cols_output + 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if col_max is odd, then special case
|
||||||
|
if (col_max % 2 == 1) {
|
||||||
|
uint64_t last_col = col_max - 1;
|
||||||
|
uint64_t col_offset = last_col * (8 * nrows);
|
||||||
|
|
||||||
|
// the last column is alone in the pmat: vec_mat1col
|
||||||
|
if (ncols == col_max) {
|
||||||
|
reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
} else {
|
||||||
|
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||||
|
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
}
|
||||||
|
reim4_save_1blk_to_reim_ref(m, blk_i, res + last_col * res_sl, mat2cols_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const double* in;
|
||||||
|
uint64_t in_sl;
|
||||||
|
if (res == a_dft) {
|
||||||
|
// it is in place: copy the input vector
|
||||||
|
in = (double*)tmp_space;
|
||||||
|
in_sl = nn;
|
||||||
|
// vec_rnx_copy(module, (double*)tmp_space, row_max, nn, a_dft, row_max, a_sl);
|
||||||
|
for (uint64_t row_i = 0; row_i < row_max; row_i++) {
|
||||||
|
memcpy((double*)tmp_space + row_i * nn, a_dft + row_i * a_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// it is out of place: do the product directly
|
||||||
|
in = a_dft;
|
||||||
|
in_sl = a_sl;
|
||||||
|
}
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||||
|
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||||
|
{
|
||||||
|
reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, //
|
||||||
|
res + col_i * res_sl, //
|
||||||
|
in, //
|
||||||
|
pmat_col);
|
||||||
|
}
|
||||||
|
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||||
|
reim_fftvec_addmul(module->precomp.fft64.p_fftvec_addmul, //
|
||||||
|
res + col_i * res_sl, //
|
||||||
|
in + row_i * in_sl, //
|
||||||
|
pmat_col + row_i * nn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// zero out remaining bytes (if any)
|
||||||
|
for (uint64_t i = col_max; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product res = a x pmat */
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||||
|
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||||
|
const uint64_t cols = ncols < res_size ? ncols : res_size;
|
||||||
|
|
||||||
|
// fft is done in place on the input (tmpa is destroyed)
|
||||||
|
for (uint64_t i = 0; i < rows; ++i) {
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, tmpa + i * a_sl);
|
||||||
|
}
|
||||||
|
fft64_rnx_vmp_apply_dft_to_dft_ref(module, //
|
||||||
|
res, cols, res_sl, //
|
||||||
|
tmpa, rows, a_sl, //
|
||||||
|
pmat, nrows, ncols, //
|
||||||
|
tmp_space);
|
||||||
|
// ifft is done in place on the output
|
||||||
|
for (uint64_t i = 0; i < cols; ++i) {
|
||||||
|
reim_ifft(module->precomp.fft64.p_ifft, res + i * res_sl);
|
||||||
|
}
|
||||||
|
// zero out the remaining positions
|
||||||
|
for (uint64_t i = cols; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) {
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
|
||||||
|
return (128) + (64 * row_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef __APPLE__
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) {
|
||||||
|
return fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref(module, res_size, a_size, nrows, ncols);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref")));
|
||||||
|
#endif
|
||||||
|
// avx aliases that need to be defined in the same .c file
|
||||||
|
|
||||||
|
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#pragma weak fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref
|
||||||
|
#else
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx(const MOD_RNX* module)
|
||||||
|
__attribute((alias("fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref")));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#pragma weak fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref
|
||||||
|
#else
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref")));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#pragma weak fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref
|
||||||
|
#else
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref")));
|
||||||
|
#endif
|
||||||
|
// wrappers
|
||||||
333
spqlios/lib/spqlios/arithmetic/vec_znx.c
Normal file
333
spqlios/lib/spqlios/arithmetic/vec_znx.c
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
#include <assert.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../coeffs/coeffs_arithmetic.h"
|
||||||
|
#include "../q120/q120_arithmetic.h"
|
||||||
|
#include "../q120/q120_ntt.h"
|
||||||
|
#include "../reim/reim_fft_internal.h"
|
||||||
|
#include "../reim4/reim4_arithmetic.h"
|
||||||
|
#include "vec_znx_arithmetic.h"
|
||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
// general function (virtual dispatch)
|
||||||
|
|
||||||
|
EXPORT void vec_znx_add(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_add(module, // N
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a, a_size, a_sl, // a
|
||||||
|
b, b_size, b_sl // b
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_sub(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_sub(module, // N
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a, a_size, a_sl, // a
|
||||||
|
b, b_size, b_sl // b
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_rotate(const MODULE* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_rotate(module, // N
|
||||||
|
p, // p
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a, a_size, a_sl // a
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_automorphism(const MODULE* module, // N
|
||||||
|
const int64_t p, // X->X^p
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_automorphism(module, // N
|
||||||
|
p, // p
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a, a_size, a_sl // a
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_normalize_base2k(const MODULE* module, // N
|
||||||
|
uint64_t log2_base2k, // output base 2^K
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
uint8_t* tmp_space // scratch space of size >= N
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_normalize_base2k(module, // N
|
||||||
|
log2_base2k, // log2_base2k
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a, a_size, a_sl, // a
|
||||||
|
tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module // N
|
||||||
|
) {
|
||||||
|
return module->func.vec_znx_normalize_base2k_tmp_bytes(module // N
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// specialized function (ref)
|
||||||
|
|
||||||
|
EXPORT void vec_znx_add_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
if (a_size <= b_size) {
|
||||||
|
const uint64_t sum_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
// add up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||||
|
znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then copy to the largest dimension
|
||||||
|
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||||
|
znx_copy_i64_ref(nn, res + i * res_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t sum_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// add up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||||
|
znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then copy to the largest dimension
|
||||||
|
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||||
|
znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_sub_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
if (a_size <= b_size) {
|
||||||
|
const uint64_t sub_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
// subtract up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||||
|
znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then negate to the largest dimension
|
||||||
|
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||||
|
znx_negate_i64_ref(nn, res + i * res_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t sub_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// subtract up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||||
|
znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then copy to the largest dimension
|
||||||
|
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||||
|
znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_rotate_ref(const MODULE* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
|
||||||
|
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// rotate up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||||
|
int64_t* res_ptr = res + i * res_sl;
|
||||||
|
const int64_t* a_ptr = a + i * a_sl;
|
||||||
|
if (res_ptr == a_ptr) {
|
||||||
|
znx_rotate_inplace_i64(nn, p, res_ptr);
|
||||||
|
} else {
|
||||||
|
znx_rotate_i64(nn, p, res_ptr, a_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_automorphism_ref(const MODULE* module, // N
|
||||||
|
const int64_t p, // X->X^p
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
|
||||||
|
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||||
|
int64_t* res_ptr = res + i * res_sl;
|
||||||
|
const int64_t* a_ptr = a + i * a_sl;
|
||||||
|
if (res_ptr == a_ptr) {
|
||||||
|
znx_automorphism_inplace_i64(nn, p, res_ptr);
|
||||||
|
} else {
|
||||||
|
znx_automorphism_i64(nn, p, res_ptr, a_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = auto_end_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module, // N
|
||||||
|
uint64_t log2_base2k, // output base 2^K
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
uint8_t* tmp_space // scratch space of size >= N
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
|
||||||
|
// use MSB limb of res for carry propagation
|
||||||
|
int64_t* cout = (int64_t*)tmp_space;
|
||||||
|
int64_t* cin = 0x0;
|
||||||
|
|
||||||
|
// propagate carry until first limb of res
|
||||||
|
int64_t i = a_size - 1;
|
||||||
|
for (; i >= res_size; --i) {
|
||||||
|
znx_normalize(nn, log2_base2k, 0x0, cout, a + i * a_sl, cin);
|
||||||
|
cin = cout;
|
||||||
|
}
|
||||||
|
|
||||||
|
// propagate carry and normalize
|
||||||
|
for (; i >= 1; --i) {
|
||||||
|
znx_normalize(nn, log2_base2k, res + i * res_sl, cout, a + i * a_sl, cin);
|
||||||
|
cin = cout;
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize last limb
|
||||||
|
znx_normalize(nn, log2_base2k, res, 0x0, a, cin);
|
||||||
|
|
||||||
|
// extend result with zeros
|
||||||
|
for (uint64_t i = a_size; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module // N
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
return nn * sizeof(int64_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// alias have to be defined in this unit: do not move
|
||||||
|
#ifdef __APPLE__
|
||||||
|
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||||
|
const MODULE* module // N
|
||||||
|
) {
|
||||||
|
return vec_znx_normalize_base2k_tmp_bytes_ref(module);
|
||||||
|
}
|
||||||
|
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( //
|
||||||
|
const MODULE* module // N
|
||||||
|
) {
|
||||||
|
return vec_znx_normalize_base2k_tmp_bytes_ref(module);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( //
|
||||||
|
const MODULE* module // N
|
||||||
|
) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||||
|
const MODULE* module // N
|
||||||
|
) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/** @brief sets res = 0 */
|
||||||
|
EXPORT void vec_znx_zero(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_zero(module, res, res_size, res_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_znx_copy(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_copy(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_znx_negate(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_negate(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_zero_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
) {
|
||||||
|
uint64_t nn = module->nn;
|
||||||
|
for (uint64_t i = 0; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_copy_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
uint64_t nn = module->nn;
|
||||||
|
uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < smin; ++i) {
|
||||||
|
znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = smin; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_negate_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
uint64_t nn = module->nn;
|
||||||
|
uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < smin; ++i) {
|
||||||
|
znx_negate_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = smin; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
357
spqlios/lib/spqlios/arithmetic/vec_znx_arithmetic.h
Normal file
357
spqlios/lib/spqlios/arithmetic/vec_znx_arithmetic.h
Normal file
@@ -0,0 +1,357 @@
|
|||||||
|
#ifndef SPQLIOS_VEC_ZNX_ARITHMETIC_H
|
||||||
|
#define SPQLIOS_VEC_ZNX_ARITHMETIC_H
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "../commons.h"
|
||||||
|
#include "../reim/reim_fft.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* We support the following module families:
|
||||||
|
* - FFT64:
|
||||||
|
* all the polynomials should fit at all times over 52 bits.
|
||||||
|
* for FHE implementations, the recommended limb-sizes are
|
||||||
|
* between K=10 and 20, which is good for low multiplicative depths.
|
||||||
|
* - NTT120:
|
||||||
|
* all the polynomials should fit at all times over 119 bits.
|
||||||
|
* for FHE implementations, the recommended limb-sizes are
|
||||||
|
* between K=20 and 40, which is good for large multiplicative depths.
|
||||||
|
*/
|
||||||
|
typedef enum module_type_t { FFT64, NTT120 } MODULE_TYPE;
|
||||||
|
|
||||||
|
/** @brief opaque structure that describr the modules (ZnX,TnX) and the hardware */
|
||||||
|
typedef struct module_info_t MODULE;
|
||||||
|
/** @brief opaque type that represents a prepared matrix */
|
||||||
|
typedef struct vmp_pmat_t VMP_PMAT;
|
||||||
|
/** @brief opaque type that represents a vector of znx in DFT space */
|
||||||
|
typedef struct vec_znx_dft_t VEC_ZNX_DFT;
|
||||||
|
/** @brief opaque type that represents a vector of znx in large coeffs space */
|
||||||
|
typedef struct vec_znx_bigcoeff_t VEC_ZNX_BIG;
|
||||||
|
/** @brief opaque type that represents a prepared scalar vector product */
|
||||||
|
typedef struct svp_ppol_t SVP_PPOL;
|
||||||
|
/** @brief opaque type that represents a prepared left convolution vector product */
|
||||||
|
typedef struct cnv_pvec_l_t CNV_PVEC_L;
|
||||||
|
/** @brief opaque type that represents a prepared right convolution vector product */
|
||||||
|
typedef struct cnv_pvec_r_t CNV_PVEC_R;
|
||||||
|
|
||||||
|
/** @brief bytes needed for a vec_znx in DFT space */
|
||||||
|
EXPORT uint64_t bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||||
|
uint64_t size);
|
||||||
|
|
||||||
|
/** @brief allocates a vec_znx in DFT space */
|
||||||
|
EXPORT VEC_ZNX_DFT* new_vec_znx_dft(const MODULE* module, // N
|
||||||
|
uint64_t size);
|
||||||
|
|
||||||
|
/** @brief frees memory from a vec_znx in DFT space */
|
||||||
|
EXPORT void delete_vec_znx_dft(VEC_ZNX_DFT* res);
|
||||||
|
|
||||||
|
/** @brief bytes needed for a vec_znx_big */
|
||||||
|
EXPORT uint64_t bytes_of_vec_znx_big(const MODULE* module, // N
|
||||||
|
uint64_t size);
|
||||||
|
|
||||||
|
/** @brief allocates a vec_znx_big */
|
||||||
|
EXPORT VEC_ZNX_BIG* new_vec_znx_big(const MODULE* module, // N
|
||||||
|
uint64_t size);
|
||||||
|
/** @brief frees memory from a vec_znx_big */
|
||||||
|
EXPORT void delete_vec_znx_big(VEC_ZNX_BIG* res);
|
||||||
|
|
||||||
|
/** @brief bytes needed for a prepared vector */
|
||||||
|
EXPORT uint64_t bytes_of_svp_ppol(const MODULE* module); // N
|
||||||
|
|
||||||
|
/** @brief allocates a prepared vector */
|
||||||
|
EXPORT SVP_PPOL* new_svp_ppol(const MODULE* module); // N
|
||||||
|
|
||||||
|
/** @brief frees memory for a prepared vector */
|
||||||
|
EXPORT void delete_svp_ppol(SVP_PPOL* res);
|
||||||
|
|
||||||
|
/** @brief bytes needed for a prepared matrix */
|
||||||
|
EXPORT uint64_t bytes_of_vmp_pmat(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols);
|
||||||
|
|
||||||
|
/** @brief allocates a prepared matrix */
|
||||||
|
EXPORT VMP_PMAT* new_vmp_pmat(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols);
|
||||||
|
|
||||||
|
/** @brief frees memory for a prepared matrix */
|
||||||
|
EXPORT void delete_vmp_pmat(VMP_PMAT* res);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief obtain a module info for ring dimension N
|
||||||
|
* the module-info knows about:
|
||||||
|
* - the dimension N (or the complex dimension m=N/2)
|
||||||
|
* - any moduleuted fft or ntt items
|
||||||
|
* - the hardware (avx, arm64, x86, ...)
|
||||||
|
*/
|
||||||
|
EXPORT MODULE* new_module_info(uint64_t N, MODULE_TYPE mode);
|
||||||
|
EXPORT void delete_module_info(MODULE* module_info);
|
||||||
|
EXPORT uint64_t module_get_n(const MODULE* module);
|
||||||
|
|
||||||
|
/** @brief sets res = 0 */
|
||||||
|
EXPORT void vec_znx_zero(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_znx_copy(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_znx_negate(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a + b */
|
||||||
|
EXPORT void vec_znx_add(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a - b */
|
||||||
|
EXPORT void vec_znx_sub(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = k-normalize-reduce(a) */
|
||||||
|
EXPORT void vec_znx_normalize_base2k(const MODULE* module, // N
|
||||||
|
uint64_t log2_base2k, // output base 2^K
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
uint8_t* tmp_space // scratch space (size >= N)
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief returns the minimal byte length of scratch space for vec_znx_normalize_base2k */
|
||||||
|
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module // N
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void vec_znx_rotate(const MODULE* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void vec_znx_automorphism(const MODULE* module, // N
|
||||||
|
const int64_t p, // X-X^p
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void vmp_prepare_contiguous(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (mat[row*ncols+col] points to the item) */
|
||||||
|
EXPORT void vmp_prepare_dblptr(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t** mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = 0 */
|
||||||
|
EXPORT void vec_dft_zero(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size // res
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void vec_dft_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a-b */
|
||||||
|
EXPORT void vec_dft_sub(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = DFT(a) */
|
||||||
|
EXPORT void vec_znx_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = iDFT(a_dft) -- output in big coeffs space */
|
||||||
|
EXPORT void vec_znx_idft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief tmp bytes required for vec_znx_idft */
|
||||||
|
EXPORT uint64_t vec_znx_idft_tmp_bytes(const MODULE* module);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief sets res = iDFT(a_dft) -- output in big coeffs space
|
||||||
|
*
|
||||||
|
* @note a_dft is overwritten
|
||||||
|
*/
|
||||||
|
EXPORT void vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void vec_znx_big_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void vec_znx_big_add_small(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
EXPORT void vec_znx_big_add_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a-b */
|
||||||
|
EXPORT void vec_znx_big_sub(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
EXPORT void vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
EXPORT void vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
EXPORT void vec_znx_big_sub_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = k-normalize(a) -- output in int64 coeffs space */
|
||||||
|
EXPORT void vec_znx_big_normalize_base2k(const MODULE* module, // N
|
||||||
|
uint64_t log2_base2k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp_space // temp space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */
|
||||||
|
EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module // N
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||||
|
EXPORT void fft64_svp_apply_dft(const MODULE* module, // N
|
||||||
|
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||||
|
const SVP_PPOL* ppol, // prepared pol
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */
|
||||||
|
EXPORT void vec_znx_big_range_normalize_base2k( //
|
||||||
|
const MODULE* module, // N
|
||||||
|
uint64_t log2_base2k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_range_begin, uint64_t a_range_xend, uint64_t a_range_step, // range
|
||||||
|
uint8_t* tmp_space // temp space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
|
||||||
|
EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||||
|
const MODULE* module // N
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void vec_znx_big_rotate(const MODULE* module, // N
|
||||||
|
int64_t p, // rotation value
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void vec_znx_big_automorphism(const MODULE* module, // N
|
||||||
|
int64_t p, // X-X^p
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||||
|
EXPORT void svp_apply_dft(const MODULE* module, // N
|
||||||
|
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||||
|
const SVP_PPOL* ppol, // prepared pol
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief prepares a svp polynomial */
|
||||||
|
EXPORT void svp_prepare(const MODULE* module, // N
|
||||||
|
SVP_PPOL* ppol, // output
|
||||||
|
const int64_t* pol // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief res = a * b : small integer polynomial product */
|
||||||
|
EXPORT void znx_small_single_product(const MODULE* module, // N
|
||||||
|
int64_t* res, // output
|
||||||
|
const int64_t* a, // a
|
||||||
|
const int64_t* b, // b
|
||||||
|
uint8_t* tmp);
|
||||||
|
|
||||||
|
/** @brief tmp bytes required for znx_small_single_product */
|
||||||
|
EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module);
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void vmp_prepare_contiguous(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||||
|
EXPORT uint64_t vmp_prepare_contiguous_tmp_bytes(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols);
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) */
|
||||||
|
EXPORT void vmp_apply_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t vmp_apply_dft_tmp_bytes(const MODULE* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT void vmp_apply_dft_to_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||||
|
const uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
;
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
#endif // SPQLIOS_VEC_ZNX_ARITHMETIC_H
|
||||||
481
spqlios/lib/spqlios/arithmetic/vec_znx_arithmetic_private.h
Normal file
481
spqlios/lib/spqlios/arithmetic/vec_znx_arithmetic_private.h
Normal file
@@ -0,0 +1,481 @@
|
|||||||
|
#ifndef SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H
|
||||||
|
#define SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H
|
||||||
|
|
||||||
|
#include "../commons_private.h"
|
||||||
|
#include "../q120/q120_ntt.h"
|
||||||
|
#include "vec_znx_arithmetic.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Layouts families:
|
||||||
|
*
|
||||||
|
* fft64:
|
||||||
|
* K: <= 20, N: <= 65536, ell: <= 200
|
||||||
|
* vec<ZnX> normalized: represented by int64
|
||||||
|
* vec<ZnX> large: represented by int64 (expect <=52 bits)
|
||||||
|
* vec<ZnX> DFT: represented by double (reim_fft space)
|
||||||
|
* On AVX2 inftastructure, PMAT, LCNV, RCNV use a special reim4_fft space
|
||||||
|
*
|
||||||
|
* ntt120:
|
||||||
|
* K: <= 50, N: <= 65536, ell: <= 80
|
||||||
|
* vec<ZnX> normalized: represented by int64
|
||||||
|
* vec<ZnX> large: represented by int128 (expect <=120 bits)
|
||||||
|
* vec<ZnX> DFT: represented by int64x4 (ntt120 space)
|
||||||
|
* On AVX2 inftastructure, PMAT, LCNV, RCNV use a special ntt120 space
|
||||||
|
*
|
||||||
|
* ntt104:
|
||||||
|
* K: <= 40, N: <= 65536, ell: <= 80
|
||||||
|
* vec<ZnX> normalized: represented by int64
|
||||||
|
* vec<ZnX> large: represented by int128 (expect <=120 bits)
|
||||||
|
* vec<ZnX> DFT: represented by int64x4 (ntt120 space)
|
||||||
|
* On AVX512 inftastructure, PMAT, LCNV, RCNV use a special ntt104 space
|
||||||
|
*/
|
||||||
|
|
||||||
|
struct fft64_module_info_t {
|
||||||
|
// pre-computation for reim_fft
|
||||||
|
REIM_FFT_PRECOMP* p_fft;
|
||||||
|
// pre-computation for mul_fft
|
||||||
|
REIM_FFTVEC_MUL_PRECOMP* mul_fft;
|
||||||
|
// pre-computation for reim_from_znx6
|
||||||
|
REIM_FROM_ZNX64_PRECOMP* p_conv;
|
||||||
|
// pre-computation for reim_tp_znx6
|
||||||
|
REIM_TO_ZNX64_PRECOMP* p_reim_to_znx;
|
||||||
|
// pre-computation for reim_fft
|
||||||
|
REIM_IFFT_PRECOMP* p_ifft;
|
||||||
|
// pre-computation for reim_fftvec_addmul
|
||||||
|
REIM_FFTVEC_ADDMUL_PRECOMP* p_addmul;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct q120_module_info_t {
|
||||||
|
// pre-computation for q120b to q120b ntt
|
||||||
|
q120_ntt_precomp* p_ntt;
|
||||||
|
// pre-computation for q120b to q120b intt
|
||||||
|
q120_ntt_precomp* p_intt;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO add function types here
|
||||||
|
typedef typeof(vec_znx_zero) VEC_ZNX_ZERO_F;
|
||||||
|
typedef typeof(vec_znx_copy) VEC_ZNX_COPY_F;
|
||||||
|
typedef typeof(vec_znx_negate) VEC_ZNX_NEGATE_F;
|
||||||
|
typedef typeof(vec_znx_add) VEC_ZNX_ADD_F;
|
||||||
|
typedef typeof(vec_znx_dft) VEC_ZNX_DFT_F;
|
||||||
|
typedef typeof(vec_znx_idft) VEC_ZNX_IDFT_F;
|
||||||
|
typedef typeof(vec_znx_idft_tmp_bytes) VEC_ZNX_IDFT_TMP_BYTES_F;
|
||||||
|
typedef typeof(vec_znx_idft_tmp_a) VEC_ZNX_IDFT_TMP_A_F;
|
||||||
|
typedef typeof(vec_znx_sub) VEC_ZNX_SUB_F;
|
||||||
|
typedef typeof(vec_znx_rotate) VEC_ZNX_ROTATE_F;
|
||||||
|
typedef typeof(vec_znx_automorphism) VEC_ZNX_AUTOMORPHISM_F;
|
||||||
|
typedef typeof(vec_znx_normalize_base2k) VEC_ZNX_NORMALIZE_BASE2K_F;
|
||||||
|
typedef typeof(vec_znx_normalize_base2k_tmp_bytes) VEC_ZNX_NORMALIZE_BASE2K_TMP_BYTES_F;
|
||||||
|
typedef typeof(vec_znx_big_normalize_base2k) VEC_ZNX_BIG_NORMALIZE_BASE2K_F;
|
||||||
|
typedef typeof(vec_znx_big_normalize_base2k_tmp_bytes) VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F;
|
||||||
|
typedef typeof(vec_znx_big_range_normalize_base2k) VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F;
|
||||||
|
typedef typeof(vec_znx_big_range_normalize_base2k_tmp_bytes) VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F;
|
||||||
|
typedef typeof(vec_znx_big_add) VEC_ZNX_BIG_ADD_F;
|
||||||
|
typedef typeof(vec_znx_big_add_small) VEC_ZNX_BIG_ADD_SMALL_F;
|
||||||
|
typedef typeof(vec_znx_big_add_small2) VEC_ZNX_BIG_ADD_SMALL2_F;
|
||||||
|
typedef typeof(vec_znx_big_sub) VEC_ZNX_BIG_SUB_F;
|
||||||
|
typedef typeof(vec_znx_big_sub_small_a) VEC_ZNX_BIG_SUB_SMALL_A_F;
|
||||||
|
typedef typeof(vec_znx_big_sub_small_b) VEC_ZNX_BIG_SUB_SMALL_B_F;
|
||||||
|
typedef typeof(vec_znx_big_sub_small2) VEC_ZNX_BIG_SUB_SMALL2_F;
|
||||||
|
typedef typeof(vec_znx_big_rotate) VEC_ZNX_BIG_ROTATE_F;
|
||||||
|
typedef typeof(vec_znx_big_automorphism) VEC_ZNX_BIG_AUTOMORPHISM_F;
|
||||||
|
typedef typeof(svp_prepare) SVP_PREPARE;
|
||||||
|
typedef typeof(svp_apply_dft) SVP_APPLY_DFT_F;
|
||||||
|
typedef typeof(znx_small_single_product) ZNX_SMALL_SINGLE_PRODUCT_F;
|
||||||
|
typedef typeof(znx_small_single_product_tmp_bytes) ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||||
|
typedef typeof(vmp_prepare_contiguous) VMP_PREPARE_CONTIGUOUS_F;
|
||||||
|
typedef typeof(vmp_prepare_contiguous_tmp_bytes) VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F;
|
||||||
|
typedef typeof(vmp_apply_dft) VMP_APPLY_DFT_F;
|
||||||
|
typedef typeof(vmp_apply_dft_tmp_bytes) VMP_APPLY_DFT_TMP_BYTES_F;
|
||||||
|
typedef typeof(vmp_apply_dft_to_dft) VMP_APPLY_DFT_TO_DFT_F;
|
||||||
|
typedef typeof(vmp_apply_dft_to_dft_tmp_bytes) VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F;
|
||||||
|
typedef typeof(bytes_of_vec_znx_dft) BYTES_OF_VEC_ZNX_DFT_F;
|
||||||
|
typedef typeof(bytes_of_vec_znx_big) BYTES_OF_VEC_ZNX_BIG_F;
|
||||||
|
typedef typeof(bytes_of_svp_ppol) BYTES_OF_SVP_PPOL_F;
|
||||||
|
typedef typeof(bytes_of_vmp_pmat) BYTES_OF_VMP_PMAT_F;
|
||||||
|
|
||||||
|
struct module_virtual_functions_t {
|
||||||
|
// TODO add functions here
|
||||||
|
VEC_ZNX_ZERO_F* vec_znx_zero;
|
||||||
|
VEC_ZNX_COPY_F* vec_znx_copy;
|
||||||
|
VEC_ZNX_NEGATE_F* vec_znx_negate;
|
||||||
|
VEC_ZNX_ADD_F* vec_znx_add;
|
||||||
|
VEC_ZNX_DFT_F* vec_znx_dft;
|
||||||
|
VEC_ZNX_IDFT_F* vec_znx_idft;
|
||||||
|
VEC_ZNX_IDFT_TMP_BYTES_F* vec_znx_idft_tmp_bytes;
|
||||||
|
VEC_ZNX_IDFT_TMP_A_F* vec_znx_idft_tmp_a;
|
||||||
|
VEC_ZNX_SUB_F* vec_znx_sub;
|
||||||
|
VEC_ZNX_ROTATE_F* vec_znx_rotate;
|
||||||
|
VEC_ZNX_AUTOMORPHISM_F* vec_znx_automorphism;
|
||||||
|
VEC_ZNX_NORMALIZE_BASE2K_F* vec_znx_normalize_base2k;
|
||||||
|
VEC_ZNX_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_normalize_base2k_tmp_bytes;
|
||||||
|
VEC_ZNX_BIG_NORMALIZE_BASE2K_F* vec_znx_big_normalize_base2k;
|
||||||
|
VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_big_normalize_base2k_tmp_bytes;
|
||||||
|
VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F* vec_znx_big_range_normalize_base2k;
|
||||||
|
VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_big_range_normalize_base2k_tmp_bytes;
|
||||||
|
VEC_ZNX_BIG_ADD_F* vec_znx_big_add;
|
||||||
|
VEC_ZNX_BIG_ADD_SMALL_F* vec_znx_big_add_small;
|
||||||
|
VEC_ZNX_BIG_ADD_SMALL2_F* vec_znx_big_add_small2;
|
||||||
|
VEC_ZNX_BIG_SUB_F* vec_znx_big_sub;
|
||||||
|
VEC_ZNX_BIG_SUB_SMALL_A_F* vec_znx_big_sub_small_a;
|
||||||
|
VEC_ZNX_BIG_SUB_SMALL_B_F* vec_znx_big_sub_small_b;
|
||||||
|
VEC_ZNX_BIG_SUB_SMALL2_F* vec_znx_big_sub_small2;
|
||||||
|
VEC_ZNX_BIG_ROTATE_F* vec_znx_big_rotate;
|
||||||
|
VEC_ZNX_BIG_AUTOMORPHISM_F* vec_znx_big_automorphism;
|
||||||
|
SVP_PREPARE* svp_prepare;
|
||||||
|
SVP_APPLY_DFT_F* svp_apply_dft;
|
||||||
|
ZNX_SMALL_SINGLE_PRODUCT_F* znx_small_single_product;
|
||||||
|
ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx_small_single_product_tmp_bytes;
|
||||||
|
VMP_PREPARE_CONTIGUOUS_F* vmp_prepare_contiguous;
|
||||||
|
VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F* vmp_prepare_contiguous_tmp_bytes;
|
||||||
|
VMP_APPLY_DFT_F* vmp_apply_dft;
|
||||||
|
VMP_APPLY_DFT_TMP_BYTES_F* vmp_apply_dft_tmp_bytes;
|
||||||
|
VMP_APPLY_DFT_TO_DFT_F* vmp_apply_dft_to_dft;
|
||||||
|
VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* vmp_apply_dft_to_dft_tmp_bytes;
|
||||||
|
BYTES_OF_VEC_ZNX_DFT_F* bytes_of_vec_znx_dft;
|
||||||
|
BYTES_OF_VEC_ZNX_BIG_F* bytes_of_vec_znx_big;
|
||||||
|
BYTES_OF_SVP_PPOL_F* bytes_of_svp_ppol;
|
||||||
|
BYTES_OF_VMP_PMAT_F* bytes_of_vmp_pmat;
|
||||||
|
};
|
||||||
|
|
||||||
|
union backend_module_info_t {
|
||||||
|
struct fft64_module_info_t fft64;
|
||||||
|
struct q120_module_info_t q120;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct module_info_t {
|
||||||
|
// generic parameters
|
||||||
|
MODULE_TYPE module_type;
|
||||||
|
uint64_t nn;
|
||||||
|
uint64_t m;
|
||||||
|
// backend_dependent functions
|
||||||
|
union backend_module_info_t mod;
|
||||||
|
// virtual functions
|
||||||
|
struct module_virtual_functions_t func;
|
||||||
|
};
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||||
|
uint64_t size);
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_vec_znx_big(const MODULE* module, // N
|
||||||
|
uint64_t size);
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_svp_ppol(const MODULE* module); // N
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_vmp_pmat(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_zero_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_copy_ref(const MODULE* precomp, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_negate_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_negate_avx(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_add_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
EXPORT void vec_znx_add_avx(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_sub_ref(const MODULE* precomp, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_sub_avx(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module, // N
|
||||||
|
uint64_t log2_base2k, // output base 2^K
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // inp
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module // N
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_rotate_ref(const MODULE* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_automorphism_ref(const MODULE* module, // N
|
||||||
|
const int64_t p, // X->X^p
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vmp_prepare_ref(const MODULE* precomp, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vmp_apply_dft_ref(const MODULE* precomp, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_dft_zero_ref(const MODULE* precomp, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size // res
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_dft_add_ref(const MODULE* precomp, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_dft_sub_ref(const MODULE* precomp, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_dft_ref(const MODULE* precomp, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_idft_ref(const MODULE* precomp, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_big_normalize_ref(const MODULE* precomp, // N
|
||||||
|
uint64_t k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||||
|
EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N
|
||||||
|
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||||
|
const SVP_PPOL* ppol, // prepared pol
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = k-normalize(a) -- output in int64 coeffs space */
|
||||||
|
EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module, // N
|
||||||
|
uint64_t k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp_space // temp space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */
|
||||||
|
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module // N
|
||||||
|
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */
|
||||||
|
EXPORT void fft64_vec_znx_big_range_normalize_base2k(const MODULE* module, // N
|
||||||
|
uint64_t log2_base2k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_range_begin, // a
|
||||||
|
uint64_t a_range_xend, uint64_t a_range_step, // range
|
||||||
|
uint8_t* tmp_space // temp space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
|
||||||
|
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes(const MODULE* module // N
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_idft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_vec_znx_idft_tmp_bytes(const MODULE* module);
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void ntt120_vec_znx_dft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** */
|
||||||
|
EXPORT void ntt120_vec_znx_idft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT uint64_t ntt120_vec_znx_idft_tmp_bytes_avx(const MODULE* module);
|
||||||
|
|
||||||
|
EXPORT void ntt120_vec_znx_idft_tmp_a_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||||
|
);
|
||||||
|
|
||||||
|
// big additions/subtractions
|
||||||
|
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void fft64_vec_znx_big_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void fft64_vec_znx_big_add_small(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
EXPORT void fft64_vec_znx_big_add_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a-b */
|
||||||
|
EXPORT void fft64_vec_znx_big_sub(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
EXPORT void fft64_vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
EXPORT void fft64_vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
EXPORT void fft64_vec_znx_big_sub_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void fft64_vec_znx_big_rotate(const MODULE* module, // N
|
||||||
|
int64_t p, // rotation value
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void fft64_vec_znx_big_automorphism(const MODULE* module, // N
|
||||||
|
int64_t p, // X-X^p
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief prepares a svp polynomial */
|
||||||
|
EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N
|
||||||
|
SVP_PPOL* ppol, // output
|
||||||
|
const int64_t* pol // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief res = a * b : small integer polynomial product */
|
||||||
|
EXPORT void fft64_znx_small_single_product(const MODULE* module, // N
|
||||||
|
int64_t* res, // output
|
||||||
|
const int64_t* a, // a
|
||||||
|
const int64_t* b, // b
|
||||||
|
uint8_t* tmp);
|
||||||
|
|
||||||
|
/** @brief tmp bytes required for znx_small_single_product */
|
||||||
|
EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module);
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||||
|
EXPORT uint64_t fft64_vmp_prepare_contiguous_tmp_bytes(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols);
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief this inner function could be very handy */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_to_dft_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||||
|
const uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief this inner function could be very handy */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_to_dft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||||
|
const uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t fft64_vmp_apply_dft_tmp_bytes(const MODULE* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t fft64_vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
#endif // SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H
|
||||||
103
spqlios/lib/spqlios/arithmetic/vec_znx_avx.c
Normal file
103
spqlios/lib/spqlios/arithmetic/vec_znx_avx.c
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../coeffs/coeffs_arithmetic.h"
|
||||||
|
#include "../reim4/reim4_arithmetic.h"
|
||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
// specialized function (ref)
|
||||||
|
|
||||||
|
// Note: these functions do not have an avx variant.
|
||||||
|
#define znx_copy_i64_avx znx_copy_i64_ref
|
||||||
|
#define znx_zero_i64_avx znx_zero_i64_ref
|
||||||
|
|
||||||
|
EXPORT void vec_znx_add_avx(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
if (a_size <= b_size) {
|
||||||
|
const uint64_t sum_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
// add up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||||
|
znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then copy to the largest dimension
|
||||||
|
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||||
|
znx_copy_i64_avx(nn, res + i * res_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t sum_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// add up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||||
|
znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then copy to the largest dimension
|
||||||
|
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||||
|
znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_sub_avx(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
if (a_size <= b_size) {
|
||||||
|
const uint64_t sub_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
// subtract up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||||
|
znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then negate to the largest dimension
|
||||||
|
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||||
|
znx_negate_i64_avx(nn, res + i * res_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t sub_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// subtract up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||||
|
znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then copy to the largest dimension
|
||||||
|
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||||
|
znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_negate_avx(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
uint64_t nn = module->nn;
|
||||||
|
uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < smin; ++i) {
|
||||||
|
znx_negate_i64_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = smin; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
270
spqlios/lib/spqlios/arithmetic/vec_znx_big.c
Normal file
270
spqlios/lib/spqlios/arithmetic/vec_znx_big.c
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT uint64_t bytes_of_vec_znx_big(const MODULE* module, // N
|
||||||
|
uint64_t size) {
|
||||||
|
return module->func.bytes_of_vec_znx_big(module, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// public wrappers
|
||||||
|
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void vec_znx_big_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_add(module, res, res_size, a, a_size, b, b_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void vec_znx_big_add_small(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_add_small(module, res, res_size, a, a_size, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_big_add_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_add_small2(module, res, res_size, a, a_size, a_sl, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a-b */
|
||||||
|
EXPORT void vec_znx_big_sub(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_sub(module, res, res_size, a, a_size, b, b_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_sub_small_b(module, res, res_size, a, a_size, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_sub_small_a(module, res, res_size, a, a_size, a_sl, b, b_size);
|
||||||
|
}
|
||||||
|
EXPORT void vec_znx_big_sub_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_sub_small2(module, res, res_size, a, a_size, a_sl, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void vec_znx_big_rotate(const MODULE* module, // N
|
||||||
|
int64_t p, // rotation value
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_rotate(module, p, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void vec_znx_big_automorphism(const MODULE* module, // N
|
||||||
|
int64_t p, // X-X^p
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_automorphism(module, p, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// private wrappers
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_vec_znx_big(const MODULE* module, // N
|
||||||
|
uint64_t size) {
|
||||||
|
return module->nn * size * sizeof(double);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT VEC_ZNX_BIG* new_vec_znx_big(const MODULE* module, // N
|
||||||
|
uint64_t size) {
|
||||||
|
return spqlios_alloc(bytes_of_vec_znx_big(module, size));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void delete_vec_znx_big(VEC_ZNX_BIG* res) { spqlios_free(res); }
|
||||||
|
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void fft64_vec_znx_big_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
const uint64_t n = module->nn;
|
||||||
|
vec_znx_add(module, //
|
||||||
|
(int64_t*)res, res_size, n, //
|
||||||
|
(int64_t*)a, a_size, n, //
|
||||||
|
(int64_t*)b, b_size, n);
|
||||||
|
}
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void fft64_vec_znx_big_add_small(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t n = module->nn;
|
||||||
|
vec_znx_add(module, //
|
||||||
|
(int64_t*)res, res_size, n, //
|
||||||
|
(int64_t*)a, a_size, n, //
|
||||||
|
b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
EXPORT void fft64_vec_znx_big_add_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t n = module->nn;
|
||||||
|
vec_znx_add(module, //
|
||||||
|
(int64_t*)res, res_size, n, //
|
||||||
|
a, a_size, a_sl, //
|
||||||
|
b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a-b */
|
||||||
|
EXPORT void fft64_vec_znx_big_sub(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
const uint64_t n = module->nn;
|
||||||
|
vec_znx_sub(module, //
|
||||||
|
(int64_t*)res, res_size, n, //
|
||||||
|
(int64_t*)a, a_size, n, //
|
||||||
|
(int64_t*)b, b_size, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t n = module->nn;
|
||||||
|
vec_znx_sub(module, //
|
||||||
|
(int64_t*)res, res_size, n, //
|
||||||
|
(int64_t*)a, a_size, //
|
||||||
|
n, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
EXPORT void fft64_vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
const uint64_t n = module->nn;
|
||||||
|
vec_znx_sub(module, //
|
||||||
|
(int64_t*)res, res_size, n, //
|
||||||
|
a, a_size, a_sl, //
|
||||||
|
(int64_t*)b, b_size, n);
|
||||||
|
}
|
||||||
|
EXPORT void fft64_vec_znx_big_sub_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t n = module->nn;
|
||||||
|
vec_znx_sub(module, //
|
||||||
|
(int64_t*)res, res_size, //
|
||||||
|
n, a, a_size, //
|
||||||
|
a_sl, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void fft64_vec_znx_big_rotate(const MODULE* module, // N
|
||||||
|
int64_t p, // rotation value
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
uint64_t nn = module->nn;
|
||||||
|
vec_znx_rotate(module, p, (int64_t*)res, res_size, nn, (int64_t*)a, a_size, nn);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void fft64_vec_znx_big_automorphism(const MODULE* module, // N
|
||||||
|
int64_t p, // X-X^p
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
uint64_t nn = module->nn;
|
||||||
|
vec_znx_automorphism(module, p, (int64_t*)res, res_size, nn, (int64_t*)a, a_size, nn);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_big_normalize_base2k(const MODULE* module, // N
|
||||||
|
uint64_t k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp_space // temp space
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_normalize_base2k(module, // N
|
||||||
|
k, // base-2^k
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a, a_size, // a
|
||||||
|
tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module // N
|
||||||
|
) {
|
||||||
|
return module->func.vec_znx_big_normalize_base2k_tmp_bytes(module // N
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */
|
||||||
|
EXPORT void vec_znx_big_range_normalize_base2k( //
|
||||||
|
const MODULE* module, // N
|
||||||
|
uint64_t log2_base2k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_range_begin, uint64_t a_range_xend, uint64_t a_range_step, // range
|
||||||
|
uint8_t* tmp_space // temp space
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_range_normalize_base2k(module, log2_base2k, res, res_size, res_sl, a, a_range_begin,
|
||||||
|
a_range_xend, a_range_step, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
|
||||||
|
EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||||
|
const MODULE* module // N
|
||||||
|
) {
|
||||||
|
return module->func.vec_znx_big_range_normalize_base2k_tmp_bytes(module);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module, // N
|
||||||
|
uint64_t k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp_space) {
|
||||||
|
uint64_t a_sl = module->nn;
|
||||||
|
module->func.vec_znx_normalize_base2k(module, // N
|
||||||
|
k, // log2_base2k
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
(int64_t*)a, a_size, a_sl, // a
|
||||||
|
tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_big_range_normalize_base2k( //
|
||||||
|
const MODULE* module, // N
|
||||||
|
uint64_t k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_begin, uint64_t a_end, uint64_t a_step, // a
|
||||||
|
uint8_t* tmp_space) {
|
||||||
|
// convert the range indexes to int64[] slices
|
||||||
|
const int64_t* a_st = ((int64_t*)a) + module->nn * a_begin;
|
||||||
|
const uint64_t a_size = (a_end + a_step - 1 - a_begin) / a_step;
|
||||||
|
const uint64_t a_sl = module->nn * a_step;
|
||||||
|
// forward the call
|
||||||
|
module->func.vec_znx_normalize_base2k(module, // N
|
||||||
|
k, // log2_base2k
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a_st, a_size, a_sl, // a
|
||||||
|
tmp_space);
|
||||||
|
}
|
||||||
162
spqlios/lib/spqlios/arithmetic/vec_znx_dft.c
Normal file
162
spqlios/lib/spqlios/arithmetic/vec_znx_dft.c
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../q120/q120_arithmetic.h"
|
||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT void vec_znx_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
return module->func.vec_znx_dft(module, res, res_size, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_idft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp // scratch space
|
||||||
|
) {
|
||||||
|
return module->func.vec_znx_idft(module, res, res_size, a_dft, a_size, tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t vec_znx_idft_tmp_bytes(const MODULE* module) { return module->func.vec_znx_idft_tmp_bytes(module); }
|
||||||
|
|
||||||
|
EXPORT void vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||||
|
) {
|
||||||
|
return module->func.vec_znx_idft_tmp_a(module, res, res_size, a_dft, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||||
|
uint64_t size) {
|
||||||
|
return module->func.bytes_of_vec_znx_dft(module, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fft64 backend
|
||||||
|
EXPORT uint64_t fft64_bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||||
|
uint64_t size) {
|
||||||
|
return module->nn * size * sizeof(double);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT VEC_ZNX_DFT* new_vec_znx_dft(const MODULE* module, // N
|
||||||
|
uint64_t size) {
|
||||||
|
return spqlios_alloc(bytes_of_vec_znx_dft(module, size));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void delete_vec_znx_dft(VEC_ZNX_DFT* res) { spqlios_free(res); }
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < smin; i++) {
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, ((double*)res) + i * nn, a + i * a_sl);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, ((double*)res) + i * nn);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill up remaining part with 0's
|
||||||
|
double* const dres = (double*)res;
|
||||||
|
memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_idft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp // unused
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
if ((double*)res != (double*)a_dft) {
|
||||||
|
memcpy(res, a_dft, smin * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < smin; i++) {
|
||||||
|
reim_ifft(module->mod.fft64.p_ifft, ((double*)res) + i * nn);
|
||||||
|
reim_to_znx64(module->mod.fft64.p_reim_to_znx, ((int64_t*)res) + i * nn, ((int64_t*)res) + i * nn);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill up remaining part with 0's
|
||||||
|
int64_t* const dres = (int64_t*)res;
|
||||||
|
memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_vec_znx_idft_tmp_bytes(const MODULE* module) { return 0; }
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
|
||||||
|
int64_t* const tres = (int64_t*)res;
|
||||||
|
double* const ta = (double*)a_dft;
|
||||||
|
for (uint64_t i = 0; i < smin; i++) {
|
||||||
|
reim_ifft(module->mod.fft64.p_ifft, ta + i * nn);
|
||||||
|
reim_to_znx64(module->mod.fft64.p_reim_to_znx, tres + i * nn, ta + i * nn);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill up remaining part with 0's
|
||||||
|
memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ntt120 backend
|
||||||
|
|
||||||
|
EXPORT void ntt120_vec_znx_dft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
|
||||||
|
int64_t* tres = (int64_t*)res;
|
||||||
|
for (uint64_t i = 0; i < smin; i++) {
|
||||||
|
q120_b_from_znx64_simple(nn, (q120b*)(tres + i * nn * 4), a + i * a_sl);
|
||||||
|
q120_ntt_bb_avx2(module->mod.q120.p_ntt, (q120b*)(tres + i * nn * 4));
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill up remaining part with 0's
|
||||||
|
memset(tres + smin * nn * 4, 0, (res_size - smin) * nn * 4 * sizeof(int64_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void ntt120_vec_znx_idft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
|
||||||
|
__int128_t* const tres = (__int128_t*)res;
|
||||||
|
const int64_t* const ta = (int64_t*)a_dft;
|
||||||
|
for (uint64_t i = 0; i < smin; i++) {
|
||||||
|
memcpy(tmp, ta + i * nn * 4, nn * 4 * sizeof(uint64_t));
|
||||||
|
q120_intt_bb_avx2(module->mod.q120.p_intt, (q120b*)tmp);
|
||||||
|
q120_b_to_znx128_simple(nn, tres + i * nn, (q120b*)tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill up remaining part with 0's
|
||||||
|
memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(*tres));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t ntt120_vec_znx_idft_tmp_bytes_avx(const MODULE* module) { return module->nn * 4 * sizeof(uint64_t); }
|
||||||
|
|
||||||
|
EXPORT void ntt120_vec_znx_idft_tmp_a_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
|
||||||
|
__int128_t* const tres = (__int128_t*)res;
|
||||||
|
int64_t* const ta = (int64_t*)a_dft;
|
||||||
|
for (uint64_t i = 0; i < smin; i++) {
|
||||||
|
q120_intt_bb_avx2(module->mod.q120.p_intt, (q120b*)(ta + i * nn * 4));
|
||||||
|
q120_b_to_znx128_simple(nn, tres + i * nn, (q120b*)(ta + i * nn * 4));
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill up remaining part with 0's
|
||||||
|
memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(*tres));
|
||||||
|
}
|
||||||
1
spqlios/lib/spqlios/arithmetic/vec_znx_dft_avx2.c
Normal file
1
spqlios/lib/spqlios/arithmetic/vec_znx_dft_avx2.c
Normal file
@@ -0,0 +1 @@
|
|||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
240
spqlios/lib/spqlios/arithmetic/vector_matrix_product.c
Normal file
240
spqlios/lib/spqlios/arithmetic/vector_matrix_product.c
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../reim4/reim4_arithmetic.h"
|
||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT uint64_t bytes_of_vmp_pmat(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols // dimensions
|
||||||
|
) {
|
||||||
|
return module->func.bytes_of_vmp_pmat(module, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fft64
|
||||||
|
EXPORT uint64_t fft64_bytes_of_vmp_pmat(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols // dimensions
|
||||||
|
) {
|
||||||
|
return module->nn * nrows * ncols * sizeof(double);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT VMP_PMAT* new_vmp_pmat(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols // dimensions
|
||||||
|
) {
|
||||||
|
return spqlios_alloc(bytes_of_vmp_pmat(module, nrows, ncols));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void delete_vmp_pmat(VMP_PMAT* res) { spqlios_free(res); }
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void vmp_prepare_contiguous(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
module->func.vmp_prepare_contiguous(module, pmat, mat, nrows, ncols, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||||
|
EXPORT uint64_t vmp_prepare_contiguous_tmp_bytes(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols) {
|
||||||
|
return module->func.vmp_prepare_contiguous_tmp_bytes(module, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
// there is an edge case if nn < 8
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
|
||||||
|
double* output_mat = (double*)pmat;
|
||||||
|
double* start_addr = (double*)pmat;
|
||||||
|
uint64_t offset = nrows * ncols * 8;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)tmp_space, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, (double*)tmp_space);
|
||||||
|
|
||||||
|
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||||
|
// special case: last column out of an odd column number
|
||||||
|
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||||
|
+ row_i * 8;
|
||||||
|
} else {
|
||||||
|
// general case: columns go by pair
|
||||||
|
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||||
|
+ row_i * 2 * 8 // third: row index
|
||||||
|
+ (col_i % 2) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
// extract blk from tmp and save it
|
||||||
|
reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, (double*)tmp_space);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
double* res = (double*)pmat + (col_i * nrows + row_i) * nn;
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)res, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||||
|
EXPORT uint64_t fft64_vmp_prepare_contiguous_tmp_bytes(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
return nn * sizeof(int64_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||||
|
|
||||||
|
VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space;
|
||||||
|
uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double);
|
||||||
|
|
||||||
|
fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl);
|
||||||
|
fft64_vmp_apply_dft_to_dft_ref(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, new_tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief this inner function could be very handy */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_to_dft_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||||
|
const uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
|
||||||
|
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||||
|
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||||
|
|
||||||
|
double* mat_input = (double*)pmat;
|
||||||
|
double* vec_input = (double*)a_dft;
|
||||||
|
double* vec_output = (double*)res;
|
||||||
|
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||||
|
|
||||||
|
reim4_extract_1blk_from_contiguous_reim_ref(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft);
|
||||||
|
// apply mat2cols
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||||
|
uint64_t col_offset = col_i * (8 * nrows);
|
||||||
|
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
|
||||||
|
reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + col_i * nn, mat2cols_output);
|
||||||
|
reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + (col_i + 1) * nn, mat2cols_output + 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if col_max is odd, then special case
|
||||||
|
if (col_max % 2 == 1) {
|
||||||
|
uint64_t last_col = col_max - 1;
|
||||||
|
uint64_t col_offset = last_col * (8 * nrows);
|
||||||
|
|
||||||
|
// the last column is alone in the pmat: vec_mat1col
|
||||||
|
if (ncols == col_max) {
|
||||||
|
reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
} else {
|
||||||
|
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||||
|
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
}
|
||||||
|
reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + last_col * nn, mat2cols_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||||
|
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||||
|
for (uint64_t row_i = 0; row_i < 1; row_i++) {
|
||||||
|
reim_fftvec_mul(module->mod.fft64.mul_fft, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||||
|
pmat_col + row_i * nn);
|
||||||
|
}
|
||||||
|
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||||
|
reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||||
|
pmat_col + row_i * nn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out remaining bytes
|
||||||
|
memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t fft64_vmp_apply_dft_tmp_bytes(const MODULE* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
|
||||||
|
return (row_max * nn * sizeof(double)) + (128) + (64 * row_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t fft64_vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) {
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
|
||||||
|
return (128) + (64 * row_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vmp_apply_dft_to_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||||
|
const uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
module->func.vmp_apply_dft_to_dft(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) {
|
||||||
|
return module->func.vmp_apply_dft_to_dft_tmp_bytes(module, res_size, a_size, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) */
|
||||||
|
EXPORT void vmp_apply_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
module->func.vmp_apply_dft(module, res, res_size, a, a_size, a_sl, pmat, nrows, ncols, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t vmp_apply_dft_tmp_bytes(const MODULE* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) {
|
||||||
|
return module->func.vmp_apply_dft_tmp_bytes(module, res_size, a_size, nrows, ncols);
|
||||||
|
}
|
||||||
137
spqlios/lib/spqlios/arithmetic/vector_matrix_product_avx.c
Normal file
137
spqlios/lib/spqlios/arithmetic/vector_matrix_product_avx.c
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../reim4/reim4_arithmetic.h"
|
||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
// there is an edge case if nn < 8
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
|
||||||
|
double* output_mat = (double*)pmat;
|
||||||
|
double* start_addr = (double*)pmat;
|
||||||
|
uint64_t offset = nrows * ncols * 8;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)tmp_space, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, (double*)tmp_space);
|
||||||
|
|
||||||
|
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||||
|
// special case: last column out of an odd column number
|
||||||
|
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||||
|
+ row_i * 8;
|
||||||
|
} else {
|
||||||
|
// general case: columns go by pair
|
||||||
|
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||||
|
+ row_i * 2 * 8 // third: row index
|
||||||
|
+ (col_i % 2) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
// extract blk from tmp and save it
|
||||||
|
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, (double*)tmp_space);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
double* res = (double*)pmat + (col_i * nrows + row_i) * nn;
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)res, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||||
|
|
||||||
|
VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space;
|
||||||
|
uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double);
|
||||||
|
|
||||||
|
fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl);
|
||||||
|
fft64_vmp_apply_dft_to_dft_avx(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, new_tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief this inner function could be very handy */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_to_dft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||||
|
const uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
|
||||||
|
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||||
|
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||||
|
|
||||||
|
double* mat_input = (double*)pmat;
|
||||||
|
double* vec_input = (double*)a_dft;
|
||||||
|
double* vec_output = (double*)res;
|
||||||
|
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||||
|
|
||||||
|
reim4_extract_1blk_from_contiguous_reim_avx(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft);
|
||||||
|
// apply mat2cols
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||||
|
uint64_t col_offset = col_i * (8 * nrows);
|
||||||
|
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
|
||||||
|
reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + col_i * nn, mat2cols_output);
|
||||||
|
reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + (col_i + 1) * nn, mat2cols_output + 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if col_max is odd, then special case
|
||||||
|
if (col_max % 2 == 1) {
|
||||||
|
uint64_t last_col = col_max - 1;
|
||||||
|
uint64_t col_offset = last_col * (8 * nrows);
|
||||||
|
|
||||||
|
// the last column is alone in the pmat: vec_mat1col
|
||||||
|
if (ncols == col_max)
|
||||||
|
reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
else {
|
||||||
|
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||||
|
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
}
|
||||||
|
reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + last_col * nn, mat2cols_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||||
|
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||||
|
for (uint64_t row_i = 0; row_i < 1; row_i++) {
|
||||||
|
reim_fftvec_mul(module->mod.fft64.mul_fft, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||||
|
pmat_col + row_i * nn);
|
||||||
|
}
|
||||||
|
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||||
|
reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||||
|
pmat_col + row_i * nn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out remaining bytes
|
||||||
|
memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double));
|
||||||
|
}
|
||||||
169
spqlios/lib/spqlios/arithmetic/zn_api.c
Normal file
169
spqlios/lib/spqlios/arithmetic/zn_api.c
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "zn_arithmetic_private.h"
|
||||||
|
|
||||||
|
void default_init_z_module_precomp(MOD_Z* module) {
|
||||||
|
// Add here initialization of items that are in the precomp
|
||||||
|
}
|
||||||
|
|
||||||
|
void default_finalize_z_module_precomp(MOD_Z* module) {
|
||||||
|
// Add here deleters for items that are in the precomp
|
||||||
|
}
|
||||||
|
|
||||||
|
void default_init_z_module_vtable(MOD_Z* module) {
|
||||||
|
// Add function pointers here
|
||||||
|
module->vtable.i8_approxdecomp_from_tndbl = default_i8_approxdecomp_from_tndbl_ref;
|
||||||
|
module->vtable.i16_approxdecomp_from_tndbl = default_i16_approxdecomp_from_tndbl_ref;
|
||||||
|
module->vtable.i32_approxdecomp_from_tndbl = default_i32_approxdecomp_from_tndbl_ref;
|
||||||
|
module->vtable.zn32_vmp_prepare_contiguous = default_zn32_vmp_prepare_contiguous_ref;
|
||||||
|
module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_ref;
|
||||||
|
module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_ref;
|
||||||
|
module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_ref;
|
||||||
|
module->vtable.dbl_to_tn32 = dbl_to_tn32_ref;
|
||||||
|
module->vtable.tn32_to_dbl = tn32_to_dbl_ref;
|
||||||
|
module->vtable.dbl_round_to_i32 = dbl_round_to_i32_ref;
|
||||||
|
module->vtable.i32_to_dbl = i32_to_dbl_ref;
|
||||||
|
module->vtable.dbl_round_to_i64 = dbl_round_to_i64_ref;
|
||||||
|
module->vtable.i64_to_dbl = i64_to_dbl_ref;
|
||||||
|
|
||||||
|
// Add optimized function pointers here
|
||||||
|
if (CPU_SUPPORTS("avx")) {
|
||||||
|
module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_avx;
|
||||||
|
module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_avx;
|
||||||
|
module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_avx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void init_z_module_info(MOD_Z* module, //
|
||||||
|
Z_MODULE_TYPE mtype) {
|
||||||
|
memset(module, 0, sizeof(MOD_Z));
|
||||||
|
module->mtype = mtype;
|
||||||
|
switch (mtype) {
|
||||||
|
case DEFAULT:
|
||||||
|
default_init_z_module_precomp(module);
|
||||||
|
default_init_z_module_vtable(module);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
NOT_SUPPORTED(); // unknown mtype
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void finalize_z_module_info(MOD_Z* module) {
|
||||||
|
if (module->custom) module->custom_deleter(module->custom);
|
||||||
|
switch (module->mtype) {
|
||||||
|
case DEFAULT:
|
||||||
|
default_finalize_z_module_precomp(module);
|
||||||
|
// fft64_finalize_rnx_module_vtable(module); // nothing to finalize
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
NOT_SUPPORTED(); // unknown mtype
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT MOD_Z* new_z_module_info(Z_MODULE_TYPE mtype) {
|
||||||
|
MOD_Z* res = (MOD_Z*)malloc(sizeof(MOD_Z));
|
||||||
|
init_z_module_info(res, mtype);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void delete_z_module_info(MOD_Z* module_info) {
|
||||||
|
finalize_z_module_info(module_info);
|
||||||
|
free(module_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////// wrappers //////////////////
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||||
|
EXPORT void i8_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size) { // a
|
||||||
|
module->vtable.i8_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||||
|
EXPORT void i16_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int16_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size) { // a
|
||||||
|
module->vtable.i16_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||||
|
EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int32_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size) { // a
|
||||||
|
module->vtable.i32_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void zn32_vmp_prepare_contiguous( //
|
||||||
|
const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t* mat, uint64_t nrows, uint64_t ncols) { // a
|
||||||
|
module->vtable.zn32_vmp_prepare_contiguous(module, pmat, mat, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int32_t* input) */
|
||||||
|
EXPORT void zn32_vmp_apply_i32(const MOD_Z* module, int32_t* res, uint64_t res_size, const int32_t* a, uint64_t a_size,
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||||
|
module->vtable.zn32_vmp_apply_i32(module, res, res_size, a, a_size, pmat, nrows, ncols);
|
||||||
|
}
|
||||||
|
/** @brief applies a vmp product (int16_t* input) */
|
||||||
|
EXPORT void zn32_vmp_apply_i16(const MOD_Z* module, int32_t* res, uint64_t res_size, const int16_t* a, uint64_t a_size,
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||||
|
module->vtable.zn32_vmp_apply_i16(module, res, res_size, a, a_size, pmat, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int8_t* input) */
|
||||||
|
EXPORT void zn32_vmp_apply_i8(const MOD_Z* module, int32_t* res, uint64_t res_size, const int8_t* a, uint64_t a_size,
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||||
|
module->vtable.zn32_vmp_apply_i8(module, res, res_size, a, a_size, pmat, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** reduction mod 1, output in torus32 space */
|
||||||
|
EXPORT void dbl_to_tn32(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->vtable.dbl_to_tn32(module, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** real centerlift mod 1, output in double space */
|
||||||
|
EXPORT void tn32_to_dbl(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->vtable.tn32_to_dbl(module, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** round to the nearest int, output in i32 space */
|
||||||
|
EXPORT void dbl_round_to_i32(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->vtable.dbl_round_to_i32(module, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** small int (int32 space) to double */
|
||||||
|
EXPORT void i32_to_dbl(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->vtable.i32_to_dbl(module, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** round to the nearest int, output in int64 space */
|
||||||
|
EXPORT void dbl_round_to_i64(const MOD_Z* module, //
|
||||||
|
int64_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->vtable.dbl_round_to_i64(module, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** small int (int64 space, <= 2^50) to double */
|
||||||
|
EXPORT void i64_to_dbl(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->vtable.i64_to_dbl(module, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
81
spqlios/lib/spqlios/arithmetic/zn_approxdecomp_ref.c
Normal file
81
spqlios/lib/spqlios/arithmetic/zn_approxdecomp_ref.c
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "zn_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, //
|
||||||
|
uint64_t k, uint64_t ell) {
|
||||||
|
if (k * ell > 50) {
|
||||||
|
return spqlios_error("approx decomposition requested is too precise for doubles");
|
||||||
|
}
|
||||||
|
if (k < 1) {
|
||||||
|
return spqlios_error("approx decomposition supports k>=1");
|
||||||
|
}
|
||||||
|
TNDBL_APPROXDECOMP_GADGET* res = malloc(sizeof(TNDBL_APPROXDECOMP_GADGET));
|
||||||
|
memset(res, 0, sizeof(TNDBL_APPROXDECOMP_GADGET));
|
||||||
|
res->k = k;
|
||||||
|
res->ell = ell;
|
||||||
|
double add_cst = INT64_C(3) << (51 - k * ell);
|
||||||
|
for (uint64_t i = 0; i < ell; ++i) {
|
||||||
|
add_cst += pow(2., -(double)(i * k + 1));
|
||||||
|
}
|
||||||
|
res->add_cst = add_cst;
|
||||||
|
res->and_mask = (UINT64_C(1) << k) - 1;
|
||||||
|
res->sub_cst = UINT64_C(1) << (k - 1);
|
||||||
|
for (uint64_t i = 0; i < ell; ++i) res->rshifts[i] = (ell - 1 - i) * k;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr) { free(ptr); }
|
||||||
|
|
||||||
|
EXPORT int default_init_tndbl_approxdecomp_gadget(const MOD_Z* module, //
|
||||||
|
TNDBL_APPROXDECOMP_GADGET* res, //
|
||||||
|
uint64_t k, uint64_t ell) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef union {
|
||||||
|
double dv;
|
||||||
|
uint64_t uv;
|
||||||
|
} du_t;
|
||||||
|
|
||||||
|
#define IMPL_ixx_approxdecomp_from_tndbl_ref(ITYPE) \
|
||||||
|
if (res_size != a_size * gadget->ell) NOT_IMPLEMENTED(); \
|
||||||
|
const uint64_t ell = gadget->ell; \
|
||||||
|
const double add_cst = gadget->add_cst; \
|
||||||
|
const uint8_t* const rshifts = gadget->rshifts; \
|
||||||
|
const ITYPE and_mask = gadget->and_mask; \
|
||||||
|
const ITYPE sub_cst = gadget->sub_cst; \
|
||||||
|
ITYPE* rr = res; \
|
||||||
|
const double* aa = a; \
|
||||||
|
const double* aaend = a + a_size; \
|
||||||
|
while (aa < aaend) { \
|
||||||
|
du_t t = {.dv = *aa + add_cst}; \
|
||||||
|
for (uint64_t i = 0; i < ell; ++i) { \
|
||||||
|
ITYPE v = (ITYPE)(t.uv >> rshifts[i]); \
|
||||||
|
*rr = (v & and_mask) - sub_cst; \
|
||||||
|
++rr; \
|
||||||
|
} \
|
||||||
|
++aa; \
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||||
|
EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size //
|
||||||
|
){IMPL_ixx_approxdecomp_from_tndbl_ref(int8_t)}
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||||
|
EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int16_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
){IMPL_ixx_approxdecomp_from_tndbl_ref(int16_t)}
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||||
|
EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
IMPL_ixx_approxdecomp_from_tndbl_ref(int32_t)
|
||||||
|
}
|
||||||
135
spqlios/lib/spqlios/arithmetic/zn_arithmetic.h
Normal file
135
spqlios/lib/spqlios/arithmetic/zn_arithmetic.h
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
#ifndef SPQLIOS_ZN_ARITHMETIC_H
|
||||||
|
#define SPQLIOS_ZN_ARITHMETIC_H
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "../commons.h"
|
||||||
|
|
||||||
|
typedef enum z_module_type_t { DEFAULT } Z_MODULE_TYPE;
|
||||||
|
|
||||||
|
/** @brief opaque structure that describes the module and the hardware */
|
||||||
|
typedef struct z_module_info_t MOD_Z;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief obtain a module info for ring dimension N
|
||||||
|
* the module-info knows about:
|
||||||
|
* - the dimension N (or the complex dimension m=N/2)
|
||||||
|
* - any moduleuted fft or ntt items
|
||||||
|
* - the hardware (avx, arm64, x86, ...)
|
||||||
|
*/
|
||||||
|
EXPORT MOD_Z* new_z_module_info(Z_MODULE_TYPE mode);
|
||||||
|
EXPORT void delete_z_module_info(MOD_Z* module_info);
|
||||||
|
|
||||||
|
typedef struct tndbl_approxdecomp_gadget_t TNDBL_APPROXDECOMP_GADGET;
|
||||||
|
|
||||||
|
EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, //
|
||||||
|
uint64_t k,
|
||||||
|
uint64_t ell); // base 2^k, and size
|
||||||
|
|
||||||
|
EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr);
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||||
|
EXPORT void i8_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size); // a
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||||
|
EXPORT void i16_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int16_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size); // a
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||||
|
EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int32_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size); // a
|
||||||
|
|
||||||
|
/** @brief opaque type that represents a prepared matrix */
|
||||||
|
typedef struct zn32_vmp_pmat_t ZN32_VMP_PMAT;
|
||||||
|
|
||||||
|
/** @brief size in bytes of a prepared matrix (for custom allocation) */
|
||||||
|
EXPORT uint64_t bytes_of_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols); // dimensions
|
||||||
|
|
||||||
|
/** @brief allocates a prepared matrix (release with delete_zn32_vmp_pmat) */
|
||||||
|
EXPORT ZN32_VMP_PMAT* new_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols); // dimensions
|
||||||
|
|
||||||
|
/** @brief deletes a prepared matrix (release with free) */
|
||||||
|
EXPORT void delete_zn32_vmp_pmat(ZN32_VMP_PMAT* ptr); // dimensions
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void zn32_vmp_prepare_contiguous( //
|
||||||
|
const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t* mat, uint64_t nrows, uint64_t ncols); // a
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int32_t* input) */
|
||||||
|
EXPORT void zn32_vmp_apply_i32( //
|
||||||
|
const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int16_t* input) */
|
||||||
|
EXPORT void zn32_vmp_apply_i16( //
|
||||||
|
const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int16_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int8_t* input) */
|
||||||
|
EXPORT void zn32_vmp_apply_i8( //
|
||||||
|
const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int8_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
// explicit conversions
|
||||||
|
|
||||||
|
/** reduction mod 1, output in torus32 space */
|
||||||
|
EXPORT void dbl_to_tn32(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** real centerlift mod 1, output in double space */
|
||||||
|
EXPORT void tn32_to_dbl(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** round to the nearest int, output in i32 space.
|
||||||
|
* WARNING: ||a||_inf must be <= 2^18 in this function
|
||||||
|
*/
|
||||||
|
EXPORT void dbl_round_to_i32(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** small int (int32 space) to double
|
||||||
|
* WARNING: ||a||_inf must be <= 2^18 in this function
|
||||||
|
*/
|
||||||
|
EXPORT void i32_to_dbl(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** round to the nearest int, output in int64 space
|
||||||
|
* WARNING: ||a||_inf must be <= 2^50 in this function
|
||||||
|
*/
|
||||||
|
EXPORT void dbl_round_to_i64(const MOD_Z* module, //
|
||||||
|
int64_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** small int (int64 space, <= 2^50) to double
|
||||||
|
* WARNING: ||a||_inf must be <= 2^50 in this function
|
||||||
|
*/
|
||||||
|
EXPORT void i64_to_dbl(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_ZN_ARITHMETIC_H
|
||||||
39
spqlios/lib/spqlios/arithmetic/zn_arithmetic_plugin.h
Normal file
39
spqlios/lib/spqlios/arithmetic/zn_arithmetic_plugin.h
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
#ifndef SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
|
||||||
|
#define SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
|
||||||
|
|
||||||
|
#include "zn_arithmetic.h"
|
||||||
|
|
||||||
|
typedef typeof(i8_approxdecomp_from_tndbl) I8_APPROXDECOMP_FROM_TNDBL_F;
|
||||||
|
typedef typeof(i16_approxdecomp_from_tndbl) I16_APPROXDECOMP_FROM_TNDBL_F;
|
||||||
|
typedef typeof(i32_approxdecomp_from_tndbl) I32_APPROXDECOMP_FROM_TNDBL_F;
|
||||||
|
typedef typeof(bytes_of_zn32_vmp_pmat) BYTES_OF_ZN32_VMP_PMAT_F;
|
||||||
|
typedef typeof(zn32_vmp_prepare_contiguous) ZN32_VMP_PREPARE_CONTIGUOUS_F;
|
||||||
|
typedef typeof(zn32_vmp_apply_i32) ZN32_VMP_APPLY_I32_F;
|
||||||
|
typedef typeof(zn32_vmp_apply_i16) ZN32_VMP_APPLY_I16_F;
|
||||||
|
typedef typeof(zn32_vmp_apply_i8) ZN32_VMP_APPLY_I8_F;
|
||||||
|
typedef typeof(dbl_to_tn32) DBL_TO_TN32_F;
|
||||||
|
typedef typeof(tn32_to_dbl) TN32_TO_DBL_F;
|
||||||
|
typedef typeof(dbl_round_to_i32) DBL_ROUND_TO_I32_F;
|
||||||
|
typedef typeof(i32_to_dbl) I32_TO_DBL_F;
|
||||||
|
typedef typeof(dbl_round_to_i64) DBL_ROUND_TO_I64_F;
|
||||||
|
typedef typeof(i64_to_dbl) I64_TO_DBL_F;
|
||||||
|
|
||||||
|
typedef struct z_module_vtable_t Z_MODULE_VTABLE;
|
||||||
|
struct z_module_vtable_t {
|
||||||
|
I8_APPROXDECOMP_FROM_TNDBL_F* i8_approxdecomp_from_tndbl;
|
||||||
|
I16_APPROXDECOMP_FROM_TNDBL_F* i16_approxdecomp_from_tndbl;
|
||||||
|
I32_APPROXDECOMP_FROM_TNDBL_F* i32_approxdecomp_from_tndbl;
|
||||||
|
BYTES_OF_ZN32_VMP_PMAT_F* bytes_of_zn32_vmp_pmat;
|
||||||
|
ZN32_VMP_PREPARE_CONTIGUOUS_F* zn32_vmp_prepare_contiguous;
|
||||||
|
ZN32_VMP_APPLY_I32_F* zn32_vmp_apply_i32;
|
||||||
|
ZN32_VMP_APPLY_I16_F* zn32_vmp_apply_i16;
|
||||||
|
ZN32_VMP_APPLY_I8_F* zn32_vmp_apply_i8;
|
||||||
|
DBL_TO_TN32_F* dbl_to_tn32;
|
||||||
|
TN32_TO_DBL_F* tn32_to_dbl;
|
||||||
|
DBL_ROUND_TO_I32_F* dbl_round_to_i32;
|
||||||
|
I32_TO_DBL_F* i32_to_dbl;
|
||||||
|
DBL_ROUND_TO_I64_F* dbl_round_to_i64;
|
||||||
|
I64_TO_DBL_F* i64_to_dbl;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
|
||||||
150
spqlios/lib/spqlios/arithmetic/zn_arithmetic_private.h
Normal file
150
spqlios/lib/spqlios/arithmetic/zn_arithmetic_private.h
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
#ifndef SPQLIOS_ZN_ARITHMETIC_PRIVATE_H
|
||||||
|
#define SPQLIOS_ZN_ARITHMETIC_PRIVATE_H
|
||||||
|
|
||||||
|
#include "../commons_private.h"
|
||||||
|
#include "zn_arithmetic.h"
|
||||||
|
#include "zn_arithmetic_plugin.h"
|
||||||
|
|
||||||
|
typedef struct main_z_module_precomp_t MAIN_Z_MODULE_PRECOMP;
|
||||||
|
struct main_z_module_precomp_t {
|
||||||
|
// TODO
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef union z_module_precomp_t Z_MODULE_PRECOMP;
|
||||||
|
union z_module_precomp_t {
|
||||||
|
MAIN_Z_MODULE_PRECOMP main;
|
||||||
|
};
|
||||||
|
|
||||||
|
void main_init_z_module_precomp(MOD_Z* module);
|
||||||
|
|
||||||
|
void main_finalize_z_module_precomp(MOD_Z* module);
|
||||||
|
|
||||||
|
/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */
|
||||||
|
struct z_module_info_t {
|
||||||
|
Z_MODULE_TYPE mtype;
|
||||||
|
Z_MODULE_VTABLE vtable;
|
||||||
|
Z_MODULE_PRECOMP precomp;
|
||||||
|
void* custom;
|
||||||
|
void (*custom_deleter)(void*);
|
||||||
|
};
|
||||||
|
|
||||||
|
void init_z_module_info(MOD_Z* module, Z_MODULE_TYPE mtype);
|
||||||
|
|
||||||
|
void main_init_z_module_vtable(MOD_Z* module);
|
||||||
|
|
||||||
|
struct tndbl_approxdecomp_gadget_t {
|
||||||
|
uint64_t k;
|
||||||
|
uint64_t ell;
|
||||||
|
double add_cst; // 3.2^51-(K.ell) + 1/2.(sum 2^-(i+1)K)
|
||||||
|
int64_t and_mask; // (2^K)-1
|
||||||
|
int64_t sub_cst; // 2^(K-1)
|
||||||
|
uint8_t rshifts[64]; // 2^(ell-1-i).K for i in [0:ell-1]
|
||||||
|
};
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||||
|
EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size); // a
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||||
|
EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int16_t* res,
|
||||||
|
uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size); // a
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||||
|
EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int32_t* res,
|
||||||
|
uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size); // a
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void default_zn32_vmp_prepare_contiguous_ref( //
|
||||||
|
const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t* mat, uint64_t nrows, uint64_t ncols // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int32_t* input) */
|
||||||
|
EXPORT void default_zn32_vmp_apply_i32_ref( //
|
||||||
|
const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int16_t* input) */
|
||||||
|
EXPORT void default_zn32_vmp_apply_i16_ref( //
|
||||||
|
const MOD_Z* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int16_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int8_t* input) */
|
||||||
|
EXPORT void default_zn32_vmp_apply_i8_ref( //
|
||||||
|
const MOD_Z* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int8_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int32_t* input) */
|
||||||
|
EXPORT void default_zn32_vmp_apply_i32_avx( //
|
||||||
|
const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int16_t* input) */
|
||||||
|
EXPORT void default_zn32_vmp_apply_i16_avx( //
|
||||||
|
const MOD_Z* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int16_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int8_t* input) */
|
||||||
|
EXPORT void default_zn32_vmp_apply_i8_avx( //
|
||||||
|
const MOD_Z* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int8_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
// explicit conversions
|
||||||
|
|
||||||
|
/** reduction mod 1, output in torus32 space */
|
||||||
|
EXPORT void dbl_to_tn32_ref(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** real centerlift mod 1, output in double space */
|
||||||
|
EXPORT void tn32_to_dbl_ref(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** round to the nearest int, output in i32 space */
|
||||||
|
EXPORT void dbl_round_to_i32_ref(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** small int (int32 space) to double */
|
||||||
|
EXPORT void i32_to_dbl_ref(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** round to the nearest int, output in int64 space */
|
||||||
|
EXPORT void dbl_round_to_i64_ref(const MOD_Z* module, //
|
||||||
|
int64_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** small int (int64 space) to double */
|
||||||
|
EXPORT void i64_to_dbl_ref(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_ZN_ARITHMETIC_PRIVATE_H
|
||||||
108
spqlios/lib/spqlios/arithmetic/zn_conversions_ref.c
Normal file
108
spqlios/lib/spqlios/arithmetic/zn_conversions_ref.c
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "zn_arithmetic_private.h"
|
||||||
|
|
||||||
|
typedef union {
|
||||||
|
double dv;
|
||||||
|
int64_t s64v;
|
||||||
|
int32_t s32v;
|
||||||
|
uint64_t u64v;
|
||||||
|
uint32_t u32v;
|
||||||
|
} di_t;
|
||||||
|
|
||||||
|
/** reduction mod 1, output in torus32 space */
|
||||||
|
EXPORT void dbl_to_tn32_ref(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
static const double ADD_CST = 0.5 + (double)(INT64_C(3) << (51 - 32));
|
||||||
|
static const int32_t XOR_CST = (INT32_C(1) << 31);
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
di_t t = {.dv = a[i] + ADD_CST};
|
||||||
|
res[i] = t.s32v ^ XOR_CST;
|
||||||
|
}
|
||||||
|
memset(res + msize, 0, (res_size - msize) * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** real centerlift mod 1, output in double space */
|
||||||
|
EXPORT void tn32_to_dbl_ref(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
static const uint32_t XOR_CST = (UINT32_C(1) << 31);
|
||||||
|
static const di_t OR_CST = {.dv = (double)(INT64_C(1) << (52 - 32))};
|
||||||
|
static const double SUB_CST = 0.5 + (double)(INT64_C(1) << (52 - 32));
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
uint32_t ai = a[i] ^ XOR_CST;
|
||||||
|
di_t t = {.u64v = OR_CST.u64v | (uint64_t)ai};
|
||||||
|
res[i] = t.dv - SUB_CST;
|
||||||
|
}
|
||||||
|
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** round to the nearest int, output in i32 space */
|
||||||
|
EXPORT void dbl_round_to_i32_ref(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
static const double ADD_CST = (double)((INT64_C(3) << (51)) + (INT64_C(1) << (31)));
|
||||||
|
static const int32_t XOR_CST = INT32_C(1) << 31;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
di_t t = {.dv = a[i] + ADD_CST};
|
||||||
|
res[i] = t.s32v ^ XOR_CST;
|
||||||
|
}
|
||||||
|
memset(res + msize, 0, (res_size - msize) * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** small int (int32 space) to double */
|
||||||
|
EXPORT void i32_to_dbl_ref(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
static const uint32_t XOR_CST = (UINT32_C(1) << 31);
|
||||||
|
static const di_t OR_CST = {.dv = (double)(INT64_C(1) << 52)};
|
||||||
|
static const double SUB_CST = (double)((INT64_C(1) << 52) + (INT64_C(1) << 31));
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
uint32_t ai = a[i] ^ XOR_CST;
|
||||||
|
di_t t = {.u64v = OR_CST.u64v | (uint64_t)ai};
|
||||||
|
res[i] = t.dv - SUB_CST;
|
||||||
|
}
|
||||||
|
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** round to the nearest int, output in int64 space */
|
||||||
|
EXPORT void dbl_round_to_i64_ref(const MOD_Z* module, //
|
||||||
|
int64_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
static const double ADD_CST = (double)(INT64_C(3) << (51));
|
||||||
|
static const int64_t AND_CST = (INT64_C(1) << 52) - 1;
|
||||||
|
static const int64_t SUB_CST = INT64_C(1) << 51;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
di_t t = {.dv = a[i] + ADD_CST};
|
||||||
|
res[i] = (t.s64v & AND_CST) - SUB_CST;
|
||||||
|
}
|
||||||
|
memset(res + msize, 0, (res_size - msize) * sizeof(int64_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** small int (int64 space) to double */
|
||||||
|
EXPORT void i64_to_dbl_ref(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
static const uint64_t ADD_CST = UINT64_C(1) << 51;
|
||||||
|
static const uint64_t AND_CST = (UINT64_C(1) << 52) - 1;
|
||||||
|
static const di_t OR_CST = {.dv = (INT64_C(1) << 52)};
|
||||||
|
static const double SUB_CST = INT64_C(3) << 51;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
di_t t = {.u64v = ((a[i] + ADD_CST) & AND_CST) | OR_CST.u64v};
|
||||||
|
res[i] = t.dv - SUB_CST;
|
||||||
|
}
|
||||||
|
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||||
|
}
|
||||||
4
spqlios/lib/spqlios/arithmetic/zn_vmp_int16_avx.c
Normal file
4
spqlios/lib/spqlios/arithmetic/zn_vmp_int16_avx.c
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#define INTTYPE int16_t
|
||||||
|
#define INTSN i16
|
||||||
|
|
||||||
|
#include "zn_vmp_int32_avx.c"
|
||||||
4
spqlios/lib/spqlios/arithmetic/zn_vmp_int16_ref.c
Normal file
4
spqlios/lib/spqlios/arithmetic/zn_vmp_int16_ref.c
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#define INTTYPE int16_t
|
||||||
|
#define INTSN i16
|
||||||
|
|
||||||
|
#include "zn_vmp_int32_ref.c"
|
||||||
223
spqlios/lib/spqlios/arithmetic/zn_vmp_int32_avx.c
Normal file
223
spqlios/lib/spqlios/arithmetic/zn_vmp_int32_avx.c
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
// This file is actually a template: it will be compiled multiple times with
|
||||||
|
// different INTTYPES
|
||||||
|
#ifndef INTTYPE
|
||||||
|
#define INTTYPE int32_t
|
||||||
|
#define INTSN i32
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <immintrin.h>
|
||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "zn_arithmetic_private.h"
|
||||||
|
|
||||||
|
#define concat_inner(aa, bb, cc) aa##_##bb##_##cc
|
||||||
|
#define concat(aa, bb, cc) concat_inner(aa, bb, cc)
|
||||||
|
#define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc)
|
||||||
|
|
||||||
|
static void zn32_vec_mat32cols_avx_prefetch(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b) {
|
||||||
|
if (nrows == 0) {
|
||||||
|
memset(res, 0, 32 * sizeof(int32_t));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int32_t* bb = b;
|
||||||
|
const int32_t* pref_bb = b;
|
||||||
|
const uint64_t pref_iters = 128;
|
||||||
|
const uint64_t pref_start = pref_iters < nrows ? pref_iters : nrows;
|
||||||
|
const uint64_t pref_last = pref_iters > nrows ? 0 : nrows - pref_iters;
|
||||||
|
// let's do some prefetching of the GSW key, since on some cpus,
|
||||||
|
// it helps
|
||||||
|
for (uint64_t i = 0; i < pref_start; ++i) {
|
||||||
|
__builtin_prefetch(pref_bb, 0, _MM_HINT_T0);
|
||||||
|
__builtin_prefetch(pref_bb + 16, 0, _MM_HINT_T0);
|
||||||
|
pref_bb += 32;
|
||||||
|
}
|
||||||
|
// we do the first iteration
|
||||||
|
__m256i x = _mm256_set1_epi32(a[0]);
|
||||||
|
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||||
|
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||||
|
__m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)));
|
||||||
|
__m256i r3 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)));
|
||||||
|
bb += 32;
|
||||||
|
uint64_t row = 1;
|
||||||
|
for (; //
|
||||||
|
row < pref_last; //
|
||||||
|
++row, bb += 32) {
|
||||||
|
// prefetch the next iteration
|
||||||
|
__builtin_prefetch(pref_bb, 0, _MM_HINT_T0);
|
||||||
|
__builtin_prefetch(pref_bb + 16, 0, _MM_HINT_T0);
|
||||||
|
pref_bb += 32;
|
||||||
|
INTTYPE ai = a[row];
|
||||||
|
if (ai == 0) continue;
|
||||||
|
x = _mm256_set1_epi32(ai);
|
||||||
|
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||||
|
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||||
|
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||||
|
r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))));
|
||||||
|
}
|
||||||
|
for (; //
|
||||||
|
row < nrows; //
|
||||||
|
++row, bb += 32) {
|
||||||
|
INTTYPE ai = a[row];
|
||||||
|
if (ai == 0) continue;
|
||||||
|
x = _mm256_set1_epi32(ai);
|
||||||
|
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||||
|
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||||
|
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||||
|
r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))));
|
||||||
|
}
|
||||||
|
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 16), r2);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 24), r3);
|
||||||
|
}
|
||||||
|
|
||||||
|
void zn32_vec_fn(mat32cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
if (nrows == 0) {
|
||||||
|
memset(res, 0, 32 * sizeof(int32_t));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const INTTYPE* aa = a;
|
||||||
|
const INTTYPE* const aaend = a + nrows;
|
||||||
|
const int32_t* bb = b;
|
||||||
|
__m256i x = _mm256_set1_epi32(*aa);
|
||||||
|
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||||
|
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||||
|
__m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)));
|
||||||
|
__m256i r3 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)));
|
||||||
|
bb += b_sl;
|
||||||
|
++aa;
|
||||||
|
for (; //
|
||||||
|
aa < aaend; //
|
||||||
|
bb += b_sl, ++aa) {
|
||||||
|
INTTYPE ai = *aa;
|
||||||
|
if (ai == 0) continue;
|
||||||
|
x = _mm256_set1_epi32(ai);
|
||||||
|
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||||
|
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||||
|
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||||
|
r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))));
|
||||||
|
}
|
||||||
|
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 16), r2);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 24), r3);
|
||||||
|
}
|
||||||
|
|
||||||
|
void zn32_vec_fn(mat24cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
if (nrows == 0) {
|
||||||
|
memset(res, 0, 24 * sizeof(int32_t));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const INTTYPE* aa = a;
|
||||||
|
const INTTYPE* const aaend = a + nrows;
|
||||||
|
const int32_t* bb = b;
|
||||||
|
__m256i x = _mm256_set1_epi32(*aa);
|
||||||
|
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||||
|
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||||
|
__m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)));
|
||||||
|
bb += b_sl;
|
||||||
|
++aa;
|
||||||
|
for (; //
|
||||||
|
aa < aaend; //
|
||||||
|
bb += b_sl, ++aa) {
|
||||||
|
INTTYPE ai = *aa;
|
||||||
|
if (ai == 0) continue;
|
||||||
|
x = _mm256_set1_epi32(ai);
|
||||||
|
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||||
|
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||||
|
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||||
|
}
|
||||||
|
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 16), r2);
|
||||||
|
}
|
||||||
|
void zn32_vec_fn(mat16cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
if (nrows == 0) {
|
||||||
|
memset(res, 0, 16 * sizeof(int32_t));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const INTTYPE* aa = a;
|
||||||
|
const INTTYPE* const aaend = a + nrows;
|
||||||
|
const int32_t* bb = b;
|
||||||
|
__m256i x = _mm256_set1_epi32(*aa);
|
||||||
|
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||||
|
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||||
|
bb += b_sl;
|
||||||
|
++aa;
|
||||||
|
for (; //
|
||||||
|
aa < aaend; //
|
||||||
|
bb += b_sl, ++aa) {
|
||||||
|
INTTYPE ai = *aa;
|
||||||
|
if (ai == 0) continue;
|
||||||
|
x = _mm256_set1_epi32(ai);
|
||||||
|
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||||
|
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||||
|
}
|
||||||
|
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void zn32_vec_fn(mat8cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
if (nrows == 0) {
|
||||||
|
memset(res, 0, 8 * sizeof(int32_t));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const INTTYPE* aa = a;
|
||||||
|
const INTTYPE* const aaend = a + nrows;
|
||||||
|
const int32_t* bb = b;
|
||||||
|
__m256i x = _mm256_set1_epi32(*aa);
|
||||||
|
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||||
|
bb += b_sl;
|
||||||
|
++aa;
|
||||||
|
for (; //
|
||||||
|
aa < aaend; //
|
||||||
|
bb += b_sl, ++aa) {
|
||||||
|
INTTYPE ai = *aa;
|
||||||
|
if (ai == 0) continue;
|
||||||
|
x = _mm256_set1_epi32(ai);
|
||||||
|
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||||
|
}
|
||||||
|
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef void (*vm_f)(uint64_t nrows, //
|
||||||
|
int32_t* res, //
|
||||||
|
const INTTYPE* a, //
|
||||||
|
const int32_t* b, uint64_t b_sl //
|
||||||
|
);
|
||||||
|
static const vm_f zn32_vec_mat8kcols_avx[4] = { //
|
||||||
|
zn32_vec_fn(mat8cols_avx), //
|
||||||
|
zn32_vec_fn(mat16cols_avx), //
|
||||||
|
zn32_vec_fn(mat24cols_avx), //
|
||||||
|
zn32_vec_fn(mat32cols_avx)};
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int32_t* input) */
|
||||||
|
EXPORT void concat(default_zn32_vmp_apply, INTSN, avx)( //
|
||||||
|
const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, //
|
||||||
|
const INTTYPE* a, uint64_t a_size, //
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||||
|
const uint64_t rows = a_size < nrows ? a_size : nrows;
|
||||||
|
const uint64_t cols = res_size < ncols ? res_size : ncols;
|
||||||
|
const uint64_t ncolblk = cols >> 5;
|
||||||
|
const uint64_t ncolrem = cols & 31;
|
||||||
|
// copy the first full blocks
|
||||||
|
const uint64_t full_blk_size = nrows * 32;
|
||||||
|
const int32_t* mat = (int32_t*)pmat;
|
||||||
|
int32_t* rr = res;
|
||||||
|
for (uint64_t blk = 0; //
|
||||||
|
blk < ncolblk; //
|
||||||
|
++blk, mat += full_blk_size, rr += 32) {
|
||||||
|
zn32_vec_mat32cols_avx_prefetch(rows, rr, a, mat);
|
||||||
|
}
|
||||||
|
// last block
|
||||||
|
if (ncolrem) {
|
||||||
|
uint64_t orig_rem = ncols - (ncolblk << 5);
|
||||||
|
uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem;
|
||||||
|
int32_t tmp[32];
|
||||||
|
zn32_vec_mat8kcols_avx[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl);
|
||||||
|
memcpy(rr, tmp, ncolrem * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
// trailing bytes
|
||||||
|
memset(res + cols, 0, (res_size - cols) * sizeof(int32_t));
|
||||||
|
}
|
||||||
88
spqlios/lib/spqlios/arithmetic/zn_vmp_int32_ref.c
Normal file
88
spqlios/lib/spqlios/arithmetic/zn_vmp_int32_ref.c
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
// This file is actually a template: it will be compiled multiple times with
|
||||||
|
// different INTTYPES
|
||||||
|
#ifndef INTTYPE
|
||||||
|
#define INTTYPE int32_t
|
||||||
|
#define INTSN i32
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "zn_arithmetic_private.h"
|
||||||
|
|
||||||
|
#define concat_inner(aa, bb, cc) aa##_##bb##_##cc
|
||||||
|
#define concat(aa, bb, cc) concat_inner(aa, bb, cc)
|
||||||
|
#define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc)
|
||||||
|
|
||||||
|
// the ref version shares the same implementation for each fixed column size
|
||||||
|
// optimized implementations may do something different.
|
||||||
|
static __always_inline void IMPL_zn32_vec_matcols_ref(
|
||||||
|
const uint64_t NCOLS, // fixed number of columns
|
||||||
|
uint64_t nrows, // nrows of b
|
||||||
|
int32_t* res, // result: size NCOLS, only the first min(b_sl, NCOLS) are relevant
|
||||||
|
const INTTYPE* a, // a: nrows-sized vector
|
||||||
|
const int32_t* b, uint64_t b_sl // b: nrows * min(b_sl, NCOLS) matrix
|
||||||
|
) {
|
||||||
|
memset(res, 0, NCOLS * sizeof(int32_t));
|
||||||
|
for (uint64_t row = 0; row < nrows; ++row) {
|
||||||
|
int32_t ai = a[row];
|
||||||
|
const int32_t* bb = b + row * b_sl;
|
||||||
|
for (uint64_t i = 0; i < NCOLS; ++i) {
|
||||||
|
res[i] += ai * bb[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void zn32_vec_fn(mat32cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_matcols_ref(32, nrows, res, a, b, b_sl);
|
||||||
|
}
|
||||||
|
void zn32_vec_fn(mat24cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_matcols_ref(24, nrows, res, a, b, b_sl);
|
||||||
|
}
|
||||||
|
void zn32_vec_fn(mat16cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_matcols_ref(16, nrows, res, a, b, b_sl);
|
||||||
|
}
|
||||||
|
void zn32_vec_fn(mat8cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_matcols_ref(8, nrows, res, a, b, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef void (*vm_f)(uint64_t nrows, //
|
||||||
|
int32_t* res, //
|
||||||
|
const INTTYPE* a, //
|
||||||
|
const int32_t* b, uint64_t b_sl //
|
||||||
|
);
|
||||||
|
static const vm_f zn32_vec_mat8kcols_ref[4] = { //
|
||||||
|
zn32_vec_fn(mat8cols_ref), //
|
||||||
|
zn32_vec_fn(mat16cols_ref), //
|
||||||
|
zn32_vec_fn(mat24cols_ref), //
|
||||||
|
zn32_vec_fn(mat32cols_ref)};
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int32_t* input) */
|
||||||
|
EXPORT void concat(default_zn32_vmp_apply, INTSN, ref)( //
|
||||||
|
const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, //
|
||||||
|
const INTTYPE* a, uint64_t a_size, //
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||||
|
const uint64_t rows = a_size < nrows ? a_size : nrows;
|
||||||
|
const uint64_t cols = res_size < ncols ? res_size : ncols;
|
||||||
|
const uint64_t ncolblk = cols >> 5;
|
||||||
|
const uint64_t ncolrem = cols & 31;
|
||||||
|
// copy the first full blocks
|
||||||
|
const uint32_t full_blk_size = nrows * 32;
|
||||||
|
const int32_t* mat = (int32_t*)pmat;
|
||||||
|
int32_t* rr = res;
|
||||||
|
for (uint64_t blk = 0; //
|
||||||
|
blk < ncolblk; //
|
||||||
|
++blk, mat += full_blk_size, rr += 32) {
|
||||||
|
zn32_vec_fn(mat32cols_ref)(rows, rr, a, mat, 32);
|
||||||
|
}
|
||||||
|
// last block
|
||||||
|
if (ncolrem) {
|
||||||
|
uint64_t orig_rem = ncols - (ncolblk << 5);
|
||||||
|
uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem;
|
||||||
|
int32_t tmp[32];
|
||||||
|
zn32_vec_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl);
|
||||||
|
memcpy(rr, tmp, ncolrem * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
// trailing bytes
|
||||||
|
memset(res + cols, 0, (res_size - cols) * sizeof(int32_t));
|
||||||
|
}
|
||||||
4
spqlios/lib/spqlios/arithmetic/zn_vmp_int8_avx.c
Normal file
4
spqlios/lib/spqlios/arithmetic/zn_vmp_int8_avx.c
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#define INTTYPE int8_t
|
||||||
|
#define INTSN i8
|
||||||
|
|
||||||
|
#include "zn_vmp_int32_avx.c"
|
||||||
4
spqlios/lib/spqlios/arithmetic/zn_vmp_int8_ref.c
Normal file
4
spqlios/lib/spqlios/arithmetic/zn_vmp_int8_ref.c
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#define INTTYPE int8_t
|
||||||
|
#define INTSN i8
|
||||||
|
|
||||||
|
#include "zn_vmp_int32_ref.c"
|
||||||
138
spqlios/lib/spqlios/arithmetic/zn_vmp_ref.c
Normal file
138
spqlios/lib/spqlios/arithmetic/zn_vmp_ref.c
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "zn_arithmetic_private.h"
|
||||||
|
|
||||||
|
/** @brief size in bytes of a prepared matrix (for custom allocation) */
|
||||||
|
EXPORT uint64_t bytes_of_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols // dimensions
|
||||||
|
) {
|
||||||
|
return (nrows * ncols + 7) * sizeof(int32_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief allocates a prepared matrix (release with delete_zn32_vmp_pmat) */
|
||||||
|
EXPORT ZN32_VMP_PMAT* new_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols) {
|
||||||
|
return (ZN32_VMP_PMAT*)spqlios_alloc(bytes_of_zn32_vmp_pmat(module, nrows, ncols));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief deletes a prepared matrix (release with free) */
|
||||||
|
EXPORT void delete_zn32_vmp_pmat(ZN32_VMP_PMAT* ptr) { spqlios_free(ptr); }
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void default_zn32_vmp_prepare_contiguous_ref( //
|
||||||
|
const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t* mat, uint64_t nrows, uint64_t ncols // a
|
||||||
|
) {
|
||||||
|
int32_t* const out = (int32_t*)pmat;
|
||||||
|
const uint64_t nblk = ncols >> 5;
|
||||||
|
const uint64_t ncols_rem = ncols & 31;
|
||||||
|
const uint64_t final_elems = (8 - nrows * ncols) & 7;
|
||||||
|
for (uint64_t blk = 0; blk < nblk; ++blk) {
|
||||||
|
int32_t* outblk = out + blk * nrows * 32;
|
||||||
|
const int32_t* srcblk = mat + blk * 32;
|
||||||
|
for (uint64_t row = 0; row < nrows; ++row) {
|
||||||
|
int32_t* dest = outblk + row * 32;
|
||||||
|
const int32_t* src = srcblk + row * ncols;
|
||||||
|
for (uint64_t i = 0; i < 32; ++i) {
|
||||||
|
dest[i] = src[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// copy the last block if any
|
||||||
|
if (ncols_rem) {
|
||||||
|
int32_t* outblk = out + nblk * nrows * 32;
|
||||||
|
const int32_t* srcblk = mat + nblk * 32;
|
||||||
|
for (uint64_t row = 0; row < nrows; ++row) {
|
||||||
|
int32_t* dest = outblk + row * ncols_rem;
|
||||||
|
const int32_t* src = srcblk + row * ncols;
|
||||||
|
for (uint64_t i = 0; i < ncols_rem; ++i) {
|
||||||
|
dest[i] = src[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// zero-out the final elements that may be accessed
|
||||||
|
if (final_elems) {
|
||||||
|
int32_t* f = out + nrows * ncols;
|
||||||
|
for (uint64_t i = 0; i < final_elems; ++i) {
|
||||||
|
f[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
|
||||||
|
#define IMPL_zn32_vec_ixxx_matyyycols_ref(NCOLS) \
|
||||||
|
memset(res, 0, NCOLS * sizeof(int32_t)); \
|
||||||
|
for (uint64_t row = 0; row < nrows; ++row) { \
|
||||||
|
int32_t ai = a[row]; \
|
||||||
|
const int32_t* bb = b + row * b_sl; \
|
||||||
|
for (uint64_t i = 0; i < NCOLS; ++i) { \
|
||||||
|
res[i] += ai * bb[i]; \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define IMPL_zn32_vec_ixxx_mat8cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(8)
|
||||||
|
#define IMPL_zn32_vec_ixxx_mat16cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(16)
|
||||||
|
#define IMPL_zn32_vec_ixxx_mat24cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(24)
|
||||||
|
#define IMPL_zn32_vec_ixxx_mat32cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(32)
|
||||||
|
|
||||||
|
void zn32_vec_i8_mat32cols_ref(uint64_t nrows, int32_t* res, const int8_t* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_ixxx_mat32cols_ref()
|
||||||
|
}
|
||||||
|
void zn32_vec_i16_mat32cols_ref(uint64_t nrows, int32_t* res, const int16_t* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_ixxx_mat32cols_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
void zn32_vec_i32_mat32cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_ixxx_mat32cols_ref()
|
||||||
|
}
|
||||||
|
void zn32_vec_i32_mat24cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_ixxx_mat24cols_ref()
|
||||||
|
}
|
||||||
|
void zn32_vec_i32_mat16cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_ixxx_mat16cols_ref()
|
||||||
|
}
|
||||||
|
void zn32_vec_i32_mat8cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_ixxx_mat8cols_ref()
|
||||||
|
}
|
||||||
|
typedef void (*zn32_vec_i32_mat8kcols_ref_f)(uint64_t nrows, //
|
||||||
|
int32_t* res, //
|
||||||
|
const int32_t* a, //
|
||||||
|
const int32_t* b, uint64_t b_sl //
|
||||||
|
);
|
||||||
|
zn32_vec_i32_mat8kcols_ref_f zn32_vec_i32_mat8kcols_ref[4] = { //
|
||||||
|
zn32_vec_i32_mat8cols_ref, zn32_vec_i32_mat16cols_ref, //
|
||||||
|
zn32_vec_i32_mat24cols_ref, zn32_vec_i32_mat32cols_ref};
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int32_t* input) */
|
||||||
|
EXPORT void default_zn32_vmp_apply_i32_ref(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, //
|
||||||
|
const int32_t* a, uint64_t a_size, //
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||||
|
const uint64_t rows = a_size < nrows ? a_size : nrows;
|
||||||
|
const uint64_t cols = res_size < ncols ? res_size : ncols;
|
||||||
|
const uint64_t ncolblk = cols >> 5;
|
||||||
|
const uint64_t ncolrem = cols & 31;
|
||||||
|
// copy the first full blocks
|
||||||
|
const uint32_t full_blk_size = nrows * 32;
|
||||||
|
const int32_t* mat = (int32_t*)pmat;
|
||||||
|
int32_t* rr = res;
|
||||||
|
for (uint64_t blk = 0; //
|
||||||
|
blk < ncolblk; //
|
||||||
|
++blk, mat += full_blk_size, rr += 32) {
|
||||||
|
zn32_vec_i32_mat32cols_ref(rows, rr, a, mat, 32);
|
||||||
|
}
|
||||||
|
// last block
|
||||||
|
if (ncolrem) {
|
||||||
|
uint64_t orig_rem = ncols - (ncolblk << 5);
|
||||||
|
uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem;
|
||||||
|
int32_t tmp[32];
|
||||||
|
zn32_vec_i32_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl);
|
||||||
|
memcpy(rr, tmp, ncolrem * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
// trailing bytes
|
||||||
|
memset(res + cols, 0, (res_size - cols) * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
38
spqlios/lib/spqlios/arithmetic/znx_small.c
Normal file
38
spqlios/lib/spqlios/arithmetic/znx_small.c
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
/** @brief res = a * b : small integer polynomial product */
|
||||||
|
EXPORT void fft64_znx_small_single_product(const MODULE* module, // N
|
||||||
|
int64_t* res, // output
|
||||||
|
const int64_t* a, // a
|
||||||
|
const int64_t* b, // b
|
||||||
|
uint8_t* tmp) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
double* const ffta = (double*)tmp;
|
||||||
|
double* const fftb = ((double*)tmp) + nn;
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, ffta, a);
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, fftb, b);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, ffta);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, fftb);
|
||||||
|
reim_fftvec_mul_simple(module->m, ffta, ffta, fftb);
|
||||||
|
reim_ifft(module->mod.fft64.p_ifft, ffta);
|
||||||
|
reim_to_znx64(module->mod.fft64.p_reim_to_znx, res, ffta);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief tmp bytes required for znx_small_single_product */
|
||||||
|
EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module) {
|
||||||
|
return 2 * module->nn * sizeof(double);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief res = a * b : small integer polynomial product */
|
||||||
|
EXPORT void znx_small_single_product(const MODULE* module, // N
|
||||||
|
int64_t* res, // output
|
||||||
|
const int64_t* a, // a
|
||||||
|
const int64_t* b, // b
|
||||||
|
uint8_t* tmp) {
|
||||||
|
module->func.znx_small_single_product(module, res, a, b, tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief tmp bytes required for znx_small_single_product */
|
||||||
|
EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module) {
|
||||||
|
return module->func.znx_small_single_product_tmp_bytes(module);
|
||||||
|
}
|
||||||
496
spqlios/lib/spqlios/coeffs/coeffs_arithmetic.c
Normal file
496
spqlios/lib/spqlios/coeffs/coeffs_arithmetic.c
Normal file
@@ -0,0 +1,496 @@
|
|||||||
|
#include "coeffs_arithmetic.h"
|
||||||
|
|
||||||
|
#include <memory.h>
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
/** res = a + b */
|
||||||
|
EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
res[i] = a[i] + b[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/** res = a - b */
|
||||||
|
EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
res[i] = a[i] - b[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) {
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
res[i] = -a[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) { memcpy(res, a, nn * sizeof(int64_t)); }
|
||||||
|
|
||||||
|
EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res) { memset(res, 0, nn * sizeof(int64_t)); }
|
||||||
|
|
||||||
|
EXPORT void rnx_divide_by_m_ref(uint64_t n, double m, double* res, const double* a) {
|
||||||
|
const double invm = 1. / m;
|
||||||
|
for (uint64_t i = 0; i < n; ++i) {
|
||||||
|
res[i] = a[i] * invm;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in) {
|
||||||
|
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||||
|
|
||||||
|
if (a < nn) { // rotate to the left
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
// rotate first half
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = in[j + a];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
res[j] = -in[j - nma];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
a -= nn;
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = -in[j + a];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
// rotate first half
|
||||||
|
res[j] = in[j - nma];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
|
||||||
|
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||||
|
|
||||||
|
if (a < nn) { // rotate to the left
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
// rotate first half
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = in[j + a];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
res[j] = -in[j - nma];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
a -= nn;
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = -in[j + a];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
// rotate first half
|
||||||
|
res[j] = in[j - nma];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void rnx_mul_xp_minus_one(uint64_t nn, int64_t p, double* res, const double* in) {
|
||||||
|
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||||
|
if (a < nn) { // rotate to the left
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
// rotate first half
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = in[j + a] - in[j];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
res[j] = -in[j - nma] - in[j];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
a -= nn;
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = -in[j + a] - in[j];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
// rotate first half
|
||||||
|
res[j] = in[j - nma] - in[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_mul_xp_minus_one(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
|
||||||
|
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||||
|
if (a < nn) { // rotate to the left
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
// rotate first half
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = in[j + a] - in[j];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
res[j] = -in[j - nma] - in[j];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
a -= nn;
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = -in[j + a] - in[j];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
// rotate first half
|
||||||
|
res[j] = in[j - nma] - in[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0 < p < 2nn
|
||||||
|
EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in) {
|
||||||
|
res[0] = in[0];
|
||||||
|
uint64_t a = 0;
|
||||||
|
uint64_t _2mn = 2 * nn - 1;
|
||||||
|
for (uint64_t i = 1; i < nn; i++) {
|
||||||
|
a = (a + p) & _2mn; // i*p mod 2n
|
||||||
|
if (a < nn) {
|
||||||
|
res[a] = in[i]; // res[ip mod 2n] = res[i]
|
||||||
|
} else {
|
||||||
|
res[a - nn] = -in[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
|
||||||
|
res[0] = in[0];
|
||||||
|
uint64_t a = 0;
|
||||||
|
uint64_t _2mn = 2 * nn - 1;
|
||||||
|
for (uint64_t i = 1; i < nn; i++) {
|
||||||
|
a = (a + p) & _2mn;
|
||||||
|
if (a < nn) {
|
||||||
|
res[a] = in[i]; // res[ip mod 2n] = res[i]
|
||||||
|
} else {
|
||||||
|
res[a - nn] = -in[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res) {
|
||||||
|
const uint64_t _2mn = 2 * nn - 1;
|
||||||
|
const uint64_t _mn = nn - 1;
|
||||||
|
uint64_t nb_modif = 0;
|
||||||
|
uint64_t j_start = 0;
|
||||||
|
while (nb_modif < nn) {
|
||||||
|
// follow the cycle that start with j_start
|
||||||
|
uint64_t j = j_start;
|
||||||
|
double tmp1 = res[j];
|
||||||
|
do {
|
||||||
|
// find where the value should go, and with which sign
|
||||||
|
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||||
|
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||||
|
// exchange this position with tmp1 (and take care of the sign)
|
||||||
|
double tmp2 = res[new_j_n];
|
||||||
|
res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1;
|
||||||
|
tmp1 = tmp2;
|
||||||
|
// move to the new location, and store the number of items modified
|
||||||
|
++nb_modif;
|
||||||
|
j = new_j_n;
|
||||||
|
} while (j != j_start);
|
||||||
|
// move to the start of the next cycle:
|
||||||
|
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||||
|
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||||
|
++j_start;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res) {
|
||||||
|
const uint64_t _2mn = 2 * nn - 1;
|
||||||
|
const uint64_t _mn = nn - 1;
|
||||||
|
uint64_t nb_modif = 0;
|
||||||
|
uint64_t j_start = 0;
|
||||||
|
while (nb_modif < nn) {
|
||||||
|
// follow the cycle that start with j_start
|
||||||
|
uint64_t j = j_start;
|
||||||
|
int64_t tmp1 = res[j];
|
||||||
|
do {
|
||||||
|
// find where the value should go, and with which sign
|
||||||
|
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||||
|
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||||
|
// exchange this position with tmp1 (and take care of the sign)
|
||||||
|
int64_t tmp2 = res[new_j_n];
|
||||||
|
res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1;
|
||||||
|
tmp1 = tmp2;
|
||||||
|
// move to the new location, and store the number of items modified
|
||||||
|
++nb_modif;
|
||||||
|
j = new_j_n;
|
||||||
|
} while (j != j_start);
|
||||||
|
// move to the start of the next cycle:
|
||||||
|
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||||
|
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||||
|
++j_start;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void rnx_mul_xp_minus_one_inplace(uint64_t nn, int64_t p, double* res) {
|
||||||
|
const uint64_t _2mn = 2 * nn - 1;
|
||||||
|
const uint64_t _mn = nn - 1;
|
||||||
|
uint64_t nb_modif = 0;
|
||||||
|
uint64_t j_start = 0;
|
||||||
|
while (nb_modif < nn) {
|
||||||
|
// follow the cycle that start with j_start
|
||||||
|
uint64_t j = j_start;
|
||||||
|
double tmp1 = res[j];
|
||||||
|
do {
|
||||||
|
// find where the value should go, and with which sign
|
||||||
|
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||||
|
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||||
|
// exchange this position with tmp1 (and take care of the sign)
|
||||||
|
double tmp2 = res[new_j_n];
|
||||||
|
res[new_j_n] = ((new_j < nn) ? tmp1 : -tmp1) - res[new_j_n];
|
||||||
|
tmp1 = tmp2;
|
||||||
|
// move to the new location, and store the number of items modified
|
||||||
|
++nb_modif;
|
||||||
|
j = new_j_n;
|
||||||
|
} while (j != j_start);
|
||||||
|
// move to the start of the next cycle:
|
||||||
|
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||||
|
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||||
|
++j_start;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__always_inline int64_t get_base_k_digit(const int64_t x, const uint64_t base_k) {
|
||||||
|
return (x << (64 - base_k)) >> (64 - base_k);
|
||||||
|
}
|
||||||
|
|
||||||
|
__always_inline int64_t get_base_k_carry(const int64_t x, const int64_t digit, const uint64_t base_k) {
|
||||||
|
return (x - digit) >> base_k;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in,
|
||||||
|
const int64_t* carry_in) {
|
||||||
|
assert(in);
|
||||||
|
if (out != 0) {
|
||||||
|
if (carry_in != 0x0 && carry_out != 0x0) {
|
||||||
|
// with carry in and carry out is computed
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
const int64_t x = in[i];
|
||||||
|
const int64_t cin = carry_in[i];
|
||||||
|
|
||||||
|
int64_t digit = get_base_k_digit(x, base_k);
|
||||||
|
int64_t carry = get_base_k_carry(x, digit, base_k);
|
||||||
|
int64_t digit_plus_cin = digit + cin;
|
||||||
|
int64_t y = get_base_k_digit(digit_plus_cin, base_k);
|
||||||
|
int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k);
|
||||||
|
|
||||||
|
out[i] = y;
|
||||||
|
carry_out[i] = cout;
|
||||||
|
}
|
||||||
|
} else if (carry_in != 0) {
|
||||||
|
// with carry in and carry out is dropped
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
const int64_t x = in[i];
|
||||||
|
const int64_t cin = carry_in[i];
|
||||||
|
|
||||||
|
int64_t digit = get_base_k_digit(x, base_k);
|
||||||
|
int64_t digit_plus_cin = digit + cin;
|
||||||
|
int64_t y = get_base_k_digit(digit_plus_cin, base_k);
|
||||||
|
|
||||||
|
out[i] = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (carry_out != 0) {
|
||||||
|
// no carry in and carry out is computed
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
const int64_t x = in[i];
|
||||||
|
|
||||||
|
int64_t y = get_base_k_digit(x, base_k);
|
||||||
|
int64_t cout = get_base_k_carry(x, y, base_k);
|
||||||
|
|
||||||
|
out[i] = y;
|
||||||
|
carry_out[i] = cout;
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// no carry in and carry out is dropped
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
out[i] = get_base_k_digit(in[i], base_k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert(carry_out);
|
||||||
|
if (carry_in != 0x0) {
|
||||||
|
// with carry in and carry out is computed
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
const int64_t x = in[i];
|
||||||
|
const int64_t cin = carry_in[i];
|
||||||
|
|
||||||
|
int64_t digit = get_base_k_digit(x, base_k);
|
||||||
|
int64_t carry = get_base_k_carry(x, digit, base_k);
|
||||||
|
int64_t digit_plus_cin = digit + cin;
|
||||||
|
int64_t y = get_base_k_digit(digit_plus_cin, base_k);
|
||||||
|
int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k);
|
||||||
|
|
||||||
|
carry_out[i] = cout;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// no carry in and carry out is computed
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
const int64_t x = in[i];
|
||||||
|
|
||||||
|
int64_t y = get_base_k_digit(x, base_k);
|
||||||
|
int64_t cout = get_base_k_carry(x, y, base_k);
|
||||||
|
|
||||||
|
carry_out[i] = cout;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res) {
|
||||||
|
const uint64_t _2mn = 2 * nn - 1;
|
||||||
|
const uint64_t _mn = nn - 1;
|
||||||
|
const uint64_t m = nn >> 1;
|
||||||
|
// reduce p mod 2n
|
||||||
|
p &= _2mn;
|
||||||
|
// uint64_t vp = p & _2mn;
|
||||||
|
/// uint64_t target_modifs = m >> 1;
|
||||||
|
// we proceed by increasing binary valuation
|
||||||
|
for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn;
|
||||||
|
binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) {
|
||||||
|
// In this loop, we are going to treat the orbit of indexes = binval mod 2.binval.
|
||||||
|
// At the beginning of this loop we have:
|
||||||
|
// vp = binval * p mod 2n
|
||||||
|
// target_modif = m / binval (i.e. order of the orbit binval % 2.binval)
|
||||||
|
|
||||||
|
// first, handle the orders 1 and 2.
|
||||||
|
// if p*binval == binval % 2n: we're done!
|
||||||
|
if (vp == binval) return;
|
||||||
|
// if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit!
|
||||||
|
if (((vp + binval) & _2mn) == 0) {
|
||||||
|
for (uint64_t j = binval; j < m; j += binval) {
|
||||||
|
int64_t tmp = res[j];
|
||||||
|
res[j] = -res[nn - j];
|
||||||
|
res[nn - j] = -tmp;
|
||||||
|
}
|
||||||
|
res[m] = -res[m];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// if p*binval == binval + n % 2n: negate the orbit and exit
|
||||||
|
if (((vp - binval) & _mn) == 0) {
|
||||||
|
for (uint64_t j = binval; j < nn; j += 2 * binval) {
|
||||||
|
res[j] = -res[j];
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// if p*binval == n - binval % 2n: mirror the orbit and continue!
|
||||||
|
if (((vp + binval) & _mn) == 0) {
|
||||||
|
for (uint64_t j = binval; j < m; j += 2 * binval) {
|
||||||
|
int64_t tmp = res[j];
|
||||||
|
res[j] = res[nn - j];
|
||||||
|
res[nn - j] = tmp;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// otherwise we will follow the orbit cycles,
|
||||||
|
// starting from binval and -binval in parallel
|
||||||
|
uint64_t j_start = binval;
|
||||||
|
uint64_t nb_modif = 0;
|
||||||
|
while (nb_modif < orb_size) {
|
||||||
|
// follow the cycle that start with j_start
|
||||||
|
uint64_t j = j_start;
|
||||||
|
int64_t tmp1 = res[j];
|
||||||
|
int64_t tmp2 = res[nn - j];
|
||||||
|
do {
|
||||||
|
// find where the value should go, and with which sign
|
||||||
|
uint64_t new_j = (j * p) & _2mn; // mod 2n to get the position and sign
|
||||||
|
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||||
|
// exchange this position with tmp1 (and take care of the sign)
|
||||||
|
int64_t tmp1a = res[new_j_n];
|
||||||
|
int64_t tmp2a = res[nn - new_j_n];
|
||||||
|
if (new_j < nn) {
|
||||||
|
res[new_j_n] = tmp1;
|
||||||
|
res[nn - new_j_n] = tmp2;
|
||||||
|
} else {
|
||||||
|
res[new_j_n] = -tmp1;
|
||||||
|
res[nn - new_j_n] = -tmp2;
|
||||||
|
}
|
||||||
|
tmp1 = tmp1a;
|
||||||
|
tmp2 = tmp2a;
|
||||||
|
// move to the new location, and store the number of items modified
|
||||||
|
nb_modif += 2;
|
||||||
|
j = new_j_n;
|
||||||
|
} while (j != j_start);
|
||||||
|
// move to the start of the next cycle:
|
||||||
|
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||||
|
// in practice, it is enough to do *5, because 5 is a generator.
|
||||||
|
j_start = (5 * j_start) & _mn;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res) {
|
||||||
|
const uint64_t _2mn = 2 * nn - 1;
|
||||||
|
const uint64_t _mn = nn - 1;
|
||||||
|
const uint64_t m = nn >> 1;
|
||||||
|
// reduce p mod 2n
|
||||||
|
p &= _2mn;
|
||||||
|
// uint64_t vp = p & _2mn;
|
||||||
|
/// uint64_t target_modifs = m >> 1;
|
||||||
|
// we proceed by increasing binary valuation
|
||||||
|
for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn;
|
||||||
|
binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) {
|
||||||
|
// In this loop, we are going to treat the orbit of indexes = binval mod 2.binval.
|
||||||
|
// At the beginning of this loop we have:
|
||||||
|
// vp = binval * p mod 2n
|
||||||
|
// target_modif = m / binval (i.e. order of the orbit binval % 2.binval)
|
||||||
|
|
||||||
|
// first, handle the orders 1 and 2.
|
||||||
|
// if p*binval == binval % 2n: we're done!
|
||||||
|
if (vp == binval) return;
|
||||||
|
// if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit!
|
||||||
|
if (((vp + binval) & _2mn) == 0) {
|
||||||
|
for (uint64_t j = binval; j < m; j += binval) {
|
||||||
|
double tmp = res[j];
|
||||||
|
res[j] = -res[nn - j];
|
||||||
|
res[nn - j] = -tmp;
|
||||||
|
}
|
||||||
|
res[m] = -res[m];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// if p*binval == binval + n % 2n: negate the orbit and exit
|
||||||
|
if (((vp - binval) & _mn) == 0) {
|
||||||
|
for (uint64_t j = binval; j < nn; j += 2 * binval) {
|
||||||
|
res[j] = -res[j];
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// if p*binval == n - binval % 2n: mirror the orbit and continue!
|
||||||
|
if (((vp + binval) & _mn) == 0) {
|
||||||
|
for (uint64_t j = binval; j < m; j += 2 * binval) {
|
||||||
|
double tmp = res[j];
|
||||||
|
res[j] = res[nn - j];
|
||||||
|
res[nn - j] = tmp;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// otherwise we will follow the orbit cycles,
|
||||||
|
// starting from binval and -binval in parallel
|
||||||
|
uint64_t j_start = binval;
|
||||||
|
uint64_t nb_modif = 0;
|
||||||
|
while (nb_modif < orb_size) {
|
||||||
|
// follow the cycle that start with j_start
|
||||||
|
uint64_t j = j_start;
|
||||||
|
double tmp1 = res[j];
|
||||||
|
double tmp2 = res[nn - j];
|
||||||
|
do {
|
||||||
|
// find where the value should go, and with which sign
|
||||||
|
uint64_t new_j = (j * p) & _2mn; // mod 2n to get the position and sign
|
||||||
|
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||||
|
// exchange this position with tmp1 (and take care of the sign)
|
||||||
|
double tmp1a = res[new_j_n];
|
||||||
|
double tmp2a = res[nn - new_j_n];
|
||||||
|
if (new_j < nn) {
|
||||||
|
res[new_j_n] = tmp1;
|
||||||
|
res[nn - new_j_n] = tmp2;
|
||||||
|
} else {
|
||||||
|
res[new_j_n] = -tmp1;
|
||||||
|
res[nn - new_j_n] = -tmp2;
|
||||||
|
}
|
||||||
|
tmp1 = tmp1a;
|
||||||
|
tmp2 = tmp2a;
|
||||||
|
// move to the new location, and store the number of items modified
|
||||||
|
nb_modif += 2;
|
||||||
|
j = new_j_n;
|
||||||
|
} while (j != j_start);
|
||||||
|
// move to the start of the next cycle:
|
||||||
|
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||||
|
// in practice, it is enough to do *5, because 5 is a generator.
|
||||||
|
j_start = (5 * j_start) & _mn;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
78
spqlios/lib/spqlios/coeffs/coeffs_arithmetic.h
Normal file
78
spqlios/lib/spqlios/coeffs/coeffs_arithmetic.h
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
#ifndef SPQLIOS_COEFFS_ARITHMETIC_H
|
||||||
|
#define SPQLIOS_COEFFS_ARITHMETIC_H
|
||||||
|
|
||||||
|
#include "../commons.h"
|
||||||
|
|
||||||
|
/** res = a + b */
|
||||||
|
EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||||
|
EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||||
|
/** res = a - b */
|
||||||
|
EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||||
|
EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||||
|
/** res = -a */
|
||||||
|
EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a);
|
||||||
|
EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a);
|
||||||
|
/** res = a */
|
||||||
|
EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a);
|
||||||
|
/** res = 0 */
|
||||||
|
EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res);
|
||||||
|
|
||||||
|
/** res = a / m where m is a power of 2 */
|
||||||
|
EXPORT void rnx_divide_by_m_ref(uint64_t nn, double m, double* res, const double* a);
|
||||||
|
EXPORT void rnx_divide_by_m_avx(uint64_t nn, double m, double* res, const double* a);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param res = X^p *in mod X^nn +1
|
||||||
|
* @param nn the ring dimension
|
||||||
|
* @param p a power for the rotation -2nn <= p <= 2nn
|
||||||
|
* @param in is a rnx/znx vector of dimension nn
|
||||||
|
*/
|
||||||
|
EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in);
|
||||||
|
EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
|
||||||
|
EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res);
|
||||||
|
EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief res(X) = in(X^p)
|
||||||
|
* @param nn the ring dimension
|
||||||
|
* @param p is odd integer and must be between 0 < p < 2nn
|
||||||
|
* @param in is a rnx/znx vector of dimension nn
|
||||||
|
*/
|
||||||
|
EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in);
|
||||||
|
EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
|
||||||
|
EXPORT void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res);
|
||||||
|
EXPORT void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief res = (X^p-1).in
|
||||||
|
* @param nn the ring dimension
|
||||||
|
* @param p must be between -2nn <= p <= 2nn
|
||||||
|
* @param in is a rnx/znx vector of dimension nn
|
||||||
|
*/
|
||||||
|
EXPORT void rnx_mul_xp_minus_one(uint64_t nn, int64_t p, double* res, const double* in);
|
||||||
|
EXPORT void znx_mul_xp_minus_one(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
|
||||||
|
EXPORT void rnx_mul_xp_minus_one_inplace(uint64_t nn, int64_t p, double* res);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Normalize input plus carry mod-2^k. The following
|
||||||
|
* equality holds @c {in + carry_in == out + carry_out . 2^k}.
|
||||||
|
*
|
||||||
|
* @c in must be in [-2^62 .. 2^62]
|
||||||
|
*
|
||||||
|
* @c out is in [ -2^(base_k-1), 2^(base_k-1) [.
|
||||||
|
*
|
||||||
|
* @c carry_in and @carry_out have at most 64+1-k bits.
|
||||||
|
*
|
||||||
|
* Null @c carry_in or @c carry_out are ignored.
|
||||||
|
*
|
||||||
|
* @param[in] nn the ring dimension
|
||||||
|
* @param[in] base_k the base k
|
||||||
|
* @param out output normalized znx
|
||||||
|
* @param carry_out output carry znx
|
||||||
|
* @param[in] in input znx
|
||||||
|
* @param[in] carry_in input carry znx
|
||||||
|
*/
|
||||||
|
EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in,
|
||||||
|
const int64_t* carry_in);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_COEFFS_ARITHMETIC_H
|
||||||
124
spqlios/lib/spqlios/coeffs/coeffs_arithmetic_avx.c
Normal file
124
spqlios/lib/spqlios/coeffs/coeffs_arithmetic_avx.c
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
#include <immintrin.h>
|
||||||
|
|
||||||
|
#include "../commons_private.h"
|
||||||
|
#include "coeffs_arithmetic.h"
|
||||||
|
|
||||||
|
// res = a + b. dimension n must be a power of 2
|
||||||
|
EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||||
|
if (nn <= 2) {
|
||||||
|
if (nn == 1) {
|
||||||
|
res[0] = a[0] + b[0];
|
||||||
|
} else {
|
||||||
|
_mm_storeu_si128((__m128i*)res, //
|
||||||
|
_mm_add_epi64( //
|
||||||
|
_mm_loadu_si128((__m128i*)a), //
|
||||||
|
_mm_loadu_si128((__m128i*)b)));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const __m256i* aa = (__m256i*)a;
|
||||||
|
const __m256i* bb = (__m256i*)b;
|
||||||
|
__m256i* rr = (__m256i*)res;
|
||||||
|
__m256i* const rrend = (__m256i*)(res + nn);
|
||||||
|
do {
|
||||||
|
_mm256_storeu_si256(rr, //
|
||||||
|
_mm256_add_epi64( //
|
||||||
|
_mm256_loadu_si256(aa), //
|
||||||
|
_mm256_loadu_si256(bb)));
|
||||||
|
++rr;
|
||||||
|
++aa;
|
||||||
|
++bb;
|
||||||
|
} while (rr < rrend);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// res = a - b. dimension n must be a power of 2
|
||||||
|
EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||||
|
if (nn <= 2) {
|
||||||
|
if (nn == 1) {
|
||||||
|
res[0] = a[0] - b[0];
|
||||||
|
} else {
|
||||||
|
_mm_storeu_si128((__m128i*)res, //
|
||||||
|
_mm_sub_epi64( //
|
||||||
|
_mm_loadu_si128((__m128i*)a), //
|
||||||
|
_mm_loadu_si128((__m128i*)b)));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const __m256i* aa = (__m256i*)a;
|
||||||
|
const __m256i* bb = (__m256i*)b;
|
||||||
|
__m256i* rr = (__m256i*)res;
|
||||||
|
__m256i* const rrend = (__m256i*)(res + nn);
|
||||||
|
do {
|
||||||
|
_mm256_storeu_si256(rr, //
|
||||||
|
_mm256_sub_epi64( //
|
||||||
|
_mm256_loadu_si256(aa), //
|
||||||
|
_mm256_loadu_si256(bb)));
|
||||||
|
++rr;
|
||||||
|
++aa;
|
||||||
|
++bb;
|
||||||
|
} while (rr < rrend);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a) {
|
||||||
|
if (nn <= 2) {
|
||||||
|
if (nn == 1) {
|
||||||
|
res[0] = -a[0];
|
||||||
|
} else {
|
||||||
|
_mm_storeu_si128((__m128i*)res, //
|
||||||
|
_mm_sub_epi64( //
|
||||||
|
_mm_set1_epi64x(0), //
|
||||||
|
_mm_loadu_si128((__m128i*)a)));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const __m256i* aa = (__m256i*)a;
|
||||||
|
__m256i* rr = (__m256i*)res;
|
||||||
|
__m256i* const rrend = (__m256i*)(res + nn);
|
||||||
|
do {
|
||||||
|
_mm256_storeu_si256(rr, //
|
||||||
|
_mm256_sub_epi64( //
|
||||||
|
_mm256_set1_epi64x(0), //
|
||||||
|
_mm256_loadu_si256(aa)));
|
||||||
|
++rr;
|
||||||
|
++aa;
|
||||||
|
} while (rr < rrend);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void rnx_divide_by_m_avx(uint64_t n, double m, double* res, const double* a) {
|
||||||
|
// TODO: see if there is a faster way of dividing by a power of 2?
|
||||||
|
const double invm = 1. / m;
|
||||||
|
if (n < 8) {
|
||||||
|
switch (n) {
|
||||||
|
case 1:
|
||||||
|
*res = *a * invm;
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
_mm_storeu_pd(res, //
|
||||||
|
_mm_mul_pd(_mm_loadu_pd(a), //
|
||||||
|
_mm_set1_pd(invm)));
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
_mm256_storeu_pd(res, //
|
||||||
|
_mm256_mul_pd(_mm256_loadu_pd(a), //
|
||||||
|
_mm256_set1_pd(invm)));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
NOT_SUPPORTED(); // non-power of 2
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const __m256d invm256 = _mm256_set1_pd(invm);
|
||||||
|
double* rr = res;
|
||||||
|
const double* aa = a;
|
||||||
|
const double* const aaend = a + n;
|
||||||
|
do {
|
||||||
|
_mm256_storeu_pd(rr, //
|
||||||
|
_mm256_mul_pd(_mm256_loadu_pd(aa), //
|
||||||
|
invm256));
|
||||||
|
_mm256_storeu_pd(rr + 4, //
|
||||||
|
_mm256_mul_pd(_mm256_loadu_pd(aa + 4), //
|
||||||
|
invm256));
|
||||||
|
rr += 8;
|
||||||
|
aa += 8;
|
||||||
|
} while (aa < aaend);
|
||||||
|
}
|
||||||
165
spqlios/lib/spqlios/commons.c
Normal file
165
spqlios/lib/spqlios/commons.c
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
#include "commons.h"
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
EXPORT void* UNDEFINED_p_ii(int32_t n, int32_t m) { UNDEFINED(); }
|
||||||
|
EXPORT void* UNDEFINED_p_uu(uint32_t n, uint32_t m) { UNDEFINED(); }
|
||||||
|
EXPORT double* UNDEFINED_dp_pi(const void* p, int32_t n) { UNDEFINED(); }
|
||||||
|
EXPORT void* UNDEFINED_vp_pi(const void* p, int32_t n) { UNDEFINED(); }
|
||||||
|
EXPORT void* UNDEFINED_vp_pu(const void* p, uint32_t n) { UNDEFINED(); }
|
||||||
|
EXPORT void UNDEFINED_v_vpdp(const void* p, double* a) { UNDEFINED(); }
|
||||||
|
EXPORT void UNDEFINED_v_vpvp(const void* p, void* a) { UNDEFINED(); }
|
||||||
|
EXPORT double* NOT_IMPLEMENTED_dp_i(int32_t n) { NOT_IMPLEMENTED(); }
|
||||||
|
EXPORT void* NOT_IMPLEMENTED_vp_i(int32_t n) { NOT_IMPLEMENTED(); }
|
||||||
|
EXPORT void* NOT_IMPLEMENTED_vp_u(uint32_t n) { NOT_IMPLEMENTED(); }
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_dp(double* a) { NOT_IMPLEMENTED(); }
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_vp(void* p) { NOT_IMPLEMENTED(); }
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_idpdpdp(int32_t n, double* a, const double* b, const double* c) { NOT_IMPLEMENTED(); }
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_uvpcvpcvp(uint32_t n, void* r, const void* a, const void* b) { NOT_IMPLEMENTED(); }
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_uvpvpcvp(uint32_t n, void* a, void* b, const void* o) { NOT_IMPLEMENTED(); }
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define __always_inline inline __attribute((always_inline))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void internal_accurate_sincos(double* rcos, double* rsin, double x) {
|
||||||
|
double _4_x_over_pi = 4 * x / M_PI;
|
||||||
|
int64_t int_part = ((int64_t)rint(_4_x_over_pi)) & 7;
|
||||||
|
double frac_part = _4_x_over_pi - (double)(int_part);
|
||||||
|
double frac_x = M_PI * frac_part / 4.;
|
||||||
|
// compute the taylor series
|
||||||
|
double cosp = 1.;
|
||||||
|
double sinp = 0.;
|
||||||
|
double powx = 1.;
|
||||||
|
int64_t nn = 0;
|
||||||
|
while (fabs(powx) > 1e-20) {
|
||||||
|
++nn;
|
||||||
|
powx = powx * frac_x / (double)(nn); // x^n/n!
|
||||||
|
switch (nn & 3) {
|
||||||
|
case 0:
|
||||||
|
cosp += powx;
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
sinp += powx;
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
cosp -= powx;
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
sinp -= powx;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
abort(); // impossible
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// final multiplication
|
||||||
|
switch (int_part) {
|
||||||
|
case 0:
|
||||||
|
*rcos = cosp;
|
||||||
|
*rsin = sinp;
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
*rcos = M_SQRT1_2 * (cosp - sinp);
|
||||||
|
*rsin = M_SQRT1_2 * (cosp + sinp);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
*rcos = -sinp;
|
||||||
|
*rsin = cosp;
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
*rcos = -M_SQRT1_2 * (cosp + sinp);
|
||||||
|
*rsin = M_SQRT1_2 * (cosp - sinp);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
*rcos = -cosp;
|
||||||
|
*rsin = -sinp;
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
*rcos = -M_SQRT1_2 * (cosp - sinp);
|
||||||
|
*rsin = -M_SQRT1_2 * (cosp + sinp);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
*rcos = sinp;
|
||||||
|
*rsin = -cosp;
|
||||||
|
break;
|
||||||
|
case 7:
|
||||||
|
*rcos = M_SQRT1_2 * (cosp + sinp);
|
||||||
|
*rsin = -M_SQRT1_2 * (cosp - sinp);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
abort(); // impossible
|
||||||
|
}
|
||||||
|
if (fabs(cos(x) - *rcos) > 1e-10 || fabs(sin(x) - *rsin) > 1e-10) {
|
||||||
|
printf("cos(%.17lf) =? %.17lf instead of %.17lf\n", x, *rcos, cos(x));
|
||||||
|
printf("sin(%.17lf) =? %.17lf instead of %.17lf\n", x, *rsin, sin(x));
|
||||||
|
printf("fracx = %.17lf\n", frac_x);
|
||||||
|
printf("cosp = %.17lf\n", cosp);
|
||||||
|
printf("sinp = %.17lf\n", sinp);
|
||||||
|
printf("nn = %d\n", (int)(nn));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
double internal_accurate_cos(double x) {
|
||||||
|
double rcos, rsin;
|
||||||
|
internal_accurate_sincos(&rcos, &rsin, x);
|
||||||
|
return rcos;
|
||||||
|
}
|
||||||
|
double internal_accurate_sin(double x) {
|
||||||
|
double rcos, rsin;
|
||||||
|
internal_accurate_sincos(&rcos, &rsin, x);
|
||||||
|
return rsin;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void spqlios_debug_free(void* addr) { free((uint8_t*)addr - 64); }
|
||||||
|
|
||||||
|
EXPORT void* spqlios_debug_alloc(uint64_t size) { return (uint8_t*)malloc(size + 64) + 64; }
|
||||||
|
|
||||||
|
EXPORT void spqlios_free(void* addr) {
|
||||||
|
#ifndef NDEBUG
|
||||||
|
// in debug mode, we deallocated with spqlios_debug_free()
|
||||||
|
spqlios_debug_free(addr);
|
||||||
|
#else
|
||||||
|
// in release mode, the function will free aligned memory
|
||||||
|
#ifdef _WIN32
|
||||||
|
_aligned_free(addr);
|
||||||
|
#else
|
||||||
|
free(addr);
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void* spqlios_alloc(uint64_t size) {
|
||||||
|
#ifndef NDEBUG
|
||||||
|
// in debug mode, the function will not necessarily have any particular alignment
|
||||||
|
// it will also ensure that memory can only be deallocated with spqlios_free()
|
||||||
|
return spqlios_debug_alloc(size);
|
||||||
|
#else
|
||||||
|
// in release mode, the function will return 64-bytes aligned memory
|
||||||
|
#ifdef _WIN32
|
||||||
|
void* reps = _aligned_malloc((size + 63) & (UINT64_C(-64)), 64);
|
||||||
|
#else
|
||||||
|
void* reps = aligned_alloc(64, (size + 63) & (UINT64_C(-64)));
|
||||||
|
#endif
|
||||||
|
if (reps == 0) FATAL_ERROR("Out of memory");
|
||||||
|
return reps;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void* spqlios_alloc_custom_align(uint64_t align, uint64_t size) {
|
||||||
|
#ifndef NDEBUG
|
||||||
|
// in debug mode, the function will not necessarily have any particular alignment
|
||||||
|
// it will also ensure that memory can only be deallocated with spqlios_free()
|
||||||
|
return spqlios_debug_alloc(size);
|
||||||
|
#else
|
||||||
|
// in release mode, the function will return aligned memory
|
||||||
|
#ifdef _WIN32
|
||||||
|
void* reps = _aligned_malloc(size, align);
|
||||||
|
#else
|
||||||
|
void* reps = aligned_alloc(align, size);
|
||||||
|
#endif
|
||||||
|
if (reps == 0) FATAL_ERROR("Out of memory");
|
||||||
|
return reps;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
77
spqlios/lib/spqlios/commons.h
Normal file
77
spqlios/lib/spqlios/commons.h
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
#ifndef SPQLIOS_COMMONS_H
|
||||||
|
#define SPQLIOS_COMMONS_H
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#define EXPORT extern "C"
|
||||||
|
#define EXPORT_DECL extern "C"
|
||||||
|
#else
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#define EXPORT
|
||||||
|
#define EXPORT_DECL extern
|
||||||
|
#define nullptr 0x0;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define UNDEFINED() \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "UNDEFINED!!!\n"); \
|
||||||
|
abort(); \
|
||||||
|
}
|
||||||
|
#define NOT_IMPLEMENTED() \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \
|
||||||
|
abort(); \
|
||||||
|
}
|
||||||
|
#define FATAL_ERROR(MESSAGE) \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \
|
||||||
|
abort(); \
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void* UNDEFINED_p_ii(int32_t n, int32_t m);
|
||||||
|
EXPORT void* UNDEFINED_p_uu(uint32_t n, uint32_t m);
|
||||||
|
EXPORT double* UNDEFINED_dp_pi(const void* p, int32_t n);
|
||||||
|
EXPORT void* UNDEFINED_vp_pi(const void* p, int32_t n);
|
||||||
|
EXPORT void* UNDEFINED_vp_pu(const void* p, uint32_t n);
|
||||||
|
EXPORT void UNDEFINED_v_vpdp(const void* p, double* a);
|
||||||
|
EXPORT void UNDEFINED_v_vpvp(const void* p, void* a);
|
||||||
|
EXPORT double* NOT_IMPLEMENTED_dp_i(int32_t n);
|
||||||
|
EXPORT void* NOT_IMPLEMENTED_vp_i(int32_t n);
|
||||||
|
EXPORT void* NOT_IMPLEMENTED_vp_u(uint32_t n);
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_dp(double* a);
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_vp(void* p);
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_idpdpdp(int32_t n, double* a, const double* b, const double* c);
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_uvpcvpcvp(uint32_t n, void* r, const void* a, const void* b);
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_uvpvpcvp(uint32_t n, void* a, void* b, const void* o);
|
||||||
|
|
||||||
|
// windows
|
||||||
|
|
||||||
|
#if defined(_WIN32) || defined(__APPLE__)
|
||||||
|
#define __always_inline inline __attribute((always_inline))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
EXPORT void spqlios_free(void* address);
|
||||||
|
|
||||||
|
EXPORT void* spqlios_alloc(uint64_t size);
|
||||||
|
EXPORT void* spqlios_alloc_custom_align(uint64_t align, uint64_t size);
|
||||||
|
|
||||||
|
#define USE_LIBM_SIN_COS
|
||||||
|
#ifndef USE_LIBM_SIN_COS
|
||||||
|
// if at some point, we want to remove the libm dependency, we can
|
||||||
|
// consider this:
|
||||||
|
EXPORT double internal_accurate_cos(double x);
|
||||||
|
EXPORT double internal_accurate_sin(double x);
|
||||||
|
EXPORT void internal_accurate_sincos(double* rcos, double* rsin, double x);
|
||||||
|
#define m_accurate_cos internal_accurate_cos
|
||||||
|
#define m_accurate_sin internal_accurate_sin
|
||||||
|
#else
|
||||||
|
// let's use libm sin and cos
|
||||||
|
#define m_accurate_cos cos
|
||||||
|
#define m_accurate_sin sin
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // SPQLIOS_COMMONS_H
|
||||||
55
spqlios/lib/spqlios/commons_private.c
Normal file
55
spqlios/lib/spqlios/commons_private.c
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
#include "commons_private.h"
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include "commons.h"
|
||||||
|
|
||||||
|
EXPORT void* spqlios_error(const char* error) {
|
||||||
|
fputs(error, stderr);
|
||||||
|
abort();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2) {
|
||||||
|
if (!ptr2) {
|
||||||
|
free(ptr);
|
||||||
|
}
|
||||||
|
return ptr2;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint32_t log2m(uint32_t m) {
|
||||||
|
uint32_t a = m - 1;
|
||||||
|
if (m & a) FATAL_ERROR("m must be a power of two");
|
||||||
|
a = (a & 0x55555555u) + ((a >> 1) & 0x55555555u);
|
||||||
|
a = (a & 0x33333333u) + ((a >> 2) & 0x33333333u);
|
||||||
|
a = (a & 0x0F0F0F0Fu) + ((a >> 4) & 0x0F0F0F0Fu);
|
||||||
|
a = (a & 0x00FF00FFu) + ((a >> 8) & 0x00FF00FFu);
|
||||||
|
return (a & 0x0000FFFFu) + ((a >> 16) & 0x0000FFFFu);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t is_not_pow2_double(void* doublevalue) { return (*(uint64_t*)doublevalue) & 0x7FFFFFFFFFFFFUL; }
|
||||||
|
|
||||||
|
uint32_t revbits(uint32_t nbits, uint32_t value) {
|
||||||
|
uint32_t res = 0;
|
||||||
|
for (uint32_t i = 0; i < nbits; ++i) {
|
||||||
|
res = (res << 1) + (value & 1);
|
||||||
|
value >>= 1;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,...
|
||||||
|
* essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/
|
||||||
|
double fracrevbits(uint32_t i) {
|
||||||
|
if (i == 0) return 0;
|
||||||
|
if (i == 1) return 0.5;
|
||||||
|
if (i % 2 == 0)
|
||||||
|
return fracrevbits(i / 2) / 2.;
|
||||||
|
else
|
||||||
|
return fracrevbits((i - 1) / 2) / 2. + 0.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t ceilto64b(uint64_t size) { return (size + UINT64_C(63)) & (UINT64_C(-64)); }
|
||||||
|
|
||||||
|
uint64_t ceilto32b(uint64_t size) { return (size + UINT64_C(31)) & (UINT64_C(-32)); }
|
||||||
72
spqlios/lib/spqlios/commons_private.h
Normal file
72
spqlios/lib/spqlios/commons_private.h
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
#ifndef SPQLIOS_COMMONS_PRIVATE_H
|
||||||
|
#define SPQLIOS_COMMONS_PRIVATE_H
|
||||||
|
|
||||||
|
#include "commons.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#else
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#define nullptr 0x0;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/** @brief log2 of a power of two (UB if m is not a power of two) */
|
||||||
|
EXPORT uint32_t log2m(uint32_t m);
|
||||||
|
|
||||||
|
/** @brief checks if the doublevalue is a power of two */
|
||||||
|
EXPORT uint64_t is_not_pow2_double(void* doublevalue);
|
||||||
|
|
||||||
|
#define UNDEFINED() \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "UNDEFINED!!!\n"); \
|
||||||
|
abort(); \
|
||||||
|
}
|
||||||
|
#define NOT_IMPLEMENTED() \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \
|
||||||
|
abort(); \
|
||||||
|
}
|
||||||
|
#define NOT_SUPPORTED() \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "NOT SUPPORTED!!!\n"); \
|
||||||
|
abort(); \
|
||||||
|
}
|
||||||
|
#define FATAL_ERROR(MESSAGE) \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \
|
||||||
|
abort(); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)])
|
||||||
|
|
||||||
|
/** @brief reports the error and returns nullptr */
|
||||||
|
EXPORT void* spqlios_error(const char* error);
|
||||||
|
/** @brief if ptr2 is not null, returns ptr, otherwise free ptr and return null */
|
||||||
|
EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2);
|
||||||
|
|
||||||
|
#ifdef __x86_64__
|
||||||
|
#define CPU_SUPPORTS __builtin_cpu_supports
|
||||||
|
#else
|
||||||
|
// TODO for now, we do not have any optimization for non x86 targets
|
||||||
|
#define CPU_SUPPORTS(xxxx) 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/** @brief returns the n bits of value in reversed order */
|
||||||
|
EXPORT uint32_t revbits(uint32_t nbits, uint32_t value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,...
|
||||||
|
* essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/
|
||||||
|
EXPORT double fracrevbits(uint32_t i);
|
||||||
|
|
||||||
|
/** @brief smallest multiple of 64 higher or equal to size */
|
||||||
|
EXPORT uint64_t ceilto64b(uint64_t size);
|
||||||
|
|
||||||
|
/** @brief smallest multiple of 32 higher or equal to size */
|
||||||
|
EXPORT uint64_t ceilto32b(uint64_t size);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_COMMONS_PRIVATE_H
|
||||||
22
spqlios/lib/spqlios/cplx/README.md
Normal file
22
spqlios/lib/spqlios/cplx/README.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
In this folder, we deal with the full complex FFT in `C[X] mod X^M-i`.
|
||||||
|
One complex is represented by two consecutive doubles `(real,imag)`
|
||||||
|
Note that a real polynomial sum_{j=0}^{N-1} p_j.X^j mod X^N+1
|
||||||
|
corresponds to the complex polynomial of half degree `M=N/2`:
|
||||||
|
`sum_{j=0}^{M-1} (p_{j} + i.p_{j+M}) X^j mod X^M-i`
|
||||||
|
|
||||||
|
For a complex polynomial A(X) sum c_i X^i of degree M-1
|
||||||
|
or a real polynomial sum a_i X^i of degree N
|
||||||
|
|
||||||
|
coefficient space:
|
||||||
|
a_0,a_M,a_1,a_{M+1},...,a_{M-1},a_{2M-1}
|
||||||
|
or equivalently
|
||||||
|
Re(c_0),Im(c_0),Re(c_1),Im(c_1),...Re(c_{M-1}),Im(c_{M-1})
|
||||||
|
|
||||||
|
eval space:
|
||||||
|
c(omega_{0}),...,c(omega_{M-1})
|
||||||
|
|
||||||
|
where
|
||||||
|
omega_j = omega^{1+rev_{2N}(j)}
|
||||||
|
and omega = exp(i.pi/N)
|
||||||
|
|
||||||
|
rev_{2N}(j) is the number that has the log2(2N) bits of j in reverse order.
|
||||||
80
spqlios/lib/spqlios/cplx/cplx_common.c
Normal file
80
spqlios/lib/spqlios/cplx/cplx_common.c
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
|
||||||
|
void cplx_set(CPLX r, const CPLX a) {
|
||||||
|
r[0] = a[0];
|
||||||
|
r[1] = a[1];
|
||||||
|
}
|
||||||
|
void cplx_neg(CPLX r, const CPLX a) {
|
||||||
|
r[0] = -a[0];
|
||||||
|
r[1] = -a[1];
|
||||||
|
}
|
||||||
|
void cplx_add(CPLX r, const CPLX a, const CPLX b) {
|
||||||
|
r[0] = a[0] + b[0];
|
||||||
|
r[1] = a[1] + b[1];
|
||||||
|
}
|
||||||
|
void cplx_sub(CPLX r, const CPLX a, const CPLX b) {
|
||||||
|
r[0] = a[0] - b[0];
|
||||||
|
r[1] = a[1] - b[1];
|
||||||
|
}
|
||||||
|
void cplx_mul(CPLX r, const CPLX a, const CPLX b) {
|
||||||
|
double re = a[0] * b[0] - a[1] * b[1];
|
||||||
|
r[1] = a[0] * b[1] + a[1] * b[0];
|
||||||
|
r[0] = re;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief splits 2h evaluations of one polynomials into 2 times h evaluations of even/odd polynomial
|
||||||
|
* Input: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y)
|
||||||
|
* Output: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z)
|
||||||
|
* where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z
|
||||||
|
* @param h number of "coefficients" h >= 1
|
||||||
|
* @param data 2h complex coefficients interleaved and 256b aligned
|
||||||
|
* @param powom y represented as (yre,yim)
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_split_fft_ref(int32_t h, CPLX* data, const CPLX powom) {
|
||||||
|
CPLX* d0 = data;
|
||||||
|
CPLX* d1 = data + h;
|
||||||
|
for (uint64_t i = 0; i < h; ++i) {
|
||||||
|
CPLX diff;
|
||||||
|
cplx_sub(diff, d0[i], d1[i]);
|
||||||
|
cplx_add(d0[i], d0[i], d1[i]);
|
||||||
|
cplx_mul(d1[i], diff, powom);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Do two layers of itwiddle (i.e. split).
|
||||||
|
* Input/output: d0,d1,d2,d3 of length h
|
||||||
|
* Algo:
|
||||||
|
* itwiddle(d0,d1,om[0]),itwiddle(d2,d3,i.om[0])
|
||||||
|
* itwiddle(d0,d2,om[1]),itwiddle(d1,d3,om[1])
|
||||||
|
* @param h number of "coefficients" h >= 1
|
||||||
|
* @param data 4h complex coefficients interleaved and 256b aligned
|
||||||
|
* @param powom om[0] (re,im) and om[1] where om[1]=om[0]^2
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_bisplit_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]) {
|
||||||
|
CPLX* d0 = data;
|
||||||
|
CPLX* d2 = data + 2*h;
|
||||||
|
const CPLX* om0 = powom;
|
||||||
|
CPLX iom0;
|
||||||
|
iom0[0]=powom[0][1];
|
||||||
|
iom0[1]=-powom[0][0];
|
||||||
|
const CPLX* om1 = powom+1;
|
||||||
|
cplx_split_fft_ref(h, d0, *om0);
|
||||||
|
cplx_split_fft_ref(h, d2, iom0);
|
||||||
|
cplx_split_fft_ref(2*h, d0, *om1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Input: Q(y),Q(-y)
|
||||||
|
* Output: P_0(z),P_1(z)
|
||||||
|
* where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z
|
||||||
|
* @param data 2 complexes coefficients interleaved and 256b aligned
|
||||||
|
* @param powom (z,-z) interleaved: (zre,zim,-zre,-zim)
|
||||||
|
*/
|
||||||
|
void split_fft_last_ref(CPLX* data, const CPLX powom) {
|
||||||
|
CPLX diff;
|
||||||
|
cplx_sub(diff, data[0], data[1]);
|
||||||
|
cplx_add(data[0], data[0], data[1]);
|
||||||
|
cplx_mul(data[1], diff, powom);
|
||||||
|
}
|
||||||
158
spqlios/lib/spqlios/cplx/cplx_conversions.c
Normal file
158
spqlios/lib/spqlios/cplx/cplx_conversions.c
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
#include <errno.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../commons_private.h"
|
||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
EXPORT void cplx_from_znx32_ref(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const int32_t* inre = x;
|
||||||
|
const int32_t* inim = x + m;
|
||||||
|
CPLX* out = r;
|
||||||
|
for (uint32_t i = 0; i < m; ++i) {
|
||||||
|
out[i][0] = (double)inre[i];
|
||||||
|
out[i][1] = (double)inim[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_from_tnx32_ref(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||||
|
static const double _2p32 = 1. / (INT64_C(1) << 32);
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const int32_t* inre = x;
|
||||||
|
const int32_t* inim = x + m;
|
||||||
|
CPLX* out = r;
|
||||||
|
for (uint32_t i = 0; i < m; ++i) {
|
||||||
|
out[i][0] = ((double)inre[i]) * _2p32;
|
||||||
|
out[i][1] = ((double)inim[i]) * _2p32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_to_tnx32_ref(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) {
|
||||||
|
static const double _2p32 = (INT64_C(1) << 32);
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
double factor = _2p32 / precomp->divisor;
|
||||||
|
int32_t* outre = r;
|
||||||
|
int32_t* outim = r + m;
|
||||||
|
const CPLX* in = x;
|
||||||
|
// Note: this formula will only work if abs(in) < 2^32
|
||||||
|
for (uint32_t i = 0; i < m; ++i) {
|
||||||
|
outre[i] = (int32_t)(int64_t)(rint(in[i][0] * factor));
|
||||||
|
outim[i] = (int32_t)(int64_t)(rint(in[i][1] * factor));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void* init_cplx_from_znx32_precomp(CPLX_FROM_ZNX32_PRECOMP* res, uint32_t m) {
|
||||||
|
res->m = m;
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
if (m >= 8) {
|
||||||
|
res->function = cplx_from_znx32_avx2_fma;
|
||||||
|
} else {
|
||||||
|
res->function = cplx_from_znx32_ref;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res->function = cplx_from_znx32_ref;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
CPLX_FROM_ZNX32_PRECOMP* new_cplx_from_znx32_precomp(uint32_t m) {
|
||||||
|
CPLX_FROM_ZNX32_PRECOMP* res = malloc(sizeof(CPLX_FROM_ZNX32_PRECOMP));
|
||||||
|
if (!res) return spqlios_error(strerror(errno));
|
||||||
|
return spqlios_keep_or_free(res, init_cplx_from_znx32_precomp(res, m));
|
||||||
|
}
|
||||||
|
|
||||||
|
void* init_cplx_from_tnx32_precomp(CPLX_FROM_TNX32_PRECOMP* res, uint32_t m) {
|
||||||
|
res->m = m;
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
if (m >= 8) {
|
||||||
|
res->function = cplx_from_tnx32_avx2_fma;
|
||||||
|
} else {
|
||||||
|
res->function = cplx_from_tnx32_ref;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res->function = cplx_from_tnx32_ref;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
CPLX_FROM_TNX32_PRECOMP* new_cplx_from_tnx32_precomp(uint32_t m) {
|
||||||
|
CPLX_FROM_TNX32_PRECOMP* res = malloc(sizeof(CPLX_FROM_TNX32_PRECOMP));
|
||||||
|
if (!res) return spqlios_error(strerror(errno));
|
||||||
|
return spqlios_keep_or_free(res, init_cplx_from_tnx32_precomp(res, m));
|
||||||
|
}
|
||||||
|
|
||||||
|
void* init_cplx_to_tnx32_precomp(CPLX_TO_TNX32_PRECOMP* res, uint32_t m, double divisor, uint32_t log2overhead) {
|
||||||
|
if (is_not_pow2_double(&divisor)) return spqlios_error("divisor must be a power of 2");
|
||||||
|
if (m & (m - 1)) return spqlios_error("m must be a power of 2");
|
||||||
|
if (log2overhead > 52) return spqlios_error("log2overhead is too large");
|
||||||
|
res->m = m;
|
||||||
|
res->divisor = divisor;
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
if (log2overhead <= 18) {
|
||||||
|
if (m >= 8) {
|
||||||
|
res->function = cplx_to_tnx32_avx2_fma;
|
||||||
|
} else {
|
||||||
|
res->function = cplx_to_tnx32_ref;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res->function = cplx_to_tnx32_ref;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res->function = cplx_to_tnx32_ref;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT CPLX_TO_TNX32_PRECOMP* new_cplx_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead) {
|
||||||
|
CPLX_TO_TNX32_PRECOMP* res = malloc(sizeof(CPLX_TO_TNX32_PRECOMP));
|
||||||
|
if (!res) return spqlios_error(strerror(errno));
|
||||||
|
return spqlios_keep_or_free(res, init_cplx_to_tnx32_precomp(res, m, divisor, log2overhead));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the znx32 to cplx conversion.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_from_znx32_simple(uint32_t m, void* r, const int32_t* x) {
|
||||||
|
// not checking for log2bound which is not relevant here
|
||||||
|
static CPLX_FROM_ZNX32_PRECOMP precomp[32];
|
||||||
|
CPLX_FROM_ZNX32_PRECOMP* p = precomp + log2m(m);
|
||||||
|
if (!p->function) {
|
||||||
|
if (!init_cplx_from_znx32_precomp(p, m)) abort();
|
||||||
|
}
|
||||||
|
p->function(p, r, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the tnx32 to cplx conversion.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_from_tnx32_simple(uint32_t m, void* r, const int32_t* x) {
|
||||||
|
static CPLX_FROM_TNX32_PRECOMP precomp[32];
|
||||||
|
CPLX_FROM_TNX32_PRECOMP* p = precomp + log2m(m);
|
||||||
|
if (!p->function) {
|
||||||
|
if (!init_cplx_from_tnx32_precomp(p, m)) abort();
|
||||||
|
}
|
||||||
|
p->function(p, r, x);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the cplx to tnx32 conversion.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x) {
|
||||||
|
struct LAST_CPLX_TO_TNX32_PRECOMP {
|
||||||
|
CPLX_TO_TNX32_PRECOMP p;
|
||||||
|
double last_divisor;
|
||||||
|
double last_log2over;
|
||||||
|
};
|
||||||
|
static __thread struct LAST_CPLX_TO_TNX32_PRECOMP precomp[32];
|
||||||
|
struct LAST_CPLX_TO_TNX32_PRECOMP* p = precomp + log2m(m);
|
||||||
|
if (!p->p.function || divisor != p->last_divisor || log2overhead != p->last_log2over) {
|
||||||
|
memset(p, 0, sizeof(*p));
|
||||||
|
if (!init_cplx_to_tnx32_precomp(&p->p, m, divisor, log2overhead)) abort();
|
||||||
|
p->last_divisor = divisor;
|
||||||
|
p->last_log2over = log2overhead;
|
||||||
|
}
|
||||||
|
p->p.function(&p->p, r, x);
|
||||||
|
}
|
||||||
104
spqlios/lib/spqlios/cplx/cplx_conversions_avx2_fma.c
Normal file
104
spqlios/lib/spqlios/cplx/cplx_conversions_avx2_fma.c
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
#include <immintrin.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
typedef int32_t I8MEM[8];
|
||||||
|
typedef double D4MEM[4];
|
||||||
|
|
||||||
|
__always_inline void cplx_from_any_fma(uint64_t m, void* r, const int32_t* x, const __m256i C, const __m256d R) {
|
||||||
|
const __m256i S = _mm256_set1_epi32(0x80000000);
|
||||||
|
const I8MEM* inre = (I8MEM*)(x);
|
||||||
|
const I8MEM* inim = (I8MEM*)(x+m);
|
||||||
|
D4MEM* out = (D4MEM*) r;
|
||||||
|
const uint64_t ms8 = m/8;
|
||||||
|
for (uint32_t i=0; i<ms8; ++i) {
|
||||||
|
__m256i rea = _mm256_loadu_si256((__m256i*) inre[0]);
|
||||||
|
__m256i ima = _mm256_loadu_si256((__m256i*) inim[0]);
|
||||||
|
rea = _mm256_add_epi32(rea, S);
|
||||||
|
ima = _mm256_add_epi32(ima, S);
|
||||||
|
__m256i tmpa = _mm256_unpacklo_epi32(rea, ima);
|
||||||
|
__m256i tmpc = _mm256_unpackhi_epi32(rea, ima);
|
||||||
|
__m256i cpla = _mm256_permute2x128_si256(tmpa,tmpc,0x20);
|
||||||
|
__m256i cplc = _mm256_permute2x128_si256(tmpa,tmpc,0x31);
|
||||||
|
tmpa = _mm256_unpacklo_epi32(cpla, C);
|
||||||
|
__m256i tmpb = _mm256_unpackhi_epi32(cpla, C);
|
||||||
|
tmpc = _mm256_unpacklo_epi32(cplc, C);
|
||||||
|
__m256i tmpd = _mm256_unpackhi_epi32(cplc, C);
|
||||||
|
cpla = _mm256_permute2x128_si256(tmpa,tmpb,0x20);
|
||||||
|
__m256i cplb = _mm256_permute2x128_si256(tmpa,tmpb,0x31);
|
||||||
|
cplc = _mm256_permute2x128_si256(tmpc,tmpd,0x20);
|
||||||
|
__m256i cpld = _mm256_permute2x128_si256(tmpc,tmpd,0x31);
|
||||||
|
__m256d dcpla = _mm256_sub_pd(_mm256_castsi256_pd(cpla), R);
|
||||||
|
__m256d dcplb = _mm256_sub_pd(_mm256_castsi256_pd(cplb), R);
|
||||||
|
__m256d dcplc = _mm256_sub_pd(_mm256_castsi256_pd(cplc), R);
|
||||||
|
__m256d dcpld = _mm256_sub_pd(_mm256_castsi256_pd(cpld), R);
|
||||||
|
_mm256_storeu_pd(out[0], dcpla);
|
||||||
|
_mm256_storeu_pd(out[1], dcplb);
|
||||||
|
_mm256_storeu_pd(out[2], dcplc);
|
||||||
|
_mm256_storeu_pd(out[3], dcpld);
|
||||||
|
inre += 1;
|
||||||
|
inim += 1;
|
||||||
|
out += 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||||
|
//note: the hex code of 2^31 + 2^52 is 0x4330000080000000
|
||||||
|
const __m256i C = _mm256_set1_epi32(0x43300000);
|
||||||
|
const __m256d R = _mm256_set1_pd((INT64_C(1) << 31) + (INT64_C(1) << 52));
|
||||||
|
// double XX = INT64_C(1) + (INT64_C(1)<<31) + (INT64_C(1)<<52);
|
||||||
|
//printf("\n\n%016lx\n", *(uint64_t*)&XX);
|
||||||
|
//abort();
|
||||||
|
const uint64_t m = precomp->m;
|
||||||
|
cplx_from_any_fma(m, r, x, C, R);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||||
|
//note: the hex code of 2^-1 + 2^30 is 0x4130000080000000
|
||||||
|
const __m256i C = _mm256_set1_epi32(0x41300000);
|
||||||
|
const __m256d R = _mm256_set1_pd(0.5 + (INT64_C(1) << 20));
|
||||||
|
// double XX = (double)(INT64_C(1) + (INT64_C(1)<<31) + (INT64_C(1)<<52))/(INT64_C(1)<<32);
|
||||||
|
//printf("\n\n%016lx\n", *(uint64_t*)&XX);
|
||||||
|
//abort();
|
||||||
|
const uint64_t m = precomp->m;
|
||||||
|
cplx_from_any_fma(m, r, x, C, R);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) {
|
||||||
|
const __m256d R = _mm256_set1_pd((0.5 + (INT64_C(3) << 19)) * precomp->divisor);
|
||||||
|
const __m256i MASK = _mm256_set1_epi64x(0xFFFFFFFFUL);
|
||||||
|
const __m256i S = _mm256_set1_epi32(0x80000000);
|
||||||
|
//const __m256i IDX = _mm256_set_epi32(0,4,1,5,2,6,3,7);
|
||||||
|
const __m256i IDX = _mm256_set_epi32(7,3,6,2,5,1,4,0);
|
||||||
|
const uint64_t m = precomp->m;
|
||||||
|
const uint64_t ms8 = m/8;
|
||||||
|
I8MEM* outre = (I8MEM*) r;
|
||||||
|
I8MEM* outim = (I8MEM*) (r+m);
|
||||||
|
const D4MEM* in = x;
|
||||||
|
// Note: this formula will only work if abs(in) < 2^32
|
||||||
|
for (uint32_t i=0; i<ms8; ++i) {
|
||||||
|
__m256d cpla = _mm256_loadu_pd(in[0]);
|
||||||
|
__m256d cplb = _mm256_loadu_pd(in[1]);
|
||||||
|
__m256d cplc = _mm256_loadu_pd(in[2]);
|
||||||
|
__m256d cpld = _mm256_loadu_pd(in[3]);
|
||||||
|
__m256i icpla = _mm256_castpd_si256(_mm256_add_pd(cpla, R));
|
||||||
|
__m256i icplb = _mm256_castpd_si256(_mm256_add_pd(cplb, R));
|
||||||
|
__m256i icplc = _mm256_castpd_si256(_mm256_add_pd(cplc, R));
|
||||||
|
__m256i icpld = _mm256_castpd_si256(_mm256_add_pd(cpld, R));
|
||||||
|
icpla = _mm256_or_si256(_mm256_and_si256(icpla, MASK), _mm256_slli_epi64(icplb, 32));
|
||||||
|
icplc = _mm256_or_si256(_mm256_and_si256(icplc, MASK), _mm256_slli_epi64(icpld, 32));
|
||||||
|
icpla = _mm256_xor_si256(icpla, S);
|
||||||
|
icplc = _mm256_xor_si256(icplc, S);
|
||||||
|
__m256i re = _mm256_unpacklo_epi64(icpla, icplc);
|
||||||
|
__m256i im = _mm256_unpackhi_epi64(icpla, icplc);
|
||||||
|
re = _mm256_permutevar8x32_epi32(re, IDX);
|
||||||
|
im = _mm256_permutevar8x32_epi32(im, IDX);
|
||||||
|
_mm256_storeu_si256((__m256i*)outre[0], re);
|
||||||
|
_mm256_storeu_si256((__m256i*)outim[0], im);
|
||||||
|
outre += 1;
|
||||||
|
outim += 1;
|
||||||
|
in += 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
18
spqlios/lib/spqlios/cplx/cplx_execute.c
Normal file
18
spqlios/lib/spqlios/cplx/cplx_execute.c
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
EXPORT void cplx_from_znx32(const CPLX_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a) {
|
||||||
|
tables->function(tables, r, a);
|
||||||
|
}
|
||||||
|
EXPORT void cplx_from_tnx32(const CPLX_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a) {
|
||||||
|
tables->function(tables, r, a);
|
||||||
|
}
|
||||||
|
EXPORT void cplx_to_tnx32(const CPLX_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a) {
|
||||||
|
tables->function(tables, r, a);
|
||||||
|
}
|
||||||
|
EXPORT void cplx_fftvec_mul(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||||
|
tables->function(tables, r, a, b);
|
||||||
|
}
|
||||||
|
EXPORT void cplx_fftvec_addmul(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||||
|
tables->function(tables, r, a, b);
|
||||||
|
}
|
||||||
41
spqlios/lib/spqlios/cplx/cplx_fallbacks_aarch64.c
Normal file
41
spqlios/lib/spqlios/cplx/cplx_fallbacks_aarch64.c
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||||
|
UNDEFINED(); // not defined for non x86 targets
|
||||||
|
}
|
||||||
|
EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||||
|
UNDEFINED();
|
||||||
|
}
|
||||||
|
EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||||
|
UNDEFINED();
|
||||||
|
}
|
||||||
|
EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a,
|
||||||
|
const void* b) {
|
||||||
|
UNDEFINED();
|
||||||
|
}
|
||||||
|
EXPORT void cplx_fft16_avx_fma(void* data, const void* omega) { UNDEFINED(); }
|
||||||
|
EXPORT void cplx_ifft16_avx_fma(void* data, const void* omega) { UNDEFINED(); }
|
||||||
|
EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); }
|
||||||
|
EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); }
|
||||||
|
EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c) { UNDEFINED(); }
|
||||||
|
EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* tables, void* data){UNDEFINED()} EXPORT
|
||||||
|
void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* itables, void* data){UNDEFINED()} EXPORT
|
||||||
|
void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om){
|
||||||
|
UNDEFINED()} EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b,
|
||||||
|
const void* om){UNDEFINED()} EXPORT
|
||||||
|
void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
|
||||||
|
const void* om){UNDEFINED()} EXPORT
|
||||||
|
void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
|
||||||
|
const void* om){UNDEFINED()}
|
||||||
|
|
||||||
|
// DEPRECATED?
|
||||||
|
EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT
|
||||||
|
void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT
|
||||||
|
void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a){UNDEFINED()}
|
||||||
|
|
||||||
|
// executors
|
||||||
|
//EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* itables, void* data) {
|
||||||
|
// itables->function(itables, data);
|
||||||
|
//}
|
||||||
|
//EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data) { tables->function(tables, data); }
|
||||||
221
spqlios/lib/spqlios/cplx/cplx_fft.h
Normal file
221
spqlios/lib/spqlios/cplx/cplx_fft.h
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
#ifndef SPQLIOS_CPLX_FFT_H
|
||||||
|
#define SPQLIOS_CPLX_FFT_H
|
||||||
|
|
||||||
|
#include "../commons.h"
|
||||||
|
|
||||||
|
typedef struct cplx_fft_precomp CPLX_FFT_PRECOMP;
|
||||||
|
typedef struct cplx_ifft_precomp CPLX_IFFT_PRECOMP;
|
||||||
|
typedef struct cplx_mul_precomp CPLX_FFTVEC_MUL_PRECOMP;
|
||||||
|
typedef struct cplx_addmul_precomp CPLX_FFTVEC_ADDMUL_PRECOMP;
|
||||||
|
typedef struct cplx_from_znx32_precomp CPLX_FROM_ZNX32_PRECOMP;
|
||||||
|
typedef struct cplx_from_tnx32_precomp CPLX_FROM_TNX32_PRECOMP;
|
||||||
|
typedef struct cplx_to_tnx32_precomp CPLX_TO_TNX32_PRECOMP;
|
||||||
|
typedef struct cplx_to_znx32_precomp CPLX_TO_ZNX32_PRECOMP;
|
||||||
|
typedef struct cplx_from_rnx64_precomp CPLX_FROM_RNX64_PRECOMP;
|
||||||
|
typedef struct cplx_to_rnx64_precomp CPLX_TO_RNX64_PRECOMP;
|
||||||
|
typedef struct cplx_round_to_rnx64_precomp CPLX_ROUND_TO_RNX64_PRECOMP;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief precomputes fft tables.
|
||||||
|
* The FFT tables contains a constant section that is required for efficient FFT operations in dimension nn.
|
||||||
|
* The resulting pointer is to be passed as "tables" argument to any call to the fft function.
|
||||||
|
* The user can optionnally allocate zero or more computation buffers, which are scratch spaces that are contiguous to
|
||||||
|
* the constant tables in memory, and allow for more efficient operations. It is the user's responsibility to ensure
|
||||||
|
* that each of those buffers are never used simultaneously by two ffts on different threads at the same time. The fft
|
||||||
|
* table must be deleted by delete_fft_precomp after its last usage.
|
||||||
|
*/
|
||||||
|
EXPORT CPLX_FFT_PRECOMP* new_cplx_fft_precomp(uint32_t m, uint32_t num_buffers);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief gets the address of a fft buffer allocated during new_fft_precomp.
|
||||||
|
* This buffer can be used as data pointer in subsequent calls to fft,
|
||||||
|
* and does not need to be released afterwards.
|
||||||
|
*/
|
||||||
|
EXPORT void* cplx_fft_precomp_get_buffer(const CPLX_FFT_PRECOMP* tables, uint32_t buffer_index);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief allocates a new fft buffer.
|
||||||
|
* This buffer can be used as data pointer in subsequent calls to fft,
|
||||||
|
* and must be deleted afterwards by calling delete_fft_buffer.
|
||||||
|
*/
|
||||||
|
EXPORT void* new_cplx_fft_buffer(uint32_t m);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief allocates a new fft buffer.
|
||||||
|
* This buffer can be used as data pointer in subsequent calls to fft,
|
||||||
|
* and must be deleted afterwards by calling delete_fft_buffer.
|
||||||
|
*/
|
||||||
|
EXPORT void delete_cplx_fft_buffer(void* buffer);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief deallocates a fft table and all its built-in buffers.
|
||||||
|
*/
|
||||||
|
#define delete_cplx_fft_precomp free
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief computes a direct fft in-place over data.
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data);
|
||||||
|
|
||||||
|
EXPORT CPLX_IFFT_PRECOMP* new_cplx_ifft_precomp(uint32_t m, uint32_t num_buffers);
|
||||||
|
EXPORT void* cplx_ifft_precomp_get_buffer(const CPLX_IFFT_PRECOMP* tables, uint32_t buffer_index);
|
||||||
|
EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* tables, void* data);
|
||||||
|
#define delete_cplx_ifft_precomp free
|
||||||
|
|
||||||
|
EXPORT CPLX_FFTVEC_MUL_PRECOMP* new_cplx_fftvec_mul_precomp(uint32_t m);
|
||||||
|
EXPORT void cplx_fftvec_mul(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b);
|
||||||
|
#define delete_cplx_fftvec_mul_precomp free
|
||||||
|
|
||||||
|
EXPORT CPLX_FFTVEC_ADDMUL_PRECOMP* new_cplx_fftvec_addmul_precomp(uint32_t m);
|
||||||
|
EXPORT void cplx_fftvec_addmul(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b);
|
||||||
|
#define delete_cplx_fftvec_addmul_precomp free
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief prepares a conversion from ZnX to the cplx layout.
|
||||||
|
* All the coefficients must be strictly lower than 2^log2bound in absolute value. Any attempt to use
|
||||||
|
* this function on a larger coefficient is undefined behaviour. The resulting precomputed data must
|
||||||
|
* be freed with `new_cplx_from_znx32_precomp`
|
||||||
|
* @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m
|
||||||
|
* int32 coefficients in natural order modulo X^n+1
|
||||||
|
* @param log2bound bound on the input coefficients. Must be between 0 and 32
|
||||||
|
*/
|
||||||
|
EXPORT CPLX_FROM_ZNX32_PRECOMP* new_cplx_from_znx32_precomp(uint32_t m);
|
||||||
|
/**
|
||||||
|
* @brief converts from ZnX to the cplx layout.
|
||||||
|
* @param tables precomputed data obtained by new_cplx_from_znx32_precomp.
|
||||||
|
* @param r resulting array of m complexes coefficients mod X^m-i
|
||||||
|
* @param x input array of n bounded integer coefficients mod X^n+1
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_from_znx32(const CPLX_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a);
|
||||||
|
/** @brief frees a precomputed conversion data initialized with new_cplx_from_znx32_precomp. */
|
||||||
|
#define delete_cplx_from_znx32_precomp free
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief prepares a conversion from TnX to the cplx layout.
|
||||||
|
* @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m
|
||||||
|
* torus32 coefficients. The resulting precomputed data must
|
||||||
|
* be freed with `delete_cplx_from_tnx32_precomp`
|
||||||
|
*/
|
||||||
|
EXPORT CPLX_FROM_TNX32_PRECOMP* new_cplx_from_tnx32_precomp(uint32_t m);
|
||||||
|
/**
|
||||||
|
* @brief converts from TnX to the cplx layout.
|
||||||
|
* @param tables precomputed data obtained by new_cplx_from_tnx32_precomp.
|
||||||
|
* @param r resulting array of m complexes coefficients mod X^m-i
|
||||||
|
* @param x input array of n torus32 coefficients mod X^n+1
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_from_tnx32(const CPLX_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a);
|
||||||
|
/** @brief frees a precomputed conversion data initialized with new_cplx_from_tnx32_precomp. */
|
||||||
|
#define delete_cplx_from_tnx32_precomp free
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief prepares a rescale and conversion from the cplx layout to TnX.
|
||||||
|
* @param m the target complex dimension m from C[X] mod X^m-i. Note that the outputs have n=2m
|
||||||
|
* torus32 coefficients.
|
||||||
|
* @param divisor must be a power of two. The inputs are rescaled by divisor before being reduced modulo 1.
|
||||||
|
* Remember that the output of an iFFT must be divided by m.
|
||||||
|
* @param log2overhead all inputs absolute values must be within divisor.2^log2overhead.
|
||||||
|
* For any inputs outside of these bounds, the conversion is undefined behaviour.
|
||||||
|
* The maximum supported log2overhead is 52, and the algorithm is faster for log2overhead=18.
|
||||||
|
*/
|
||||||
|
EXPORT CPLX_TO_TNX32_PRECOMP* new_cplx_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead);
|
||||||
|
/**
|
||||||
|
* @brief rescale, converts and reduce mod 1 from cplx layout to torus32.
|
||||||
|
* @param tables precomputed data obtained by new_cplx_from_tnx32_precomp.
|
||||||
|
* @param r resulting array of n torus32 coefficients mod X^n+1
|
||||||
|
* @param x input array of m cplx coefficients mod X^m-i
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_to_tnx32(const CPLX_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a);
|
||||||
|
#define delete_cplx_to_tnx32_precomp free
|
||||||
|
|
||||||
|
EXPORT CPLX_TO_ZNX32_PRECOMP* new_cplx_to_znx32_precomp(uint32_t m, double divisor);
|
||||||
|
EXPORT void cplx_to_znx32(const CPLX_TO_ZNX32_PRECOMP* precomp, int32_t* r, const void* x);
|
||||||
|
#define delete_cplx_to_znx32_simple free
|
||||||
|
|
||||||
|
EXPORT CPLX_FROM_RNX64_PRECOMP* new_cplx_from_rnx64_simple(uint32_t m);
|
||||||
|
EXPORT void cplx_from_rnx64(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x);
|
||||||
|
#define delete_cplx_from_rnx64_simple free
|
||||||
|
|
||||||
|
EXPORT CPLX_TO_RNX64_PRECOMP* new_cplx_to_rnx64(uint32_t m, double divisor);
|
||||||
|
EXPORT void cplx_to_rnx64(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||||
|
#define delete_cplx_round_to_rnx64_simple free
|
||||||
|
|
||||||
|
EXPORT CPLX_ROUND_TO_RNX64_PRECOMP* new_cplx_round_to_rnx64(uint32_t m, double divisor, uint32_t log2bound);
|
||||||
|
EXPORT void cplx_round_to_rnx64(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||||
|
#define delete_cplx_round_to_rnx64_simple free
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the fft function.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically.
|
||||||
|
* It is advised to do one dry-run per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_fft_simple(uint32_t m, void* data);
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the ifft function.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension in the main thread before using in a multithread
|
||||||
|
* environment */
|
||||||
|
EXPORT void cplx_ifft_simple(uint32_t m, void* data);
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the fftvec multiplication function.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b);
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the fftvec addmul function.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b);
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the znx32 to cplx conversion.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_from_znx32_simple(uint32_t m, void* r, const int32_t* x);
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the tnx32 to cplx conversion.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_from_tnx32_simple(uint32_t m, void* r, const int32_t* x);
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the cplx to tnx32 conversion.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief converts, divides and round from cplx to znx32 (simple API)
|
||||||
|
* @param m the complex dimension
|
||||||
|
* @param divisor the divisor: a power of two, often m after an ifft
|
||||||
|
* @param r the result: must be a double array of size 2m. r must be distinct from x
|
||||||
|
* @param x the input: must hold m complex numbers.
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_to_znx32_simple(uint32_t m, double divisor, int32_t* r, const void* x);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief converts from rnx64 to cplx (simple API)
|
||||||
|
* The bound on the output is assumed to be within ]2^-31,2^31[.
|
||||||
|
* Any coefficient that would fall outside this range is undefined behaviour.
|
||||||
|
* @param m the complex dimension
|
||||||
|
* @param r the result: must be an array of m complex numbers. r must be distinct from x
|
||||||
|
* @param x the input: must be an array of 2m doubles.
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_from_rnx64_simple(uint32_t m, void* r, const double* x);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief converts, divides from cplx to rnx64 (simple API)
|
||||||
|
* @param m the complex dimension
|
||||||
|
* @param divisor the divisor: a power of two, often m after an ifft
|
||||||
|
* @param r the result: must be a double array of size 2m. r must be distinct from x
|
||||||
|
* @param x the input: must hold m complex numbers.
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_to_rnx64_simple(uint32_t m, double divisor, double* r, const void* x);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief converts, divides and round to integer from cplx to rnx32 (simple API)
|
||||||
|
* @param m the complex dimension
|
||||||
|
* @param divisor the divisor: a power of two, often m after an ifft
|
||||||
|
* @param log2bound a guarantee on the log2bound of the output. log2bound<=48 will use a more efficient algorithm.
|
||||||
|
* @param r the result: must be a double array of size 2m. r must be distinct from x
|
||||||
|
* @param x the input: must hold m complex numbers.
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_round_to_rnx64_simple(uint32_t m, double divisor, uint32_t log2bound, double* r, const void* x);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_CPLX_FFT_H
|
||||||
156
spqlios/lib/spqlios/cplx/cplx_fft16_avx_fma.s
Normal file
156
spqlios/lib/spqlios/cplx/cplx_fft16_avx_fma.s
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
# shifted FFT over X^16-i
|
||||||
|
# 1st argument (rdi) contains 16 complexes
|
||||||
|
# 2nd argument (rsi) contains: 8 complexes
|
||||||
|
# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma
|
||||||
|
# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta)
|
||||||
|
# j = sqrt(i), k=sqrt(j)
|
||||||
|
.globl cplx_fft16_avx_fma
|
||||||
|
cplx_fft16_avx_fma:
|
||||||
|
vmovupd (%rdi),%ymm8
|
||||||
|
vmovupd 0x20(%rdi),%ymm9
|
||||||
|
vmovupd 0x40(%rdi),%ymm10
|
||||||
|
vmovupd 0x60(%rdi),%ymm11
|
||||||
|
vmovupd 0x80(%rdi),%ymm12
|
||||||
|
vmovupd 0xa0(%rdi),%ymm13
|
||||||
|
vmovupd 0xc0(%rdi),%ymm14
|
||||||
|
vmovupd 0xe0(%rdi),%ymm15
|
||||||
|
|
||||||
|
.first_pass:
|
||||||
|
vmovupd (%rsi),%xmm0 /* omri */
|
||||||
|
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||||
|
vshufpd $5, %ymm12, %ymm12, %ymm4
|
||||||
|
vshufpd $5, %ymm13, %ymm13, %ymm5
|
||||||
|
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||||
|
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||||
|
vmulpd %ymm4,%ymm1,%ymm4
|
||||||
|
vmulpd %ymm5,%ymm1,%ymm5
|
||||||
|
vmulpd %ymm6,%ymm1,%ymm6
|
||||||
|
vmulpd %ymm7,%ymm1,%ymm7
|
||||||
|
vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4
|
||||||
|
vfmaddsub231pd %ymm13, %ymm0, %ymm5
|
||||||
|
vfmaddsub231pd %ymm14, %ymm0, %ymm6
|
||||||
|
vfmaddsub231pd %ymm15, %ymm0, %ymm7
|
||||||
|
vsubpd %ymm4,%ymm8,%ymm12
|
||||||
|
vsubpd %ymm5,%ymm9,%ymm13
|
||||||
|
vsubpd %ymm6,%ymm10,%ymm14
|
||||||
|
vsubpd %ymm7,%ymm11,%ymm15
|
||||||
|
vaddpd %ymm4,%ymm8,%ymm8
|
||||||
|
vaddpd %ymm5,%ymm9,%ymm9
|
||||||
|
vaddpd %ymm6,%ymm10,%ymm10
|
||||||
|
vaddpd %ymm7,%ymm11,%ymm11
|
||||||
|
|
||||||
|
.second_pass:
|
||||||
|
vmovupd 16(%rsi),%xmm0 /* omri */
|
||||||
|
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||||
|
vshufpd $5, %ymm10, %ymm10, %ymm4
|
||||||
|
vshufpd $5, %ymm11, %ymm11, %ymm5
|
||||||
|
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||||
|
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||||
|
vmulpd %ymm4,%ymm1,%ymm4
|
||||||
|
vmulpd %ymm5,%ymm1,%ymm5
|
||||||
|
vmulpd %ymm6,%ymm0,%ymm6
|
||||||
|
vmulpd %ymm7,%ymm0,%ymm7
|
||||||
|
vfmaddsub231pd %ymm10, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4
|
||||||
|
vfmaddsub231pd %ymm11, %ymm0, %ymm5
|
||||||
|
vfmsubadd231pd %ymm14, %ymm1, %ymm6
|
||||||
|
vfmsubadd231pd %ymm15, %ymm1, %ymm7
|
||||||
|
vsubpd %ymm4,%ymm8,%ymm10
|
||||||
|
vsubpd %ymm5,%ymm9,%ymm11
|
||||||
|
vaddpd %ymm6,%ymm12,%ymm14
|
||||||
|
vaddpd %ymm7,%ymm13,%ymm15
|
||||||
|
vaddpd %ymm4,%ymm8,%ymm8
|
||||||
|
vaddpd %ymm5,%ymm9,%ymm9
|
||||||
|
vsubpd %ymm6,%ymm12,%ymm12
|
||||||
|
vsubpd %ymm7,%ymm13,%ymm13
|
||||||
|
|
||||||
|
.third_pass:
|
||||||
|
vmovupd 32(%rsi),%xmm0 /* gamma */
|
||||||
|
vmovupd 48(%rsi),%xmm2 /* delta */
|
||||||
|
vinsertf128 $1, %xmm0, %ymm0, %ymm0
|
||||||
|
vinsertf128 $1, %xmm2, %ymm2, %ymm2
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||||
|
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||||
|
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||||
|
vshufpd $5, %ymm9, %ymm9, %ymm4
|
||||||
|
vshufpd $5, %ymm11, %ymm11, %ymm5
|
||||||
|
vshufpd $5, %ymm13, %ymm13, %ymm6
|
||||||
|
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||||
|
vmulpd %ymm4,%ymm1,%ymm4
|
||||||
|
vmulpd %ymm5,%ymm0,%ymm5
|
||||||
|
vmulpd %ymm6,%ymm3,%ymm6
|
||||||
|
vmulpd %ymm7,%ymm2,%ymm7
|
||||||
|
vfmaddsub231pd %ymm9, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4
|
||||||
|
vfmsubadd231pd %ymm11, %ymm1, %ymm5
|
||||||
|
vfmaddsub231pd %ymm13, %ymm2, %ymm6
|
||||||
|
vfmsubadd231pd %ymm15, %ymm3, %ymm7
|
||||||
|
vsubpd %ymm4,%ymm8,%ymm9
|
||||||
|
vaddpd %ymm5,%ymm10,%ymm11
|
||||||
|
vsubpd %ymm6,%ymm12,%ymm13
|
||||||
|
vaddpd %ymm7,%ymm14,%ymm15
|
||||||
|
vaddpd %ymm4,%ymm8,%ymm8
|
||||||
|
vsubpd %ymm5,%ymm10,%ymm10
|
||||||
|
vaddpd %ymm6,%ymm12,%ymm12
|
||||||
|
vsubpd %ymm7,%ymm14,%ymm14
|
||||||
|
|
||||||
|
.fourth_pass:
|
||||||
|
vmovupd 64(%rsi),%ymm0 /* gamma */
|
||||||
|
vmovupd 96(%rsi),%ymm2 /* delta */
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||||
|
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||||
|
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||||
|
vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5 -- x gamma
|
||||||
|
vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7 -- x igamma
|
||||||
|
vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13 -- x delta
|
||||||
|
vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15 -- x idelta
|
||||||
|
vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||||
|
vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||||
|
vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12
|
||||||
|
vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14
|
||||||
|
vshufpd $5, %ymm4, %ymm4, %ymm12
|
||||||
|
vshufpd $5, %ymm5, %ymm5, %ymm13
|
||||||
|
vshufpd $5, %ymm6, %ymm6, %ymm14
|
||||||
|
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||||
|
vmulpd %ymm12,%ymm1,%ymm12
|
||||||
|
vmulpd %ymm13,%ymm0,%ymm13
|
||||||
|
vmulpd %ymm14,%ymm3,%ymm14
|
||||||
|
vmulpd %ymm15,%ymm2,%ymm15
|
||||||
|
vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12
|
||||||
|
vfmsubadd231pd %ymm5, %ymm1, %ymm13
|
||||||
|
vfmaddsub231pd %ymm6, %ymm2, %ymm14
|
||||||
|
vfmsubadd231pd %ymm7, %ymm3, %ymm15
|
||||||
|
vsubpd %ymm12,%ymm8,%ymm4
|
||||||
|
vaddpd %ymm13,%ymm9,%ymm5
|
||||||
|
vsubpd %ymm14,%ymm10,%ymm6
|
||||||
|
vaddpd %ymm15,%ymm11,%ymm7
|
||||||
|
vaddpd %ymm12,%ymm8,%ymm8
|
||||||
|
vsubpd %ymm13,%ymm9,%ymm9
|
||||||
|
vaddpd %ymm14,%ymm10,%ymm10
|
||||||
|
vsubpd %ymm15,%ymm11,%ymm11
|
||||||
|
|
||||||
|
vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma
|
||||||
|
vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma
|
||||||
|
vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta
|
||||||
|
vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta
|
||||||
|
vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12
|
||||||
|
vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14
|
||||||
|
vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||||
|
vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||||
|
|
||||||
|
.save_and_return:
|
||||||
|
vmovupd %ymm8,(%rdi)
|
||||||
|
vmovupd %ymm9,0x20(%rdi)
|
||||||
|
vmovupd %ymm10,0x40(%rdi)
|
||||||
|
vmovupd %ymm11,0x60(%rdi)
|
||||||
|
vmovupd %ymm12,0x80(%rdi)
|
||||||
|
vmovupd %ymm13,0xa0(%rdi)
|
||||||
|
vmovupd %ymm14,0xc0(%rdi)
|
||||||
|
vmovupd %ymm15,0xe0(%rdi)
|
||||||
|
ret
|
||||||
|
.size cplx_fft16_avx_fma, .-cplx_fft16_avx_fma
|
||||||
|
.section .note.GNU-stack,"",@progbits
|
||||||
190
spqlios/lib/spqlios/cplx/cplx_fft16_avx_fma_win32.s
Normal file
190
spqlios/lib/spqlios/cplx/cplx_fft16_avx_fma_win32.s
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
.text
|
||||||
|
.p2align 4
|
||||||
|
.globl cplx_fft16_avx_fma
|
||||||
|
.def cplx_fft16_avx_fma; .scl 2; .type 32; .endef
|
||||||
|
cplx_fft16_avx_fma:
|
||||||
|
|
||||||
|
pushq %rdi
|
||||||
|
pushq %rsi
|
||||||
|
movq %rcx,%rdi
|
||||||
|
movq %rdx,%rsi
|
||||||
|
subq $0x100,%rsp
|
||||||
|
movdqu %xmm6,(%rsp)
|
||||||
|
movdqu %xmm7,0x10(%rsp)
|
||||||
|
movdqu %xmm8,0x20(%rsp)
|
||||||
|
movdqu %xmm9,0x30(%rsp)
|
||||||
|
movdqu %xmm10,0x40(%rsp)
|
||||||
|
movdqu %xmm11,0x50(%rsp)
|
||||||
|
movdqu %xmm12,0x60(%rsp)
|
||||||
|
movdqu %xmm13,0x70(%rsp)
|
||||||
|
movdqu %xmm14,0x80(%rsp)
|
||||||
|
movdqu %xmm15,0x90(%rsp)
|
||||||
|
callq cplx_fft16_avx_fma_amd64
|
||||||
|
movdqu (%rsp),%xmm6
|
||||||
|
movdqu 0x10(%rsp),%xmm7
|
||||||
|
movdqu 0x20(%rsp),%xmm8
|
||||||
|
movdqu 0x30(%rsp),%xmm9
|
||||||
|
movdqu 0x40(%rsp),%xmm10
|
||||||
|
movdqu 0x50(%rsp),%xmm11
|
||||||
|
movdqu 0x60(%rsp),%xmm12
|
||||||
|
movdqu 0x70(%rsp),%xmm13
|
||||||
|
movdqu 0x80(%rsp),%xmm14
|
||||||
|
movdqu 0x90(%rsp),%xmm15
|
||||||
|
addq $0x100,%rsp
|
||||||
|
popq %rsi
|
||||||
|
popq %rdi
|
||||||
|
retq
|
||||||
|
|
||||||
|
# shifted FFT over X^16-i
|
||||||
|
# 1st argument (rdi) contains 16 complexes
|
||||||
|
# 2nd argument (rsi) contains: 8 complexes
|
||||||
|
# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma
|
||||||
|
# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta)
|
||||||
|
# j = sqrt(i), k=sqrt(j)
|
||||||
|
cplx_fft16_avx_fma_amd64:
|
||||||
|
vmovupd (%rdi),%ymm8
|
||||||
|
vmovupd 0x20(%rdi),%ymm9
|
||||||
|
vmovupd 0x40(%rdi),%ymm10
|
||||||
|
vmovupd 0x60(%rdi),%ymm11
|
||||||
|
vmovupd 0x80(%rdi),%ymm12
|
||||||
|
vmovupd 0xa0(%rdi),%ymm13
|
||||||
|
vmovupd 0xc0(%rdi),%ymm14
|
||||||
|
vmovupd 0xe0(%rdi),%ymm15
|
||||||
|
|
||||||
|
.first_pass:
|
||||||
|
vmovupd (%rsi),%xmm0 /* omri */
|
||||||
|
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||||
|
vshufpd $5, %ymm12, %ymm12, %ymm4
|
||||||
|
vshufpd $5, %ymm13, %ymm13, %ymm5
|
||||||
|
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||||
|
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||||
|
vmulpd %ymm4,%ymm1,%ymm4
|
||||||
|
vmulpd %ymm5,%ymm1,%ymm5
|
||||||
|
vmulpd %ymm6,%ymm1,%ymm6
|
||||||
|
vmulpd %ymm7,%ymm1,%ymm7
|
||||||
|
vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4
|
||||||
|
vfmaddsub231pd %ymm13, %ymm0, %ymm5
|
||||||
|
vfmaddsub231pd %ymm14, %ymm0, %ymm6
|
||||||
|
vfmaddsub231pd %ymm15, %ymm0, %ymm7
|
||||||
|
vsubpd %ymm4,%ymm8,%ymm12
|
||||||
|
vsubpd %ymm5,%ymm9,%ymm13
|
||||||
|
vsubpd %ymm6,%ymm10,%ymm14
|
||||||
|
vsubpd %ymm7,%ymm11,%ymm15
|
||||||
|
vaddpd %ymm4,%ymm8,%ymm8
|
||||||
|
vaddpd %ymm5,%ymm9,%ymm9
|
||||||
|
vaddpd %ymm6,%ymm10,%ymm10
|
||||||
|
vaddpd %ymm7,%ymm11,%ymm11
|
||||||
|
|
||||||
|
.second_pass:
|
||||||
|
vmovupd 16(%rsi),%xmm0 /* omri */
|
||||||
|
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||||
|
vshufpd $5, %ymm10, %ymm10, %ymm4
|
||||||
|
vshufpd $5, %ymm11, %ymm11, %ymm5
|
||||||
|
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||||
|
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||||
|
vmulpd %ymm4,%ymm1,%ymm4
|
||||||
|
vmulpd %ymm5,%ymm1,%ymm5
|
||||||
|
vmulpd %ymm6,%ymm0,%ymm6
|
||||||
|
vmulpd %ymm7,%ymm0,%ymm7
|
||||||
|
vfmaddsub231pd %ymm10, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4
|
||||||
|
vfmaddsub231pd %ymm11, %ymm0, %ymm5
|
||||||
|
vfmsubadd231pd %ymm14, %ymm1, %ymm6
|
||||||
|
vfmsubadd231pd %ymm15, %ymm1, %ymm7
|
||||||
|
vsubpd %ymm4,%ymm8,%ymm10
|
||||||
|
vsubpd %ymm5,%ymm9,%ymm11
|
||||||
|
vaddpd %ymm6,%ymm12,%ymm14
|
||||||
|
vaddpd %ymm7,%ymm13,%ymm15
|
||||||
|
vaddpd %ymm4,%ymm8,%ymm8
|
||||||
|
vaddpd %ymm5,%ymm9,%ymm9
|
||||||
|
vsubpd %ymm6,%ymm12,%ymm12
|
||||||
|
vsubpd %ymm7,%ymm13,%ymm13
|
||||||
|
|
||||||
|
.third_pass:
|
||||||
|
vmovupd 32(%rsi),%xmm0 /* gamma */
|
||||||
|
vmovupd 48(%rsi),%xmm2 /* delta */
|
||||||
|
vinsertf128 $1, %xmm0, %ymm0, %ymm0
|
||||||
|
vinsertf128 $1, %xmm2, %ymm2, %ymm2
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||||
|
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||||
|
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||||
|
vshufpd $5, %ymm9, %ymm9, %ymm4
|
||||||
|
vshufpd $5, %ymm11, %ymm11, %ymm5
|
||||||
|
vshufpd $5, %ymm13, %ymm13, %ymm6
|
||||||
|
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||||
|
vmulpd %ymm4,%ymm1,%ymm4
|
||||||
|
vmulpd %ymm5,%ymm0,%ymm5
|
||||||
|
vmulpd %ymm6,%ymm3,%ymm6
|
||||||
|
vmulpd %ymm7,%ymm2,%ymm7
|
||||||
|
vfmaddsub231pd %ymm9, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4
|
||||||
|
vfmsubadd231pd %ymm11, %ymm1, %ymm5
|
||||||
|
vfmaddsub231pd %ymm13, %ymm2, %ymm6
|
||||||
|
vfmsubadd231pd %ymm15, %ymm3, %ymm7
|
||||||
|
vsubpd %ymm4,%ymm8,%ymm9
|
||||||
|
vaddpd %ymm5,%ymm10,%ymm11
|
||||||
|
vsubpd %ymm6,%ymm12,%ymm13
|
||||||
|
vaddpd %ymm7,%ymm14,%ymm15
|
||||||
|
vaddpd %ymm4,%ymm8,%ymm8
|
||||||
|
vsubpd %ymm5,%ymm10,%ymm10
|
||||||
|
vaddpd %ymm6,%ymm12,%ymm12
|
||||||
|
vsubpd %ymm7,%ymm14,%ymm14
|
||||||
|
|
||||||
|
.fourth_pass:
|
||||||
|
vmovupd 64(%rsi),%ymm0 /* gamma */
|
||||||
|
vmovupd 96(%rsi),%ymm2 /* delta */
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||||
|
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||||
|
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||||
|
vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5 -- x gamma
|
||||||
|
vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7 -- x igamma
|
||||||
|
vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13 -- x delta
|
||||||
|
vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15 -- x idelta
|
||||||
|
vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||||
|
vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||||
|
vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12
|
||||||
|
vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14
|
||||||
|
vshufpd $5, %ymm4, %ymm4, %ymm12
|
||||||
|
vshufpd $5, %ymm5, %ymm5, %ymm13
|
||||||
|
vshufpd $5, %ymm6, %ymm6, %ymm14
|
||||||
|
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||||
|
vmulpd %ymm12,%ymm1,%ymm12
|
||||||
|
vmulpd %ymm13,%ymm0,%ymm13
|
||||||
|
vmulpd %ymm14,%ymm3,%ymm14
|
||||||
|
vmulpd %ymm15,%ymm2,%ymm15
|
||||||
|
vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12
|
||||||
|
vfmsubadd231pd %ymm5, %ymm1, %ymm13
|
||||||
|
vfmaddsub231pd %ymm6, %ymm2, %ymm14
|
||||||
|
vfmsubadd231pd %ymm7, %ymm3, %ymm15
|
||||||
|
vsubpd %ymm12,%ymm8,%ymm4
|
||||||
|
vaddpd %ymm13,%ymm9,%ymm5
|
||||||
|
vsubpd %ymm14,%ymm10,%ymm6
|
||||||
|
vaddpd %ymm15,%ymm11,%ymm7
|
||||||
|
vaddpd %ymm12,%ymm8,%ymm8
|
||||||
|
vsubpd %ymm13,%ymm9,%ymm9
|
||||||
|
vaddpd %ymm14,%ymm10,%ymm10
|
||||||
|
vsubpd %ymm15,%ymm11,%ymm11
|
||||||
|
|
||||||
|
vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma
|
||||||
|
vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma
|
||||||
|
vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta
|
||||||
|
vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta
|
||||||
|
vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12
|
||||||
|
vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14
|
||||||
|
vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||||
|
vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||||
|
|
||||||
|
.save_and_return:
|
||||||
|
vmovupd %ymm8,(%rdi)
|
||||||
|
vmovupd %ymm9,0x20(%rdi)
|
||||||
|
vmovupd %ymm10,0x40(%rdi)
|
||||||
|
vmovupd %ymm11,0x60(%rdi)
|
||||||
|
vmovupd %ymm12,0x80(%rdi)
|
||||||
|
vmovupd %ymm13,0xa0(%rdi)
|
||||||
|
vmovupd %ymm14,0xc0(%rdi)
|
||||||
|
vmovupd %ymm15,0xe0(%rdi)
|
||||||
|
ret
|
||||||
8
spqlios/lib/spqlios/cplx/cplx_fft_asserts.c
Normal file
8
spqlios/lib/spqlios/cplx/cplx_fft_asserts.c
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
#include "cplx_fft_private.h"
|
||||||
|
#include "../commons_private.h"
|
||||||
|
|
||||||
|
__always_inline void my_asserts() {
|
||||||
|
STATIC_ASSERT(sizeof(FFT_FUNCTION)==8);
|
||||||
|
STATIC_ASSERT(sizeof(CPLX_FFT_PRECOMP)==40);
|
||||||
|
STATIC_ASSERT(sizeof(CPLX_IFFT_PRECOMP)==40);
|
||||||
|
}
|
||||||
266
spqlios/lib/spqlios/cplx/cplx_fft_avx2_fma.c
Normal file
266
spqlios/lib/spqlios/cplx/cplx_fft_avx2_fma.c
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
#include <immintrin.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
typedef double D4MEM[4];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief complex fft via bfs strategy (for m between 2 and 8)
|
||||||
|
* @param dat the data to run the algorithm on
|
||||||
|
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||||
|
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||||
|
*/
|
||||||
|
void cplx_fft_avx2_fma_bfs_2(D4MEM* dat, const D4MEM** omg, uint32_t m) {
|
||||||
|
double* data = (double*)dat;
|
||||||
|
int32_t _2nblock = m >> 1; // = h in ref code
|
||||||
|
D4MEM* const finaldd = (D4MEM*)(data + 2 * m);
|
||||||
|
while (_2nblock >= 2) {
|
||||||
|
int32_t nblock = _2nblock >> 1; // =h/2 in ref code
|
||||||
|
D4MEM* dd = (D4MEM*)data;
|
||||||
|
do {
|
||||||
|
const __m256d om = _mm256_load_pd(*omg[0]);
|
||||||
|
const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om));
|
||||||
|
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||||
|
D4MEM* const ddend = (dd + nblock);
|
||||||
|
D4MEM* ddmid = ddend;
|
||||||
|
do {
|
||||||
|
const __m256d b = _mm256_loadu_pd(ddmid[0]);
|
||||||
|
const __m256d t1 = _mm256_mul_pd(b, omre);
|
||||||
|
const __m256d barb = _mm256_shuffle_pd(b, b, 5);
|
||||||
|
const __m256d t2 = _mm256_fmadd_pd(barb, omim, t1);
|
||||||
|
const __m256d a = _mm256_loadu_pd(dd[0]);
|
||||||
|
const __m256d newa = _mm256_add_pd(a, t2);
|
||||||
|
const __m256d newb = _mm256_sub_pd(a, t2);
|
||||||
|
_mm256_storeu_pd(dd[0], newa);
|
||||||
|
_mm256_storeu_pd(ddmid[0], newb);
|
||||||
|
dd += 1;
|
||||||
|
ddmid += 1;
|
||||||
|
} while (dd < ddend);
|
||||||
|
dd += nblock;
|
||||||
|
*omg += 1;
|
||||||
|
} while (dd < finaldd);
|
||||||
|
_2nblock >>= 1;
|
||||||
|
}
|
||||||
|
// last iteration when _2nblock == 1
|
||||||
|
{
|
||||||
|
D4MEM* dd = (D4MEM*)data;
|
||||||
|
do {
|
||||||
|
const __m256d om = _mm256_load_pd(*omg[0]);
|
||||||
|
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||||
|
const __m256d omim = _mm256_unpackhi_pd(om, om);
|
||||||
|
const __m256d ab = _mm256_loadu_pd(dd[0]);
|
||||||
|
const __m256d bb = _mm256_permute4x64_pd(ab, 0b11101110);
|
||||||
|
const __m256d bbbar = _mm256_permute4x64_pd(ab, 0b10111011);
|
||||||
|
const __m256d t1 = _mm256_mul_pd(bbbar, omim);
|
||||||
|
const __m256d t2 = _mm256_fmaddsub_pd(bb, omre, t1);
|
||||||
|
const __m256d aa = _mm256_permute4x64_pd(ab, 0b01000100);
|
||||||
|
const __m256d newab = _mm256_add_pd(aa, t2);
|
||||||
|
_mm256_storeu_pd(dd[0], newab);
|
||||||
|
dd += 1;
|
||||||
|
*omg += 1;
|
||||||
|
} while (dd < finaldd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__always_inline void cplx_twiddle_fft_avx2(int32_t h, D4MEM* data, const void* omg) {
|
||||||
|
const __m256d om = _mm256_loadu_pd(omg);
|
||||||
|
const __m256d omim = _mm256_unpackhi_pd(om, om);
|
||||||
|
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||||
|
D4MEM* d0 = data;
|
||||||
|
D4MEM* const ddend = d0 + (h>>1);
|
||||||
|
D4MEM* d1 = ddend;
|
||||||
|
do {
|
||||||
|
const __m256d b = _mm256_loadu_pd(d1[0]);
|
||||||
|
const __m256d barb = _mm256_shuffle_pd(b, b, 5);
|
||||||
|
const __m256d t1 = _mm256_mul_pd(barb, omim);
|
||||||
|
const __m256d t2 = _mm256_fmaddsub_pd(b, omre, t1);
|
||||||
|
const __m256d a = _mm256_loadu_pd(d0[0]);
|
||||||
|
const __m256d newa = _mm256_add_pd(a, t2);
|
||||||
|
const __m256d newb = _mm256_sub_pd(a, t2);
|
||||||
|
_mm256_storeu_pd(d0[0], newa);
|
||||||
|
_mm256_storeu_pd(d1[0], newb);
|
||||||
|
d0 += 1;
|
||||||
|
d1 += 1;
|
||||||
|
} while (d0 < ddend);
|
||||||
|
}
|
||||||
|
|
||||||
|
__always_inline void cplx_bitwiddle_fft_avx2(int32_t h, void* data, const void* powom) {
|
||||||
|
const __m256d omx = _mm256_loadu_pd(powom);
|
||||||
|
const __m256d oma = _mm256_permute2f128_pd(omx, omx, 0x00);
|
||||||
|
const __m256d omb = _mm256_permute2f128_pd(omx, omx, 0x11);
|
||||||
|
const __m256d omaim = _mm256_unpackhi_pd(oma, oma);
|
||||||
|
const __m256d omare = _mm256_unpacklo_pd(oma, oma);
|
||||||
|
const __m256d ombim = _mm256_unpackhi_pd(omb, omb);
|
||||||
|
const __m256d ombre = _mm256_unpacklo_pd(omb, omb);
|
||||||
|
D4MEM* d0 = (D4MEM*) data;
|
||||||
|
D4MEM* const ddend = d0 + (h>>1);
|
||||||
|
D4MEM* d1 = ddend;
|
||||||
|
D4MEM* d2 = d0+h;
|
||||||
|
D4MEM* d3 = d1+h;
|
||||||
|
__m256d reg0,reg1,reg2,reg3,tmp0,tmp1;
|
||||||
|
do {
|
||||||
|
reg0 = _mm256_loadu_pd(d0[0]);
|
||||||
|
reg1 = _mm256_loadu_pd(d1[0]);
|
||||||
|
reg2 = _mm256_loadu_pd(d2[0]);
|
||||||
|
reg3 = _mm256_loadu_pd(d3[0]);
|
||||||
|
tmp0 = _mm256_shuffle_pd(reg2, reg2, 5);
|
||||||
|
tmp1 = _mm256_shuffle_pd(reg3, reg3, 5);
|
||||||
|
tmp0 = _mm256_mul_pd(tmp0, omaim);
|
||||||
|
tmp1 = _mm256_mul_pd(tmp1, omaim);
|
||||||
|
tmp0 = _mm256_fmaddsub_pd(reg2, omare, tmp0);
|
||||||
|
tmp1 = _mm256_fmaddsub_pd(reg3, omare, tmp1);
|
||||||
|
reg2 = _mm256_sub_pd(reg0, tmp0);
|
||||||
|
reg3 = _mm256_sub_pd(reg1, tmp1);
|
||||||
|
reg0 = _mm256_add_pd(reg0, tmp0);
|
||||||
|
reg1 = _mm256_add_pd(reg1, tmp1);
|
||||||
|
//--------------------------------------
|
||||||
|
tmp0 = _mm256_shuffle_pd(reg1, reg1, 5);
|
||||||
|
tmp1 = _mm256_shuffle_pd(reg3, reg3, 5);
|
||||||
|
tmp0 = _mm256_mul_pd(tmp0, ombim); //(r,i)
|
||||||
|
tmp1 = _mm256_mul_pd(tmp1, ombre); //(-i,r)
|
||||||
|
tmp0 = _mm256_fmaddsub_pd(reg1, ombre, tmp0);
|
||||||
|
tmp1 = _mm256_fmsubadd_pd(reg3, ombim, tmp1);
|
||||||
|
reg1 = _mm256_sub_pd(reg0, tmp0);
|
||||||
|
reg3 = _mm256_add_pd(reg2, tmp1);
|
||||||
|
reg0 = _mm256_add_pd(reg0, tmp0);
|
||||||
|
reg2 = _mm256_sub_pd(reg2, tmp1);
|
||||||
|
/////
|
||||||
|
_mm256_storeu_pd(d0[0], reg0);
|
||||||
|
_mm256_storeu_pd(d1[0], reg1);
|
||||||
|
_mm256_storeu_pd(d2[0], reg2);
|
||||||
|
_mm256_storeu_pd(d3[0], reg3);
|
||||||
|
d0 += 1;
|
||||||
|
d1 += 1;
|
||||||
|
d2 += 1;
|
||||||
|
d3 += 1;
|
||||||
|
} while (d0 < ddend);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief complex fft via bfs strategy (for m >= 16)
|
||||||
|
* @param dat the data to run the algorithm on
|
||||||
|
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||||
|
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||||
|
*/
|
||||||
|
void cplx_fft_avx2_fma_bfs_16(D4MEM* dat, const D4MEM** omg, uint32_t m) {
|
||||||
|
double* data = (double*)dat;
|
||||||
|
D4MEM* const finaldd = (D4MEM*)(data + 2 * m);
|
||||||
|
uint32_t mm = m;
|
||||||
|
uint32_t log2m = _mm_popcnt_u32(m-1); // log2(m)
|
||||||
|
if (log2m % 2 == 1) {
|
||||||
|
uint32_t h = mm>>1;
|
||||||
|
cplx_twiddle_fft_avx2(h, dat, **omg);
|
||||||
|
*omg += 1;
|
||||||
|
mm >>= 1;
|
||||||
|
}
|
||||||
|
while(mm>16) {
|
||||||
|
uint32_t h = mm/4;
|
||||||
|
for (CPLX* d = (CPLX*) data; d < (CPLX*) finaldd; d += mm) {
|
||||||
|
cplx_bitwiddle_fft_avx2(h, d, (CPLX*) *omg);
|
||||||
|
*omg += 1;
|
||||||
|
}
|
||||||
|
mm=h;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
D4MEM* dd = (D4MEM*)data;
|
||||||
|
do {
|
||||||
|
cplx_fft16_avx_fma(dd, *omg);
|
||||||
|
dd += 8;
|
||||||
|
*omg += 4;
|
||||||
|
} while (dd < finaldd);
|
||||||
|
_mm256_zeroupper();
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
int32_t _2nblock = m >> 1; // = h in ref code
|
||||||
|
while (_2nblock >= 16) {
|
||||||
|
int32_t nblock = _2nblock >> 1; // =h/2 in ref code
|
||||||
|
D4MEM* dd = (D4MEM*)data;
|
||||||
|
do {
|
||||||
|
const __m256d om = _mm256_load_pd(*omg[0]);
|
||||||
|
const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om));
|
||||||
|
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||||
|
D4MEM* const ddend = (dd + nblock);
|
||||||
|
D4MEM* ddmid = ddend;
|
||||||
|
do {
|
||||||
|
const __m256d b = _mm256_loadu_pd(ddmid[0]);
|
||||||
|
const __m256d t1 = _mm256_mul_pd(b, omre);
|
||||||
|
const __m256d barb = _mm256_shuffle_pd(b, b, 5);
|
||||||
|
const __m256d t2 = _mm256_fmadd_pd(barb, omim, t1);
|
||||||
|
const __m256d a = _mm256_loadu_pd(dd[0]);
|
||||||
|
const __m256d newa = _mm256_add_pd(a, t2);
|
||||||
|
const __m256d newb = _mm256_sub_pd(a, t2);
|
||||||
|
_mm256_storeu_pd(dd[0], newa);
|
||||||
|
_mm256_storeu_pd(ddmid[0], newb);
|
||||||
|
dd += 1;
|
||||||
|
ddmid += 1;
|
||||||
|
} while (dd < ddend);
|
||||||
|
dd += nblock;
|
||||||
|
*omg += 1;
|
||||||
|
} while (dd < finaldd);
|
||||||
|
_2nblock >>= 1;
|
||||||
|
}
|
||||||
|
// last iteration when _2nblock == 8
|
||||||
|
{
|
||||||
|
D4MEM* dd = (D4MEM*)data;
|
||||||
|
do {
|
||||||
|
cplx_fft16_avx_fma(dd, *omg);
|
||||||
|
dd += 8;
|
||||||
|
*omg += 4;
|
||||||
|
} while (dd < finaldd);
|
||||||
|
_mm256_zeroupper();
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief complex fft via dfs recursion (for m >= 16)
|
||||||
|
* @param dat the data to run the algorithm on
|
||||||
|
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||||
|
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||||
|
*/
|
||||||
|
void cplx_fft_avx2_fma_rec_16(D4MEM* dat, const D4MEM** omg, uint32_t m) {
|
||||||
|
if (m <= 8) return cplx_fft_avx2_fma_bfs_2(dat, omg, m);
|
||||||
|
if (m <= 2048) return cplx_fft_avx2_fma_bfs_16(dat, omg, m);
|
||||||
|
double* data = (double*)dat;
|
||||||
|
int32_t _2nblock = m >> 1; // = h in ref code
|
||||||
|
int32_t nblock = _2nblock >> 1; // =h/2 in ref code
|
||||||
|
D4MEM* dd = (D4MEM*)data;
|
||||||
|
const __m256d om = _mm256_load_pd(*omg[0]);
|
||||||
|
const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om));
|
||||||
|
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||||
|
D4MEM* const ddend = (dd + nblock);
|
||||||
|
D4MEM* ddmid = ddend;
|
||||||
|
do {
|
||||||
|
const __m256d b = _mm256_loadu_pd(ddmid[0]);
|
||||||
|
const __m256d t1 = _mm256_mul_pd(b, omre);
|
||||||
|
const __m256d barb = _mm256_shuffle_pd(b, b, 5);
|
||||||
|
const __m256d t2 = _mm256_fmadd_pd(barb, omim, t1);
|
||||||
|
const __m256d a = _mm256_loadu_pd(dd[0]);
|
||||||
|
const __m256d newa = _mm256_add_pd(a, t2);
|
||||||
|
const __m256d newb = _mm256_sub_pd(a, t2);
|
||||||
|
_mm256_storeu_pd(dd[0], newa);
|
||||||
|
_mm256_storeu_pd(ddmid[0], newb);
|
||||||
|
dd += 1;
|
||||||
|
ddmid += 1;
|
||||||
|
} while (dd < ddend);
|
||||||
|
*omg += 1;
|
||||||
|
cplx_fft_avx2_fma_rec_16(dat, omg, _2nblock);
|
||||||
|
cplx_fft_avx2_fma_rec_16(ddend, omg, _2nblock);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief complex fft via best strategy (for m>=1)
|
||||||
|
* @param dat the data to run the algorithm on: m complex numbers
|
||||||
|
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||||
|
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* precomp, void* d) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const D4MEM* omg = (D4MEM*)precomp->powomegas;
|
||||||
|
if (m <= 1) return;
|
||||||
|
if (m <= 8) return cplx_fft_avx2_fma_bfs_2(d, &omg, m);
|
||||||
|
if (m <= 2048) return cplx_fft_avx2_fma_bfs_16(d, &omg, m);
|
||||||
|
cplx_fft_avx2_fma_rec_16(d, &omg, m);
|
||||||
|
}
|
||||||
453
spqlios/lib/spqlios/cplx/cplx_fft_avx512.c
Normal file
453
spqlios/lib/spqlios/cplx/cplx_fft_avx512.c
Normal file
@@ -0,0 +1,453 @@
|
|||||||
|
#include <immintrin.h>
|
||||||
|
|
||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
typedef double D2MEM[2];
|
||||||
|
typedef double D4MEM[4];
|
||||||
|
typedef double D8MEM[8];
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a,
|
||||||
|
const void* b) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const D8MEM* aa = (D8MEM*)a;
|
||||||
|
const D8MEM* bb = (D8MEM*)b;
|
||||||
|
D8MEM* rr = (D8MEM*)r;
|
||||||
|
const D8MEM* const aend = aa + (m >> 2);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m512d ari% = _mm512_loadu_pd(aa[%]);
|
||||||
|
const __m512d bri% = _mm512_loadu_pd(bb[%]);
|
||||||
|
const __m512d rri% = _mm512_loadu_pd(rr[%]);
|
||||||
|
const __m512d bir% = _mm512_shuffle_pd(bri%,bri%, 0b01010101);
|
||||||
|
const __m512d aii% = _mm512_shuffle_pd(ari%,ari%, 0b11111111);
|
||||||
|
const __m512d pro% = _mm512_fmaddsub_pd(aii%,bir%,rri%);
|
||||||
|
const __m512d arr% = _mm512_shuffle_pd(ari%,ari%, 0b00000000);
|
||||||
|
const __m512d res% = _mm512_fmaddsub_pd(arr%,bri%,pro%);
|
||||||
|
_mm512_storeu_pd(rr[%],res%);
|
||||||
|
rr += @; // ONCE
|
||||||
|
aa += @; // ONCE
|
||||||
|
bb += @; // ONCE
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 2
|
||||||
|
const __m512d ari0 = _mm512_loadu_pd(aa[0]);
|
||||||
|
const __m512d ari1 = _mm512_loadu_pd(aa[1]);
|
||||||
|
const __m512d bri0 = _mm512_loadu_pd(bb[0]);
|
||||||
|
const __m512d bri1 = _mm512_loadu_pd(bb[1]);
|
||||||
|
const __m512d rri0 = _mm512_loadu_pd(rr[0]);
|
||||||
|
const __m512d rri1 = _mm512_loadu_pd(rr[1]);
|
||||||
|
const __m512d bir0 = _mm512_shuffle_pd(bri0,bri0, 0b01010101);
|
||||||
|
const __m512d bir1 = _mm512_shuffle_pd(bri1,bri1, 0b01010101);
|
||||||
|
const __m512d aii0 = _mm512_shuffle_pd(ari0,ari0, 0b11111111);
|
||||||
|
const __m512d aii1 = _mm512_shuffle_pd(ari1,ari1, 0b11111111);
|
||||||
|
const __m512d pro0 = _mm512_fmaddsub_pd(aii0,bir0,rri0);
|
||||||
|
const __m512d pro1 = _mm512_fmaddsub_pd(aii1,bir1,rri1);
|
||||||
|
const __m512d arr0 = _mm512_shuffle_pd(ari0,ari0, 0b00000000);
|
||||||
|
const __m512d arr1 = _mm512_shuffle_pd(ari1,ari1, 0b00000000);
|
||||||
|
const __m512d res0 = _mm512_fmaddsub_pd(arr0,bri0,pro0);
|
||||||
|
const __m512d res1 = _mm512_fmaddsub_pd(arr1,bri1,pro1);
|
||||||
|
_mm512_storeu_pd(rr[0],res0);
|
||||||
|
_mm512_storeu_pd(rr[1],res1);
|
||||||
|
rr += 2; // ONCE
|
||||||
|
aa += 2; // ONCE
|
||||||
|
bb += 2; // ONCE
|
||||||
|
// END_INTERLEAVE
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
EXPORT void cplx_fftvec_mul_fma(uint32_t m, void* r, const void* a, const void* b) {
|
||||||
|
const double(*aa)[4] = (double(*)[4])a;
|
||||||
|
const double(*bb)[4] = (double(*)[4])b;
|
||||||
|
double(*rr)[4] = (double(*)[4])r;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||||
|
const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5); // conj of b
|
||||||
|
const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15); // im of a
|
||||||
|
const __m256d pro% = _mm256_mul_pd(aii%,bir%);
|
||||||
|
const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0); // rr of a
|
||||||
|
const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%);
|
||||||
|
_mm256_storeu_pd(rr[%],res%);
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||||
|
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||||
|
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||||
|
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||||
|
const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0, 5); // conj of b
|
||||||
|
const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1, 5); // conj of b
|
||||||
|
const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2, 5); // conj of b
|
||||||
|
const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3, 5); // conj of b
|
||||||
|
const __m256d aii0 = _mm256_shuffle_pd(ari0,ari0, 15); // im of a
|
||||||
|
const __m256d aii1 = _mm256_shuffle_pd(ari1,ari1, 15); // im of a
|
||||||
|
const __m256d aii2 = _mm256_shuffle_pd(ari2,ari2, 15); // im of a
|
||||||
|
const __m256d aii3 = _mm256_shuffle_pd(ari3,ari3, 15); // im of a
|
||||||
|
const __m256d pro0 = _mm256_mul_pd(aii0,bir0);
|
||||||
|
const __m256d pro1 = _mm256_mul_pd(aii1,bir1);
|
||||||
|
const __m256d pro2 = _mm256_mul_pd(aii2,bir2);
|
||||||
|
const __m256d pro3 = _mm256_mul_pd(aii3,bir3);
|
||||||
|
const __m256d arr0 = _mm256_shuffle_pd(ari0,ari0, 0); // rr of a
|
||||||
|
const __m256d arr1 = _mm256_shuffle_pd(ari1,ari1, 0); // rr of a
|
||||||
|
const __m256d arr2 = _mm256_shuffle_pd(ari2,ari2, 0); // rr of a
|
||||||
|
const __m256d arr3 = _mm256_shuffle_pd(ari3,ari3, 0); // rr of a
|
||||||
|
const __m256d res0 = _mm256_fmaddsub_pd(arr0,bri0,pro0);
|
||||||
|
const __m256d res1 = _mm256_fmaddsub_pd(arr1,bri1,pro1);
|
||||||
|
const __m256d res2 = _mm256_fmaddsub_pd(arr2,bri2,pro2);
|
||||||
|
const __m256d res3 = _mm256_fmaddsub_pd(arr3,bri3,pro3);
|
||||||
|
_mm256_storeu_pd(rr[0],res0);
|
||||||
|
_mm256_storeu_pd(rr[1],res1);
|
||||||
|
_mm256_storeu_pd(rr[2],res2);
|
||||||
|
_mm256_storeu_pd(rr[3],res3);
|
||||||
|
// END_INTERLEAVE
|
||||||
|
rr += 4;
|
||||||
|
aa += 4;
|
||||||
|
bb += 4;
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b) {
|
||||||
|
const double(*aa)[4] = (double(*)[4])a;
|
||||||
|
const double(*bb)[4] = (double(*)[4])b;
|
||||||
|
double(*rr)[4] = (double(*)[4])r;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||||
|
const __m256d res% = _mm256_add_pd(ari%,bri%);
|
||||||
|
_mm256_storeu_pd(rr[%],res%);
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||||
|
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||||
|
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||||
|
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||||
|
const __m256d res0 = _mm256_add_pd(ari0,bri0);
|
||||||
|
const __m256d res1 = _mm256_add_pd(ari1,bri1);
|
||||||
|
const __m256d res2 = _mm256_add_pd(ari2,bri2);
|
||||||
|
const __m256d res3 = _mm256_add_pd(ari3,bri3);
|
||||||
|
_mm256_storeu_pd(rr[0],res0);
|
||||||
|
_mm256_storeu_pd(rr[1],res1);
|
||||||
|
_mm256_storeu_pd(rr[2],res2);
|
||||||
|
_mm256_storeu_pd(rr[3],res3);
|
||||||
|
// END_INTERLEAVE
|
||||||
|
rr += 4;
|
||||||
|
aa += 4;
|
||||||
|
bb += 4;
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b) {
|
||||||
|
const double(*aa)[4] = (double(*)[4])a;
|
||||||
|
const double(*bb)[4] = (double(*)[4])b;
|
||||||
|
double(*rr)[4] = (double(*)[4])r;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||||
|
const __m256d sum% = _mm256_add_pd(ari%,bri%);
|
||||||
|
const __m256d rri% = _mm256_loadu_pd(rr[%]);
|
||||||
|
const __m256d res% = _mm256_sub_pd(rri%,sum%);
|
||||||
|
_mm256_storeu_pd(rr[%],res%);
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||||
|
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||||
|
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||||
|
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||||
|
const __m256d sum0 = _mm256_add_pd(ari0,bri0);
|
||||||
|
const __m256d sum1 = _mm256_add_pd(ari1,bri1);
|
||||||
|
const __m256d sum2 = _mm256_add_pd(ari2,bri2);
|
||||||
|
const __m256d sum3 = _mm256_add_pd(ari3,bri3);
|
||||||
|
const __m256d rri0 = _mm256_loadu_pd(rr[0]);
|
||||||
|
const __m256d rri1 = _mm256_loadu_pd(rr[1]);
|
||||||
|
const __m256d rri2 = _mm256_loadu_pd(rr[2]);
|
||||||
|
const __m256d rri3 = _mm256_loadu_pd(rr[3]);
|
||||||
|
const __m256d res0 = _mm256_sub_pd(rri0,sum0);
|
||||||
|
const __m256d res1 = _mm256_sub_pd(rri1,sum1);
|
||||||
|
const __m256d res2 = _mm256_sub_pd(rri2,sum2);
|
||||||
|
const __m256d res3 = _mm256_sub_pd(rri3,sum3);
|
||||||
|
_mm256_storeu_pd(rr[0],res0);
|
||||||
|
_mm256_storeu_pd(rr[1],res1);
|
||||||
|
_mm256_storeu_pd(rr[2],res2);
|
||||||
|
_mm256_storeu_pd(rr[3],res3);
|
||||||
|
// END_INTERLEAVE
|
||||||
|
rr += 4;
|
||||||
|
aa += 4;
|
||||||
|
bb += 4;
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) {
|
||||||
|
const double(*aa)[4] = (double(*)[4])a;
|
||||||
|
double(*rr)[4] = (double(*)[4])r;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
_mm256_storeu_pd(rr[%],ari%);
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
_mm256_storeu_pd(rr[0],ari0);
|
||||||
|
_mm256_storeu_pd(rr[1],ari1);
|
||||||
|
_mm256_storeu_pd(rr[2],ari2);
|
||||||
|
_mm256_storeu_pd(rr[3],ari3);
|
||||||
|
// END_INTERLEAVE
|
||||||
|
rr += 4;
|
||||||
|
aa += 4;
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_twiddle_fma(uint32_t m, void* a, void* b, const void* omg) {
|
||||||
|
double(*aa)[4] = (double(*)[4])a;
|
||||||
|
double(*bb)[4] = (double(*)[4])b;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
const __m256d om = _mm256_loadu_pd(omg);
|
||||||
|
const __m256d omrr = _mm256_shuffle_pd(om, om, 0);
|
||||||
|
const __m256d omii = _mm256_shuffle_pd(om, om, 15);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||||
|
const __m256d bir% = _mm256_shuffle_pd(bri%,bri%,5);
|
||||||
|
__m256d p% = _mm256_mul_pd(bir%,omii);
|
||||||
|
p% = _mm256_fmaddsub_pd(bri%,omrr,p%);
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
_mm256_storeu_pd(aa[%],_mm256_add_pd(ari%,p%));
|
||||||
|
_mm256_storeu_pd(bb[%],_mm256_sub_pd(ari%,p%));
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||||
|
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||||
|
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||||
|
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||||
|
const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0,5);
|
||||||
|
const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1,5);
|
||||||
|
const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2,5);
|
||||||
|
const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3,5);
|
||||||
|
__m256d p0 = _mm256_mul_pd(bir0,omii);
|
||||||
|
__m256d p1 = _mm256_mul_pd(bir1,omii);
|
||||||
|
__m256d p2 = _mm256_mul_pd(bir2,omii);
|
||||||
|
__m256d p3 = _mm256_mul_pd(bir3,omii);
|
||||||
|
p0 = _mm256_fmaddsub_pd(bri0,omrr,p0);
|
||||||
|
p1 = _mm256_fmaddsub_pd(bri1,omrr,p1);
|
||||||
|
p2 = _mm256_fmaddsub_pd(bri2,omrr,p2);
|
||||||
|
p3 = _mm256_fmaddsub_pd(bri3,omrr,p3);
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
_mm256_storeu_pd(aa[0],_mm256_add_pd(ari0,p0));
|
||||||
|
_mm256_storeu_pd(aa[1],_mm256_add_pd(ari1,p1));
|
||||||
|
_mm256_storeu_pd(aa[2],_mm256_add_pd(ari2,p2));
|
||||||
|
_mm256_storeu_pd(aa[3],_mm256_add_pd(ari3,p3));
|
||||||
|
_mm256_storeu_pd(bb[0],_mm256_sub_pd(ari0,p0));
|
||||||
|
_mm256_storeu_pd(bb[1],_mm256_sub_pd(ari1,p1));
|
||||||
|
_mm256_storeu_pd(bb[2],_mm256_sub_pd(ari2,p2));
|
||||||
|
_mm256_storeu_pd(bb[3],_mm256_sub_pd(ari3,p3));
|
||||||
|
// END_INTERLEAVE
|
||||||
|
bb += 4;
|
||||||
|
aa += 4;
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* precomp, void* a, void* b, const void* omg) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
D8MEM* aa = (D8MEM*)a;
|
||||||
|
D8MEM* bb = (D8MEM*)b;
|
||||||
|
D8MEM* const aend = aa + (m >> 2);
|
||||||
|
const __m512d om = _mm512_broadcast_f64x4(_mm256_loadu_pd(omg));
|
||||||
|
const __m512d omrr = _mm512_shuffle_pd(om, om, 0b00000000);
|
||||||
|
const __m512d omii = _mm512_shuffle_pd(om, om, 0b11111111);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m512d bri% = _mm512_loadu_pd(bb[%]);
|
||||||
|
const __m512d bir% = _mm512_shuffle_pd(bri%,bri%,0b10011001);
|
||||||
|
__m512d p% = _mm512_mul_pd(bir%,omii);
|
||||||
|
p% = _mm512_fmaddsub_pd(bri%,omrr,p%);
|
||||||
|
const __m512d ari% = _mm512_loadu_pd(aa[%]);
|
||||||
|
_mm512_storeu_pd(aa[%],_mm512_add_pd(ari%,p%));
|
||||||
|
_mm512_storeu_pd(bb[%],_mm512_sub_pd(ari%,p%));
|
||||||
|
bb += @; // ONCE
|
||||||
|
aa += @; // ONCE
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
const __m512d bri0 = _mm512_loadu_pd(bb[0]);
|
||||||
|
const __m512d bri1 = _mm512_loadu_pd(bb[1]);
|
||||||
|
const __m512d bri2 = _mm512_loadu_pd(bb[2]);
|
||||||
|
const __m512d bri3 = _mm512_loadu_pd(bb[3]);
|
||||||
|
const __m512d bir0 = _mm512_shuffle_pd(bri0,bri0,0b10011001);
|
||||||
|
const __m512d bir1 = _mm512_shuffle_pd(bri1,bri1,0b10011001);
|
||||||
|
const __m512d bir2 = _mm512_shuffle_pd(bri2,bri2,0b10011001);
|
||||||
|
const __m512d bir3 = _mm512_shuffle_pd(bri3,bri3,0b10011001);
|
||||||
|
__m512d p0 = _mm512_mul_pd(bir0,omii);
|
||||||
|
__m512d p1 = _mm512_mul_pd(bir1,omii);
|
||||||
|
__m512d p2 = _mm512_mul_pd(bir2,omii);
|
||||||
|
__m512d p3 = _mm512_mul_pd(bir3,omii);
|
||||||
|
p0 = _mm512_fmaddsub_pd(bri0,omrr,p0);
|
||||||
|
p1 = _mm512_fmaddsub_pd(bri1,omrr,p1);
|
||||||
|
p2 = _mm512_fmaddsub_pd(bri2,omrr,p2);
|
||||||
|
p3 = _mm512_fmaddsub_pd(bri3,omrr,p3);
|
||||||
|
const __m512d ari0 = _mm512_loadu_pd(aa[0]);
|
||||||
|
const __m512d ari1 = _mm512_loadu_pd(aa[1]);
|
||||||
|
const __m512d ari2 = _mm512_loadu_pd(aa[2]);
|
||||||
|
const __m512d ari3 = _mm512_loadu_pd(aa[3]);
|
||||||
|
_mm512_storeu_pd(aa[0],_mm512_add_pd(ari0,p0));
|
||||||
|
_mm512_storeu_pd(aa[1],_mm512_add_pd(ari1,p1));
|
||||||
|
_mm512_storeu_pd(aa[2],_mm512_add_pd(ari2,p2));
|
||||||
|
_mm512_storeu_pd(aa[3],_mm512_add_pd(ari3,p3));
|
||||||
|
_mm512_storeu_pd(bb[0], _mm512_sub_pd(ari0, p0));
|
||||||
|
_mm512_storeu_pd(bb[1], _mm512_sub_pd(ari1, p1));
|
||||||
|
_mm512_storeu_pd(bb[2], _mm512_sub_pd(ari2, p2));
|
||||||
|
_mm512_storeu_pd(bb[3], _mm512_sub_pd(ari3, p3));
|
||||||
|
bb += 4; // ONCE
|
||||||
|
aa += 4; // ONCE
|
||||||
|
// END_INTERLEAVE
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* precomp, void* a, uint64_t slicea,
|
||||||
|
const void* omg) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const uint64_t OFFSET = slicea / sizeof(D8MEM);
|
||||||
|
D8MEM* aa = (D8MEM*)a;
|
||||||
|
const D8MEM* aend = aa + (m >> 2);
|
||||||
|
const __m512d om = _mm512_broadcast_f64x4(_mm256_loadu_pd(omg));
|
||||||
|
const __m512d om1rr = _mm512_shuffle_pd(om, om, 0);
|
||||||
|
const __m512d om1ii = _mm512_shuffle_pd(om, om, 15);
|
||||||
|
const __m512d om2rr = _mm512_shuffle_pd(om, om, 0);
|
||||||
|
const __m512d om2ii = _mm512_shuffle_pd(om, om, 0);
|
||||||
|
const __m512d om3rr = _mm512_shuffle_pd(om, om, 15);
|
||||||
|
const __m512d om3ii = _mm512_shuffle_pd(om, om, 15);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
__m512d ari% = _mm512_loadu_pd(aa[%]);
|
||||||
|
__m512d bri% = _mm512_loadu_pd((aa+OFFSET)[%]);
|
||||||
|
__m512d cri% = _mm512_loadu_pd((aa+2*OFFSET)[%]);
|
||||||
|
__m512d dri% = _mm512_loadu_pd((aa+3*OFFSET)[%]);
|
||||||
|
__m512d pa% = _mm512_shuffle_pd(cri%,cri%,5);
|
||||||
|
__m512d pb% = _mm512_shuffle_pd(dri%,dri%,5);
|
||||||
|
pa% = _mm512_mul_pd(pa%,om1ii);
|
||||||
|
pb% = _mm512_mul_pd(pb%,om1ii);
|
||||||
|
pa% = _mm512_fmaddsub_pd(cri%,om1rr,pa%);
|
||||||
|
pb% = _mm512_fmaddsub_pd(dri%,om1rr,pb%);
|
||||||
|
cri% = _mm512_sub_pd(ari%,pa%);
|
||||||
|
dri% = _mm512_sub_pd(bri%,pb%);
|
||||||
|
ari% = _mm512_add_pd(ari%,pa%);
|
||||||
|
bri% = _mm512_add_pd(bri%,pb%);
|
||||||
|
pa% = _mm512_shuffle_pd(bri%,bri%,5);
|
||||||
|
pb% = _mm512_shuffle_pd(dri%,dri%,5);
|
||||||
|
pa% = _mm512_mul_pd(pa%,om2ii);
|
||||||
|
pb% = _mm512_mul_pd(pb%,om3ii);
|
||||||
|
pa% = _mm512_fmaddsub_pd(bri%,om2rr,pa%);
|
||||||
|
pb% = _mm512_fmaddsub_pd(dri%,om3rr,pb%);
|
||||||
|
bri% = _mm512_sub_pd(ari%,pa%);
|
||||||
|
dri% = _mm512_sub_pd(cri%,pb%);
|
||||||
|
ari% = _mm512_add_pd(ari%,pa%);
|
||||||
|
cri% = _mm512_add_pd(cri%,pb%);
|
||||||
|
_mm512_storeu_pd(aa[%], ari%);
|
||||||
|
_mm512_storeu_pd((aa+OFFSET)[%],bri%);
|
||||||
|
_mm512_storeu_pd((aa+2*OFFSET)[%],cri%);
|
||||||
|
_mm512_storeu_pd((aa+3*OFFSET)[%],dri%);
|
||||||
|
aa += @; // ONCE
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 2
|
||||||
|
__m512d ari0 = _mm512_loadu_pd(aa[0]);
|
||||||
|
__m512d ari1 = _mm512_loadu_pd(aa[1]);
|
||||||
|
__m512d bri0 = _mm512_loadu_pd((aa+OFFSET)[0]);
|
||||||
|
__m512d bri1 = _mm512_loadu_pd((aa+OFFSET)[1]);
|
||||||
|
__m512d cri0 = _mm512_loadu_pd((aa+2*OFFSET)[0]);
|
||||||
|
__m512d cri1 = _mm512_loadu_pd((aa+2*OFFSET)[1]);
|
||||||
|
__m512d dri0 = _mm512_loadu_pd((aa+3*OFFSET)[0]);
|
||||||
|
__m512d dri1 = _mm512_loadu_pd((aa+3*OFFSET)[1]);
|
||||||
|
__m512d pa0 = _mm512_shuffle_pd(cri0,cri0,5);
|
||||||
|
__m512d pa1 = _mm512_shuffle_pd(cri1,cri1,5);
|
||||||
|
__m512d pb0 = _mm512_shuffle_pd(dri0,dri0,5);
|
||||||
|
__m512d pb1 = _mm512_shuffle_pd(dri1,dri1,5);
|
||||||
|
pa0 = _mm512_mul_pd(pa0,om1ii);
|
||||||
|
pa1 = _mm512_mul_pd(pa1,om1ii);
|
||||||
|
pb0 = _mm512_mul_pd(pb0,om1ii);
|
||||||
|
pb1 = _mm512_mul_pd(pb1,om1ii);
|
||||||
|
pa0 = _mm512_fmaddsub_pd(cri0,om1rr,pa0);
|
||||||
|
pa1 = _mm512_fmaddsub_pd(cri1,om1rr,pa1);
|
||||||
|
pb0 = _mm512_fmaddsub_pd(dri0,om1rr,pb0);
|
||||||
|
pb1 = _mm512_fmaddsub_pd(dri1,om1rr,pb1);
|
||||||
|
cri0 = _mm512_sub_pd(ari0,pa0);
|
||||||
|
cri1 = _mm512_sub_pd(ari1,pa1);
|
||||||
|
dri0 = _mm512_sub_pd(bri0,pb0);
|
||||||
|
dri1 = _mm512_sub_pd(bri1,pb1);
|
||||||
|
ari0 = _mm512_add_pd(ari0,pa0);
|
||||||
|
ari1 = _mm512_add_pd(ari1,pa1);
|
||||||
|
bri0 = _mm512_add_pd(bri0,pb0);
|
||||||
|
bri1 = _mm512_add_pd(bri1,pb1);
|
||||||
|
pa0 = _mm512_shuffle_pd(bri0,bri0,5);
|
||||||
|
pa1 = _mm512_shuffle_pd(bri1,bri1,5);
|
||||||
|
pb0 = _mm512_shuffle_pd(dri0,dri0,5);
|
||||||
|
pb1 = _mm512_shuffle_pd(dri1,dri1,5);
|
||||||
|
pa0 = _mm512_mul_pd(pa0,om2ii);
|
||||||
|
pa1 = _mm512_mul_pd(pa1,om2ii);
|
||||||
|
pb0 = _mm512_mul_pd(pb0,om3ii);
|
||||||
|
pb1 = _mm512_mul_pd(pb1,om3ii);
|
||||||
|
pa0 = _mm512_fmaddsub_pd(bri0,om2rr,pa0);
|
||||||
|
pa1 = _mm512_fmaddsub_pd(bri1,om2rr,pa1);
|
||||||
|
pb0 = _mm512_fmaddsub_pd(dri0,om3rr,pb0);
|
||||||
|
pb1 = _mm512_fmaddsub_pd(dri1,om3rr,pb1);
|
||||||
|
bri0 = _mm512_sub_pd(ari0,pa0);
|
||||||
|
bri1 = _mm512_sub_pd(ari1,pa1);
|
||||||
|
dri0 = _mm512_sub_pd(cri0,pb0);
|
||||||
|
dri1 = _mm512_sub_pd(cri1,pb1);
|
||||||
|
ari0 = _mm512_add_pd(ari0,pa0);
|
||||||
|
ari1 = _mm512_add_pd(ari1,pa1);
|
||||||
|
cri0 = _mm512_add_pd(cri0,pb0);
|
||||||
|
cri1 = _mm512_add_pd(cri1,pb1);
|
||||||
|
_mm512_storeu_pd(aa[0], ari0);
|
||||||
|
_mm512_storeu_pd(aa[1], ari1);
|
||||||
|
_mm512_storeu_pd((aa+OFFSET)[0],bri0);
|
||||||
|
_mm512_storeu_pd((aa+OFFSET)[1],bri1);
|
||||||
|
_mm512_storeu_pd((aa+2*OFFSET)[0],cri0);
|
||||||
|
_mm512_storeu_pd((aa+2*OFFSET)[1],cri1);
|
||||||
|
_mm512_storeu_pd((aa+3*OFFSET)[0],dri0);
|
||||||
|
_mm512_storeu_pd((aa+3*OFFSET)[1],dri1);
|
||||||
|
aa += 2; // ONCE
|
||||||
|
// END_INTERLEAVE
|
||||||
|
} while (aa < aend);
|
||||||
|
_mm256_zeroupper();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
123
spqlios/lib/spqlios/cplx/cplx_fft_internal.h
Normal file
123
spqlios/lib/spqlios/cplx/cplx_fft_internal.h
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
#ifndef SPQLIOS_CPLX_FFT_INTERNAL_H
|
||||||
|
#define SPQLIOS_CPLX_FFT_INTERNAL_H
|
||||||
|
|
||||||
|
#include "cplx_fft.h"
|
||||||
|
|
||||||
|
/** @brief a complex number contains two doubles real,imag */
|
||||||
|
typedef double CPLX[2];
|
||||||
|
|
||||||
|
EXPORT void cplx_set(CPLX r, const CPLX a);
|
||||||
|
EXPORT void cplx_neg(CPLX r, const CPLX a);
|
||||||
|
EXPORT void cplx_add(CPLX r, const CPLX a, const CPLX b);
|
||||||
|
EXPORT void cplx_sub(CPLX r, const CPLX a, const CPLX b);
|
||||||
|
EXPORT void cplx_mul(CPLX r, const CPLX a, const CPLX b);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief splits 2h evaluations of one polynomials into 2 times h evaluations of even/odd polynomial
|
||||||
|
* Input: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y)
|
||||||
|
* Output: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z)
|
||||||
|
* where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z
|
||||||
|
* @param h number of "coefficients" h >= 1
|
||||||
|
* @param data 2h complex coefficients interleaved and 256b aligned
|
||||||
|
* @param powom y represented as (yre,yim)
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_split_fft_ref(int32_t h, CPLX* data, const CPLX powom);
|
||||||
|
EXPORT void cplx_bisplit_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Input: Q(y),Q(-y)
|
||||||
|
* Output: P_0(z),P_1(z)
|
||||||
|
* where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z
|
||||||
|
* @param data 2 complexes coefficients interleaved and 256b aligned
|
||||||
|
* @param powom (z,-z) interleaved: (zre,zim,-zre,-zim)
|
||||||
|
*/
|
||||||
|
EXPORT void split_fft_last_ref(CPLX* data, const CPLX powom);
|
||||||
|
|
||||||
|
EXPORT void cplx_ifft_naive(const uint32_t m, const double entry_pwr, CPLX* data);
|
||||||
|
EXPORT void cplx_ifft16_avx_fma(void* data, const void* omega);
|
||||||
|
EXPORT void cplx_ifft16_ref(void* data, const void* omega);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief compute the ifft evaluations of P in place
|
||||||
|
* ifft(data) = ifft_rec(data, i);
|
||||||
|
* function ifft_rec(data, omega) {
|
||||||
|
* if #data = 1: return data
|
||||||
|
* let s = sqrt(omega) w. re(s)>0
|
||||||
|
* let (u,v) = data
|
||||||
|
* return split_fft([ifft_rec(u, s), ifft_rec(v, -s)],s)
|
||||||
|
* }
|
||||||
|
* @param itables precomputed tables (contains all the powers of omega in the order they are used)
|
||||||
|
* @param data vector of m complexes (coeffs as input, evals as output)
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_ifft_ref(const CPLX_IFFT_PRECOMP* itables, void* data);
|
||||||
|
EXPORT void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* itables, void* data);
|
||||||
|
EXPORT void cplx_fft_naive(const uint32_t m, const double entry_pwr, CPLX* data);
|
||||||
|
EXPORT void cplx_fft16_avx_fma(void* data, const void* omega);
|
||||||
|
EXPORT void cplx_fft16_ref(void* data, const void* omega);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief compute the fft evaluations of P in place
|
||||||
|
* fft(data) = fft_rec(data, i);
|
||||||
|
* function fft_rec(data, omega) {
|
||||||
|
* if #data = 1: return data
|
||||||
|
* let s = sqrt(omega) w. re(s)>0
|
||||||
|
* let (u,v) = merge_fft(data, s)
|
||||||
|
* return [fft_rec(u, s), fft_rec(v, -s)]
|
||||||
|
* }
|
||||||
|
* @param tables precomputed tables (contains all the powers of omega in the order they are used)
|
||||||
|
* @param data vector of m complexes (coeffs as input, evals as output)
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_fft_ref(const CPLX_FFT_PRECOMP* tables, void* data);
|
||||||
|
EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* tables, void* data);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief merges 2 times h evaluations of even/odd polynomials into 2h evaluations of a sigle polynomial
|
||||||
|
* Input: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z)
|
||||||
|
* Output: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y)
|
||||||
|
* where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z
|
||||||
|
* @param h number of "coefficients" h >= 1
|
||||||
|
* @param data 2h complex coefficients interleaved and 256b aligned
|
||||||
|
* @param powom y represented as (yre,yim)
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_twiddle_fft_ref(int32_t h, CPLX* data, const CPLX powom);
|
||||||
|
|
||||||
|
EXPORT void citwiddle(CPLX a, CPLX b, const CPLX om);
|
||||||
|
EXPORT void ctwiddle(CPLX a, CPLX b, const CPLX om);
|
||||||
|
EXPORT void invctwiddle(CPLX a, CPLX b, const CPLX ombar);
|
||||||
|
EXPORT void invcitwiddle(CPLX a, CPLX b, const CPLX ombar);
|
||||||
|
|
||||||
|
// CONVERSIONS
|
||||||
|
|
||||||
|
/** @brief r = x from ZnX (coeffs as signed int32_t's ) to double */
|
||||||
|
EXPORT void cplx_from_znx32_ref(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x);
|
||||||
|
EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x);
|
||||||
|
/** @brief r = x to ZnX (coeffs as signed int32_t's ) to double */
|
||||||
|
EXPORT void cplx_to_znx32_ref(const CPLX_TO_ZNX32_PRECOMP* precomp, int32_t* r, const void* x);
|
||||||
|
EXPORT void cplx_to_znx32_avx2_fma(const CPLX_TO_ZNX32_PRECOMP* precomp, int32_t* r, const void* x);
|
||||||
|
/** @brief r = x mod 1 from TnX (coeffs as signed int32_t's) to double */
|
||||||
|
EXPORT void cplx_from_tnx32_ref(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x);
|
||||||
|
EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x);
|
||||||
|
/** @brief r = x mod 1 from TnX (coeffs as signed int32_t's) */
|
||||||
|
EXPORT void cplx_to_tnx32_ref(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c);
|
||||||
|
EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c);
|
||||||
|
/** @brief r = x from RnX (coeffs as doubles ) to double */
|
||||||
|
EXPORT void cplx_from_rnx64_ref(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x);
|
||||||
|
EXPORT void cplx_from_rnx64_avx2_fma(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x);
|
||||||
|
/** @brief r = x to RnX (coeffs as doubles ) to double */
|
||||||
|
EXPORT void cplx_to_rnx64_ref(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||||
|
EXPORT void cplx_to_rnx64_avx2_fma(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||||
|
/** @brief r = x to integers in RnX (coeffs as doubles ) to double */
|
||||||
|
EXPORT void cplx_round_to_rnx64_ref(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||||
|
EXPORT void cplx_round_to_rnx64_avx2_fma(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||||
|
|
||||||
|
// fftvec operations
|
||||||
|
/** @brief element-wise addmul r += ab */
|
||||||
|
EXPORT void cplx_fftvec_addmul_ref(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b);
|
||||||
|
EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b);
|
||||||
|
EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b);
|
||||||
|
EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b);
|
||||||
|
/** @brief element-wise mul r = ab */
|
||||||
|
EXPORT void cplx_fftvec_mul_ref(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b);
|
||||||
|
EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_CPLX_FFT_INTERNAL_H
|
||||||
109
spqlios/lib/spqlios/cplx/cplx_fft_private.h
Normal file
109
spqlios/lib/spqlios/cplx/cplx_fft_private.h
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
#ifndef SPQLIOS_CPLX_FFT_PRIVATE_H
|
||||||
|
#define SPQLIOS_CPLX_FFT_PRIVATE_H
|
||||||
|
|
||||||
|
#include "cplx_fft.h"
|
||||||
|
|
||||||
|
typedef struct cplx_twiddle_precomp CPLX_FFTVEC_TWIDDLE_PRECOMP;
|
||||||
|
typedef struct cplx_bitwiddle_precomp CPLX_FFTVEC_BITWIDDLE_PRECOMP;
|
||||||
|
|
||||||
|
typedef void (*IFFT_FUNCTION)(const CPLX_IFFT_PRECOMP*, void*);
|
||||||
|
typedef void (*FFT_FUNCTION)(const CPLX_FFT_PRECOMP*, void*);
|
||||||
|
// conversions
|
||||||
|
typedef void (*FROM_ZNX32_FUNCTION)(const CPLX_FROM_ZNX32_PRECOMP*, void*, const int32_t*);
|
||||||
|
typedef void (*TO_ZNX32_FUNCTION)(const CPLX_FROM_ZNX32_PRECOMP*, int32_t*, const void*);
|
||||||
|
typedef void (*FROM_TNX32_FUNCTION)(const CPLX_FROM_TNX32_PRECOMP*, void*, const int32_t*);
|
||||||
|
typedef void (*TO_TNX32_FUNCTION)(const CPLX_TO_TNX32_PRECOMP*, int32_t*, const void*);
|
||||||
|
typedef void (*FROM_RNX64_FUNCTION)(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x);
|
||||||
|
typedef void (*TO_RNX64_FUNCTION)(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||||
|
typedef void (*ROUND_TO_RNX64_FUNCTION)(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||||
|
// fftvec operations
|
||||||
|
typedef void (*FFTVEC_MUL_FUNCTION)(const CPLX_FFTVEC_MUL_PRECOMP*, void*, const void*, const void*);
|
||||||
|
typedef void (*FFTVEC_ADDMUL_FUNCTION)(const CPLX_FFTVEC_ADDMUL_PRECOMP*, void*, const void*, const void*);
|
||||||
|
|
||||||
|
typedef void (*FFTVEC_TWIDDLE_FUNCTION)(const CPLX_FFTVEC_TWIDDLE_PRECOMP*, void*, const void*, const void*);
|
||||||
|
typedef void (*FFTVEC_BITWIDDLE_FUNCTION)(const CPLX_FFTVEC_BITWIDDLE_PRECOMP*, void*, uint64_t, const void*);
|
||||||
|
|
||||||
|
struct cplx_ifft_precomp {
|
||||||
|
IFFT_FUNCTION function;
|
||||||
|
int64_t m;
|
||||||
|
uint64_t buf_size;
|
||||||
|
double* powomegas;
|
||||||
|
void* aligned_buffers;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct cplx_fft_precomp {
|
||||||
|
FFT_FUNCTION function;
|
||||||
|
int64_t m;
|
||||||
|
uint64_t buf_size;
|
||||||
|
double* powomegas;
|
||||||
|
void* aligned_buffers;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct cplx_from_znx32_precomp {
|
||||||
|
FROM_ZNX32_FUNCTION function;
|
||||||
|
int64_t m;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct cplx_to_znx32_precomp {
|
||||||
|
TO_ZNX32_FUNCTION function;
|
||||||
|
int64_t m;
|
||||||
|
double divisor;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct cplx_from_tnx32_precomp {
|
||||||
|
FROM_TNX32_FUNCTION function;
|
||||||
|
int64_t m;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct cplx_to_tnx32_precomp {
|
||||||
|
TO_TNX32_FUNCTION function;
|
||||||
|
int64_t m;
|
||||||
|
double divisor;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct cplx_from_rnx64_precomp {
|
||||||
|
FROM_RNX64_FUNCTION function;
|
||||||
|
int64_t m;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct cplx_to_rnx64_precomp {
|
||||||
|
TO_RNX64_FUNCTION function;
|
||||||
|
int64_t m;
|
||||||
|
double divisor;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct cplx_round_to_rnx64_precomp {
|
||||||
|
ROUND_TO_RNX64_FUNCTION function;
|
||||||
|
int64_t m;
|
||||||
|
double divisor;
|
||||||
|
uint32_t log2bound;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct cplx_mul_precomp {
|
||||||
|
FFTVEC_MUL_FUNCTION function;
|
||||||
|
int64_t m;
|
||||||
|
} CPLX_FFTVEC_MUL_PRECOMP;
|
||||||
|
|
||||||
|
typedef struct cplx_addmul_precomp {
|
||||||
|
FFTVEC_ADDMUL_FUNCTION function;
|
||||||
|
int64_t m;
|
||||||
|
} CPLX_FFTVEC_ADDMUL_PRECOMP;
|
||||||
|
|
||||||
|
struct cplx_twiddle_precomp {
|
||||||
|
FFTVEC_TWIDDLE_FUNCTION function;
|
||||||
|
int64_t m;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct cplx_bitwiddle_precomp {
|
||||||
|
FFTVEC_BITWIDDLE_FUNCTION function;
|
||||||
|
int64_t m;
|
||||||
|
};
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om);
|
||||||
|
EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om);
|
||||||
|
EXPORT void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
|
||||||
|
const void* om);
|
||||||
|
EXPORT void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
|
||||||
|
const void* om);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_CPLX_FFT_PRIVATE_H
|
||||||
367
spqlios/lib/spqlios/cplx/cplx_fft_ref.c
Normal file
367
spqlios/lib/spqlios/cplx/cplx_fft_ref.c
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
#include <memory.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "../commons_private.h"
|
||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
/** @brief (a,b) <- (a+omega.b,a-omega.b) */
|
||||||
|
void ctwiddle(CPLX a, CPLX b, const CPLX om) {
|
||||||
|
double re = om[0] * b[0] - om[1] * b[1];
|
||||||
|
double im = om[0] * b[1] + om[1] * b[0];
|
||||||
|
b[0] = a[0] - re;
|
||||||
|
b[1] = a[1] - im;
|
||||||
|
a[0] += re;
|
||||||
|
a[1] += im;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief (a,b) <- (a+i.omega.b,a-i.omega.b) */
|
||||||
|
void citwiddle(CPLX a, CPLX b, const CPLX om) {
|
||||||
|
double re = -om[1] * b[0] - om[0] * b[1];
|
||||||
|
double im = -om[1] * b[1] + om[0] * b[0];
|
||||||
|
b[0] = a[0] - re;
|
||||||
|
b[1] = a[1] - im;
|
||||||
|
a[0] += re;
|
||||||
|
a[1] += im;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief FFT modulo X^16-omega^2 (in registers)
|
||||||
|
* @param data contains 16 complexes
|
||||||
|
* @param omega 8 complexes in this order:
|
||||||
|
* omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma
|
||||||
|
* alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta)
|
||||||
|
* j = sqrt(i), k=sqrt(j)
|
||||||
|
*/
|
||||||
|
void cplx_fft16_ref(void* data, const void* omega) {
|
||||||
|
CPLX* d = data;
|
||||||
|
const CPLX* om = omega;
|
||||||
|
// first pass
|
||||||
|
for (uint64_t i = 0; i < 8; ++i) {
|
||||||
|
ctwiddle(d[0 + i], d[8 + i], om[0]);
|
||||||
|
}
|
||||||
|
//
|
||||||
|
ctwiddle(d[0], d[4], om[1]);
|
||||||
|
ctwiddle(d[1], d[5], om[1]);
|
||||||
|
ctwiddle(d[2], d[6], om[1]);
|
||||||
|
ctwiddle(d[3], d[7], om[1]);
|
||||||
|
citwiddle(d[8], d[12], om[1]);
|
||||||
|
citwiddle(d[9], d[13], om[1]);
|
||||||
|
citwiddle(d[10], d[14], om[1]);
|
||||||
|
citwiddle(d[11], d[15], om[1]);
|
||||||
|
//
|
||||||
|
ctwiddle(d[0], d[2], om[2]);
|
||||||
|
ctwiddle(d[1], d[3], om[2]);
|
||||||
|
citwiddle(d[4], d[6], om[2]);
|
||||||
|
citwiddle(d[5], d[7], om[2]);
|
||||||
|
ctwiddle(d[8], d[10], om[3]);
|
||||||
|
ctwiddle(d[9], d[11], om[3]);
|
||||||
|
citwiddle(d[12], d[14], om[3]);
|
||||||
|
citwiddle(d[13], d[15], om[3]);
|
||||||
|
//
|
||||||
|
ctwiddle(d[0], d[1], om[4]);
|
||||||
|
citwiddle(d[2], d[3], om[4]);
|
||||||
|
ctwiddle(d[4], d[5], om[5]);
|
||||||
|
citwiddle(d[6], d[7], om[5]);
|
||||||
|
ctwiddle(d[8], d[9], om[6]);
|
||||||
|
citwiddle(d[10], d[11], om[6]);
|
||||||
|
ctwiddle(d[12], d[13], om[7]);
|
||||||
|
citwiddle(d[14], d[15], om[7]);
|
||||||
|
}
|
||||||
|
|
||||||
|
double cos_2pix(double x) { return m_accurate_cos(2 * M_PI * x); }
|
||||||
|
double sin_2pix(double x) { return m_accurate_sin(2 * M_PI * x); }
|
||||||
|
void cplx_set_e2pix(CPLX res, double x) {
|
||||||
|
res[0] = cos_2pix(x);
|
||||||
|
res[1] = sin_2pix(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cplx_fft16_precomp(const double entry_pwr, CPLX** omg) {
|
||||||
|
static const double j_pow = 1. / 8.;
|
||||||
|
static const double k_pow = 1. / 16.;
|
||||||
|
const double pom = entry_pwr / 2.;
|
||||||
|
const double pom_2 = entry_pwr / 4.;
|
||||||
|
const double pom_4 = entry_pwr / 8.;
|
||||||
|
const double pom_8 = entry_pwr / 16.;
|
||||||
|
cplx_set_e2pix((*omg)[0], pom);
|
||||||
|
cplx_set_e2pix((*omg)[1], pom_2);
|
||||||
|
cplx_set_e2pix((*omg)[2], pom_4);
|
||||||
|
cplx_set_e2pix((*omg)[3], pom_4 + j_pow);
|
||||||
|
cplx_set_e2pix((*omg)[4], pom_8);
|
||||||
|
cplx_set_e2pix((*omg)[5], pom_8 + j_pow);
|
||||||
|
cplx_set_e2pix((*omg)[6], pom_8 + k_pow);
|
||||||
|
cplx_set_e2pix((*omg)[7], pom_8 + j_pow + k_pow);
|
||||||
|
*omg += 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief h twiddles-fft on the same omega
|
||||||
|
* (also called merge-fft)merges 2 times h evaluations of even/odd polynomials into 2h evaluations of a sigle polynomial
|
||||||
|
* Input: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z)
|
||||||
|
* Output: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y)
|
||||||
|
* where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z
|
||||||
|
* @param h number of "coefficients" h >= 1
|
||||||
|
* @param data 2h complex coefficients interleaved and 256b aligned
|
||||||
|
* @param powom y represented as (yre,yim)
|
||||||
|
*/
|
||||||
|
void cplx_twiddle_fft_ref(int32_t h, CPLX* data, const CPLX powom) {
|
||||||
|
CPLX* d0 = data;
|
||||||
|
CPLX* d1 = data + h;
|
||||||
|
for (uint64_t i = 0; i < h; ++i) {
|
||||||
|
ctwiddle(d0[i], d1[i], powom);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cplx_bitwiddle_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]) {
|
||||||
|
CPLX* d0 = data;
|
||||||
|
CPLX* d1 = data + h;
|
||||||
|
CPLX* d2 = data + 2*h;
|
||||||
|
CPLX* d3 = data + 3*h;
|
||||||
|
for (uint64_t i = 0; i < h; ++i) {
|
||||||
|
ctwiddle(d0[i], d2[i], powom[0]);
|
||||||
|
ctwiddle(d1[i], d3[i], powom[0]);
|
||||||
|
}
|
||||||
|
for (uint64_t i = 0; i < h; ++i) {
|
||||||
|
ctwiddle(d0[i], d1[i], powom[1]);
|
||||||
|
citwiddle(d2[i], d3[i], powom[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Input: P_0(z),P_1(z)
|
||||||
|
* Output: Q(y),Q(-y)
|
||||||
|
* where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z
|
||||||
|
* @param data 2 complexes coefficients interleaved and 256b aligned
|
||||||
|
* @param powom (z,-z) interleaved: (zre,zim,-zre,-zim)
|
||||||
|
*/
|
||||||
|
void merge_fft_last_ref(CPLX* data, const CPLX powom) {
|
||||||
|
CPLX prod;
|
||||||
|
cplx_mul(prod, data[1], powom);
|
||||||
|
cplx_sub(data[1], data[0], prod);
|
||||||
|
cplx_add(data[0], data[0], prod);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cplx_fft_ref_bfs_2(CPLX* dat, const CPLX** omg, uint32_t m) {
|
||||||
|
CPLX* data = (CPLX*)dat;
|
||||||
|
CPLX* const dend = data + m;
|
||||||
|
for (int32_t h = m / 2; h >= 2; h >>= 1) {
|
||||||
|
for (CPLX* d = data; d < dend; d += 2 * h) {
|
||||||
|
if (memcmp((*omg)[0], (*omg)[1], 8) != 0) abort();
|
||||||
|
cplx_twiddle_fft_ref(h, d, **omg);
|
||||||
|
*omg += 2;
|
||||||
|
}
|
||||||
|
#if 0
|
||||||
|
printf("after merge %d: ", h);
|
||||||
|
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||||
|
printf("%.6lf %.6lf ",data[ii][0],data[ii][1]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
for (CPLX* d = data; d < dend; d += 2) {
|
||||||
|
// TODO see if encoding changes
|
||||||
|
if ((*omg)[0][0] != -(*omg)[1][0]) abort();
|
||||||
|
if ((*omg)[0][1] != -(*omg)[1][1]) abort();
|
||||||
|
merge_fft_last_ref(d, **omg);
|
||||||
|
*omg += 2;
|
||||||
|
}
|
||||||
|
#if 0
|
||||||
|
printf("after last: ");
|
||||||
|
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||||
|
printf("%.6lf %.6lf ",data[ii][0],data[ii][1]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void cplx_fft_ref_bfs_16(CPLX* dat, const CPLX** omg, uint32_t m) {
|
||||||
|
CPLX* data = (CPLX*)dat;
|
||||||
|
CPLX* const dend = data + m;
|
||||||
|
uint32_t mm = m;
|
||||||
|
uint32_t log2m = log2(m);
|
||||||
|
if (log2m % 2 == 1) {
|
||||||
|
cplx_twiddle_fft_ref(mm/2, data, **omg);
|
||||||
|
*omg += 2;
|
||||||
|
mm >>= 1;
|
||||||
|
}
|
||||||
|
while(mm>16) {
|
||||||
|
uint32_t h = mm/4;
|
||||||
|
for (CPLX* d = data; d < dend; d += mm) {
|
||||||
|
cplx_bitwiddle_fft_ref(h, d, *omg);
|
||||||
|
*omg += 2;
|
||||||
|
}
|
||||||
|
mm=h;
|
||||||
|
}
|
||||||
|
for (CPLX* d = data; d < dend; d += 16) {
|
||||||
|
cplx_fft16_ref(d, *omg);
|
||||||
|
*omg += 8;
|
||||||
|
}
|
||||||
|
#if 0
|
||||||
|
printf("after last: ");
|
||||||
|
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||||
|
printf("%.6lf %.6lf ",data[ii][0],data[ii][1]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief fft modulo X^m-exp(i.2pi.entry+pwr) -- reference code */
|
||||||
|
void cplx_fft_naive(const uint32_t m, const double entry_pwr, CPLX* data) {
|
||||||
|
if (m == 1) return;
|
||||||
|
const double pom = entry_pwr / 2.;
|
||||||
|
const uint32_t h = m / 2;
|
||||||
|
// apply the twiddle factors
|
||||||
|
CPLX cpom;
|
||||||
|
cplx_set_e2pix(cpom, pom);
|
||||||
|
for (uint64_t i = 0; i < h; ++i) {
|
||||||
|
ctwiddle(data[i], data[i + h], cpom);
|
||||||
|
}
|
||||||
|
// do the recursive calls
|
||||||
|
cplx_fft_naive(h, pom, data);
|
||||||
|
cplx_fft_naive(h, pom + 0.5, data + h);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief fills omega for cplx_fft_bfs_16 modulo X^m-exp(i.2.pi.entry_pwr) */
|
||||||
|
void fill_cplx_fft_omegas_bfs_16(const double entry_pwr, CPLX** omg, uint32_t m) {
|
||||||
|
uint32_t mm = m;
|
||||||
|
uint32_t log2m = log2(m);
|
||||||
|
double ss = entry_pwr;
|
||||||
|
if (log2m % 2 == 1) {
|
||||||
|
uint32_t h = mm / 2;
|
||||||
|
double pom = ss / 2.;
|
||||||
|
for (uint32_t i = 0; i < m / mm; i++) {
|
||||||
|
cplx_set_e2pix(omg[0][0], pom + fracrevbits(i) / 2.);
|
||||||
|
cplx_set(omg[0][1], omg[0][0]);
|
||||||
|
*omg += 2;
|
||||||
|
}
|
||||||
|
mm = h;
|
||||||
|
ss = pom;
|
||||||
|
}
|
||||||
|
while (mm>16) {
|
||||||
|
double pom = ss / 4.;
|
||||||
|
uint32_t h = mm / 4;
|
||||||
|
for (uint32_t i = 0; i < m / mm; i++) {
|
||||||
|
double om = pom + fracrevbits(i) / 4.;
|
||||||
|
cplx_set_e2pix(omg[0][0], 2. * om);
|
||||||
|
cplx_set_e2pix(omg[0][1], om);
|
||||||
|
*omg += 2;
|
||||||
|
}
|
||||||
|
mm = h;
|
||||||
|
ss = pom;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// mm=16
|
||||||
|
for (uint32_t i = 0; i < m / 16; i++) {
|
||||||
|
cplx_fft16_precomp(ss + fracrevbits(i), omg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief fills omega for cplx_fft_bfs_2 modulo X^m-exp(i.2.pi.entry_pwr) */
|
||||||
|
void fill_cplx_fft_omegas_bfs_2(const double entry_pwr, CPLX** omg, uint32_t m) {
|
||||||
|
double pom = entry_pwr / 2.;
|
||||||
|
for (int32_t h = m / 2; h >= 2; h >>= 1) {
|
||||||
|
for (uint32_t i = 0; i < m / (2 * h); i++) {
|
||||||
|
cplx_set_e2pix(omg[0][0], pom + fracrevbits(i) / 2.);
|
||||||
|
cplx_set(omg[0][1], omg[0][0]);
|
||||||
|
*omg += 2;
|
||||||
|
}
|
||||||
|
pom /= 2;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// h=1
|
||||||
|
for (uint32_t i = 0; i < m / 2; i++) {
|
||||||
|
cplx_set_e2pix((*omg)[0], pom + fracrevbits(i) / 2.);
|
||||||
|
cplx_neg((*omg)[1], (*omg)[0]);
|
||||||
|
*omg += 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief fills omega for cplx_fft_rec modulo X^m-exp(i.2.pi.entry_pwr) */
|
||||||
|
void fill_cplx_fft_omegas_rec_16(const double entry_pwr, CPLX** omg, uint32_t m) {
|
||||||
|
// note that the cases below are for recursive calls only!
|
||||||
|
// externally, this function shall only be called with m>=4096
|
||||||
|
if (m == 1) return;
|
||||||
|
if (m <= 8) return fill_cplx_fft_omegas_bfs_2(entry_pwr, omg, m);
|
||||||
|
if (m <= 2048) return fill_cplx_fft_omegas_bfs_16(entry_pwr, omg, m);
|
||||||
|
double pom = entry_pwr / 2.;
|
||||||
|
cplx_set_e2pix((*omg)[0], pom);
|
||||||
|
cplx_set_e2pix((*omg)[1], pom);
|
||||||
|
*omg += 2;
|
||||||
|
fill_cplx_fft_omegas_rec_16(pom, omg, m / 2);
|
||||||
|
fill_cplx_fft_omegas_rec_16(pom + 0.5, omg, m / 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cplx_fft_ref_rec_16(CPLX* dat, const CPLX** omg, uint32_t m) {
|
||||||
|
if (m == 1) return;
|
||||||
|
if (m <= 8) return cplx_fft_ref_bfs_2(dat, omg, m);
|
||||||
|
if (m <= 2048) return cplx_fft_ref_bfs_16(dat, omg, m);
|
||||||
|
const uint32_t h = m / 2;
|
||||||
|
if (memcmp((*omg)[0], (*omg)[1], 8) != 0) abort();
|
||||||
|
cplx_twiddle_fft_ref(h, dat, **omg);
|
||||||
|
*omg += 2;
|
||||||
|
cplx_fft_ref_rec_16(dat, omg, h);
|
||||||
|
cplx_fft_ref_rec_16(dat + h, omg, h);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cplx_fft_ref(const CPLX_FFT_PRECOMP* precomp, void* d) {
|
||||||
|
CPLX* data = (CPLX*)d;
|
||||||
|
const int32_t m = precomp->m;
|
||||||
|
const CPLX* omg = (CPLX*)precomp->powomegas;
|
||||||
|
if (m == 1) return;
|
||||||
|
if (m <= 8) return cplx_fft_ref_bfs_2(data, &omg, m);
|
||||||
|
if (m <= 2048) return cplx_fft_ref_bfs_16(data, &omg, m);
|
||||||
|
cplx_fft_ref_rec_16(data, &omg, m);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT CPLX_FFT_PRECOMP* new_cplx_fft_precomp(uint32_t m, uint32_t num_buffers) {
|
||||||
|
const uint64_t OMG_SPACE = ceilto64b((2 * m)* sizeof(CPLX));
|
||||||
|
const uint64_t BUF_SIZE = ceilto64b(m * sizeof(CPLX));
|
||||||
|
void* reps = malloc(sizeof(CPLX_FFT_PRECOMP) + 63 // padding
|
||||||
|
+ OMG_SPACE // tables //TODO 16?
|
||||||
|
+ num_buffers * BUF_SIZE // buffers
|
||||||
|
);
|
||||||
|
uint64_t aligned_addr = ceilto64b((uint64_t)(reps) + sizeof(CPLX_FFT_PRECOMP));
|
||||||
|
CPLX_FFT_PRECOMP* r = (CPLX_FFT_PRECOMP*)reps;
|
||||||
|
r->m = m;
|
||||||
|
r->buf_size = BUF_SIZE;
|
||||||
|
r->powomegas = (double*)aligned_addr;
|
||||||
|
r->aligned_buffers = (void*)(aligned_addr + OMG_SPACE);
|
||||||
|
// fill in powomegas
|
||||||
|
CPLX* omg = (CPLX*)r->powomegas;
|
||||||
|
if (m <= 8) {
|
||||||
|
fill_cplx_fft_omegas_bfs_2(0.25, &omg, m);
|
||||||
|
} else if (m <= 2048) {
|
||||||
|
fill_cplx_fft_omegas_bfs_16(0.25, &omg, m);
|
||||||
|
} else {
|
||||||
|
fill_cplx_fft_omegas_rec_16(0.25, &omg, m);
|
||||||
|
}
|
||||||
|
if (((uint64_t)omg) - aligned_addr > OMG_SPACE) abort();
|
||||||
|
// dispatch the right implementation
|
||||||
|
{
|
||||||
|
if (m <= 4) {
|
||||||
|
// currently, we do not have any acceletated
|
||||||
|
// implementation for m<=4
|
||||||
|
r->function = cplx_fft_ref;
|
||||||
|
} else if (CPU_SUPPORTS("fma")) {
|
||||||
|
r->function = cplx_fft_avx2_fma;
|
||||||
|
} else {
|
||||||
|
r->function = cplx_fft_ref;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return reps;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void* cplx_fft_precomp_get_buffer(const CPLX_FFT_PRECOMP* tables, uint32_t buffer_index) {
|
||||||
|
return (uint8_t *)tables->aligned_buffers + buffer_index * tables->buf_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fft_simple(uint32_t m, void* data) {
|
||||||
|
static CPLX_FFT_PRECOMP* p[31] = {0};
|
||||||
|
CPLX_FFT_PRECOMP** f = p + log2m(m);
|
||||||
|
if (!*f) *f = new_cplx_fft_precomp(m, 0);
|
||||||
|
(*f)->function(*f, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data) { tables->function(tables, data); }
|
||||||
310
spqlios/lib/spqlios/cplx/cplx_fft_sse.c
Normal file
310
spqlios/lib/spqlios/cplx/cplx_fft_sse.c
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
#include <immintrin.h>
|
||||||
|
|
||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
typedef double D2MEM[2];
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const D2MEM* aa = (D2MEM*)a;
|
||||||
|
const D2MEM* bb = (D2MEM*)b;
|
||||||
|
D2MEM* rr = (D2MEM*)r;
|
||||||
|
const D2MEM* const aend = aa + m;
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m128d ari% = _mm_loadu_pd(aa[%]);
|
||||||
|
const __m128d bri% = _mm_loadu_pd(bb[%]);
|
||||||
|
const __m128d rri% = _mm_loadu_pd(rr[%]);
|
||||||
|
const __m128d bir% = _mm_shuffle_pd(bri%,bri%, 5);
|
||||||
|
const __m128d aii% = _mm_shuffle_pd(ari%,ari%, 15);
|
||||||
|
const __m128d pro% = _mm_fmaddsub_pd(aii%,bir%,rri%);
|
||||||
|
const __m128d arr% = _mm_shuffle_pd(ari%,ari%, 0);
|
||||||
|
const __m128d res% = _mm_fmaddsub_pd(arr%,bri%,pro%);
|
||||||
|
_mm_storeu_pd(rr[%],res%);
|
||||||
|
rr += @; // ONCE
|
||||||
|
aa += @; // ONCE
|
||||||
|
bb += @; // ONCE
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 2
|
||||||
|
const __m128d ari0 = _mm_loadu_pd(aa[0]);
|
||||||
|
const __m128d ari1 = _mm_loadu_pd(aa[1]);
|
||||||
|
const __m128d bri0 = _mm_loadu_pd(bb[0]);
|
||||||
|
const __m128d bri1 = _mm_loadu_pd(bb[1]);
|
||||||
|
const __m128d rri0 = _mm_loadu_pd(rr[0]);
|
||||||
|
const __m128d rri1 = _mm_loadu_pd(rr[1]);
|
||||||
|
const __m128d bir0 = _mm_shuffle_pd(bri0, bri0, 0b01);
|
||||||
|
const __m128d bir1 = _mm_shuffle_pd(bri1, bri1, 0b01);
|
||||||
|
const __m128d aii0 = _mm_shuffle_pd(ari0, ari0, 0b11);
|
||||||
|
const __m128d aii1 = _mm_shuffle_pd(ari1, ari1, 0b11);
|
||||||
|
const __m128d pro0 = _mm_fmaddsub_pd(aii0, bir0, rri0);
|
||||||
|
const __m128d pro1 = _mm_fmaddsub_pd(aii1, bir1, rri1);
|
||||||
|
const __m128d arr0 = _mm_shuffle_pd(ari0, ari0, 0b00);
|
||||||
|
const __m128d arr1 = _mm_shuffle_pd(ari1, ari1, 0b00);
|
||||||
|
const __m128d res0 = _mm_fmaddsub_pd(arr0, bri0, pro0);
|
||||||
|
const __m128d res1 = _mm_fmaddsub_pd(arr1, bri1, pro1);
|
||||||
|
_mm_storeu_pd(rr[0], res0);
|
||||||
|
_mm_storeu_pd(rr[1], res1);
|
||||||
|
rr += 2; // ONCE
|
||||||
|
aa += 2; // ONCE
|
||||||
|
bb += 2; // ONCE
|
||||||
|
// END_INTERLEAVE
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const double(*aa)[4] = (double(*)[4])a;
|
||||||
|
const double(*bb)[4] = (double(*)[4])b;
|
||||||
|
double(*rr)[4] = (double(*)[4])r;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||||
|
const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5); // conj of b
|
||||||
|
const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15); // im of a
|
||||||
|
const __m256d pro% = _mm256_mul_pd(aii%,bir%);
|
||||||
|
const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0); // rr of a
|
||||||
|
const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%);
|
||||||
|
_mm256_storeu_pd(rr[%],res%);
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||||
|
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||||
|
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||||
|
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||||
|
const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0, 5); // conj of b
|
||||||
|
const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1, 5); // conj of b
|
||||||
|
const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2, 5); // conj of b
|
||||||
|
const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3, 5); // conj of b
|
||||||
|
const __m256d aii0 = _mm256_shuffle_pd(ari0,ari0, 15); // im of a
|
||||||
|
const __m256d aii1 = _mm256_shuffle_pd(ari1,ari1, 15); // im of a
|
||||||
|
const __m256d aii2 = _mm256_shuffle_pd(ari2,ari2, 15); // im of a
|
||||||
|
const __m256d aii3 = _mm256_shuffle_pd(ari3,ari3, 15); // im of a
|
||||||
|
const __m256d pro0 = _mm256_mul_pd(aii0,bir0);
|
||||||
|
const __m256d pro1 = _mm256_mul_pd(aii1,bir1);
|
||||||
|
const __m256d pro2 = _mm256_mul_pd(aii2,bir2);
|
||||||
|
const __m256d pro3 = _mm256_mul_pd(aii3,bir3);
|
||||||
|
const __m256d arr0 = _mm256_shuffle_pd(ari0,ari0, 0); // rr of a
|
||||||
|
const __m256d arr1 = _mm256_shuffle_pd(ari1,ari1, 0); // rr of a
|
||||||
|
const __m256d arr2 = _mm256_shuffle_pd(ari2,ari2, 0); // rr of a
|
||||||
|
const __m256d arr3 = _mm256_shuffle_pd(ari3,ari3, 0); // rr of a
|
||||||
|
const __m256d res0 = _mm256_fmaddsub_pd(arr0,bri0,pro0);
|
||||||
|
const __m256d res1 = _mm256_fmaddsub_pd(arr1,bri1,pro1);
|
||||||
|
const __m256d res2 = _mm256_fmaddsub_pd(arr2,bri2,pro2);
|
||||||
|
const __m256d res3 = _mm256_fmaddsub_pd(arr3,bri3,pro3);
|
||||||
|
_mm256_storeu_pd(rr[0],res0);
|
||||||
|
_mm256_storeu_pd(rr[1],res1);
|
||||||
|
_mm256_storeu_pd(rr[2],res2);
|
||||||
|
_mm256_storeu_pd(rr[3],res3);
|
||||||
|
// END_INTERLEAVE
|
||||||
|
rr += 4;
|
||||||
|
aa += 4;
|
||||||
|
bb += 4;
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b) {
|
||||||
|
const double(*aa)[4] = (double(*)[4])a;
|
||||||
|
const double(*bb)[4] = (double(*)[4])b;
|
||||||
|
double(*rr)[4] = (double(*)[4])r;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||||
|
const __m256d res% = _mm256_add_pd(ari%,bri%);
|
||||||
|
_mm256_storeu_pd(rr[%],res%);
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||||
|
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||||
|
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||||
|
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||||
|
const __m256d res0 = _mm256_add_pd(ari0,bri0);
|
||||||
|
const __m256d res1 = _mm256_add_pd(ari1,bri1);
|
||||||
|
const __m256d res2 = _mm256_add_pd(ari2,bri2);
|
||||||
|
const __m256d res3 = _mm256_add_pd(ari3,bri3);
|
||||||
|
_mm256_storeu_pd(rr[0],res0);
|
||||||
|
_mm256_storeu_pd(rr[1],res1);
|
||||||
|
_mm256_storeu_pd(rr[2],res2);
|
||||||
|
_mm256_storeu_pd(rr[3],res3);
|
||||||
|
// END_INTERLEAVE
|
||||||
|
rr += 4;
|
||||||
|
aa += 4;
|
||||||
|
bb += 4;
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b) {
|
||||||
|
const double(*aa)[4] = (double(*)[4])a;
|
||||||
|
const double(*bb)[4] = (double(*)[4])b;
|
||||||
|
double(*rr)[4] = (double(*)[4])r;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||||
|
const __m256d sum% = _mm256_add_pd(ari%,bri%);
|
||||||
|
const __m256d rri% = _mm256_loadu_pd(rr[%]);
|
||||||
|
const __m256d res% = _mm256_sub_pd(rri%,sum%);
|
||||||
|
_mm256_storeu_pd(rr[%],res%);
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||||
|
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||||
|
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||||
|
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||||
|
const __m256d sum0 = _mm256_add_pd(ari0,bri0);
|
||||||
|
const __m256d sum1 = _mm256_add_pd(ari1,bri1);
|
||||||
|
const __m256d sum2 = _mm256_add_pd(ari2,bri2);
|
||||||
|
const __m256d sum3 = _mm256_add_pd(ari3,bri3);
|
||||||
|
const __m256d rri0 = _mm256_loadu_pd(rr[0]);
|
||||||
|
const __m256d rri1 = _mm256_loadu_pd(rr[1]);
|
||||||
|
const __m256d rri2 = _mm256_loadu_pd(rr[2]);
|
||||||
|
const __m256d rri3 = _mm256_loadu_pd(rr[3]);
|
||||||
|
const __m256d res0 = _mm256_sub_pd(rri0,sum0);
|
||||||
|
const __m256d res1 = _mm256_sub_pd(rri1,sum1);
|
||||||
|
const __m256d res2 = _mm256_sub_pd(rri2,sum2);
|
||||||
|
const __m256d res3 = _mm256_sub_pd(rri3,sum3);
|
||||||
|
_mm256_storeu_pd(rr[0],res0);
|
||||||
|
_mm256_storeu_pd(rr[1],res1);
|
||||||
|
_mm256_storeu_pd(rr[2],res2);
|
||||||
|
_mm256_storeu_pd(rr[3],res3);
|
||||||
|
// END_INTERLEAVE
|
||||||
|
rr += 4;
|
||||||
|
aa += 4;
|
||||||
|
bb += 4;
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) {
|
||||||
|
const double(*aa)[4] = (double(*)[4])a;
|
||||||
|
double(*rr)[4] = (double(*)[4])r;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
_mm256_storeu_pd(rr[%],ari%);
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
_mm256_storeu_pd(rr[0],ari0);
|
||||||
|
_mm256_storeu_pd(rr[1],ari1);
|
||||||
|
_mm256_storeu_pd(rr[2],ari2);
|
||||||
|
_mm256_storeu_pd(rr[3],ari3);
|
||||||
|
// END_INTERLEAVE
|
||||||
|
rr += 4;
|
||||||
|
aa += 4;
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_twiddle_fma(uint32_t m, void* a, void* b, const void* omg) {
|
||||||
|
double(*aa)[4] = (double(*)[4])a;
|
||||||
|
double(*bb)[4] = (double(*)[4])b;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
const __m256d om = _mm256_loadu_pd(omg);
|
||||||
|
const __m256d omrr = _mm256_shuffle_pd(om, om, 0);
|
||||||
|
const __m256d omii = _mm256_shuffle_pd(om, om, 15);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||||
|
const __m256d bir% = _mm256_shuffle_pd(bri%,bri%,5);
|
||||||
|
__m256d p% = _mm256_mul_pd(bir%,omii);
|
||||||
|
p% = _mm256_fmaddsub_pd(bri%,omrr,p%);
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
_mm256_storeu_pd(aa[%],_mm256_add_pd(ari%,p%));
|
||||||
|
_mm256_storeu_pd(bb[%],_mm256_sub_pd(ari%,p%));
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||||
|
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||||
|
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||||
|
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||||
|
const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0,5);
|
||||||
|
const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1,5);
|
||||||
|
const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2,5);
|
||||||
|
const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3,5);
|
||||||
|
__m256d p0 = _mm256_mul_pd(bir0,omii);
|
||||||
|
__m256d p1 = _mm256_mul_pd(bir1,omii);
|
||||||
|
__m256d p2 = _mm256_mul_pd(bir2,omii);
|
||||||
|
__m256d p3 = _mm256_mul_pd(bir3,omii);
|
||||||
|
p0 = _mm256_fmaddsub_pd(bri0,omrr,p0);
|
||||||
|
p1 = _mm256_fmaddsub_pd(bri1,omrr,p1);
|
||||||
|
p2 = _mm256_fmaddsub_pd(bri2,omrr,p2);
|
||||||
|
p3 = _mm256_fmaddsub_pd(bri3,omrr,p3);
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
_mm256_storeu_pd(aa[0],_mm256_add_pd(ari0,p0));
|
||||||
|
_mm256_storeu_pd(aa[1],_mm256_add_pd(ari1,p1));
|
||||||
|
_mm256_storeu_pd(aa[2],_mm256_add_pd(ari2,p2));
|
||||||
|
_mm256_storeu_pd(aa[3],_mm256_add_pd(ari3,p3));
|
||||||
|
_mm256_storeu_pd(bb[0],_mm256_sub_pd(ari0,p0));
|
||||||
|
_mm256_storeu_pd(bb[1],_mm256_sub_pd(ari1,p1));
|
||||||
|
_mm256_storeu_pd(bb[2],_mm256_sub_pd(ari2,p2));
|
||||||
|
_mm256_storeu_pd(bb[3],_mm256_sub_pd(ari3,p3));
|
||||||
|
// END_INTERLEAVE
|
||||||
|
bb += 4;
|
||||||
|
aa += 4;
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_innerprod_avx2_fma(const CPLX_FFTVEC_INNERPROD_PRECOMP* precomp, const int32_t ellbar,
|
||||||
|
const uint64_t lda, const uint64_t ldb,
|
||||||
|
void* r, const void* a, const void* b) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const uint32_t blk = precomp->blk;
|
||||||
|
const uint32_t nblocks = precomp->nblocks;
|
||||||
|
const CPLX* aa = (CPLX*)a;
|
||||||
|
const CPLX* bb = (CPLX*)b;
|
||||||
|
CPLX* rr = (CPLX*)r;
|
||||||
|
const uint64_t ldda = lda >> 4; // in CPLX
|
||||||
|
const uint64_t lddb = ldb >> 4;
|
||||||
|
if (m==0) {
|
||||||
|
memset(r, 0, m*sizeof(CPLX));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (uint32_t k=0; k<nblocks; ++k) {
|
||||||
|
const uint64_t offset = k*blk;
|
||||||
|
const CPLX* aaa = aa+offset;
|
||||||
|
const CPLX* bbb = bb+offset;
|
||||||
|
CPLX *rrr = rr+offset;
|
||||||
|
cplx_fftvec_mul_fma(&precomp->mul_func, rrr, aaa, bbb);
|
||||||
|
for (int32_t i=1; i<ellbar; ++i) {
|
||||||
|
cplx_fftvec_addmul_fma(&precomp->addmul_func, rrr, aaa + i * ldda, bbb + i * lddb);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
389
spqlios/lib/spqlios/cplx/cplx_fftvec_avx2_fma.c
Normal file
389
spqlios/lib/spqlios/cplx/cplx_fftvec_avx2_fma.c
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
#include <immintrin.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
typedef double D4MEM[4];
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const double(*aa)[4] = (double(*)[4])a;
|
||||||
|
const double(*bb)[4] = (double(*)[4])b;
|
||||||
|
double(*rr)[4] = (double(*)[4])r;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||||
|
const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5); // conj of b
|
||||||
|
const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15); // im of a
|
||||||
|
const __m256d pro% = _mm256_mul_pd(aii%,bir%);
|
||||||
|
const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0); // rr of a
|
||||||
|
const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%);
|
||||||
|
_mm256_storeu_pd(rr[%],res%);
|
||||||
|
rr += @; // ONCE
|
||||||
|
aa += @; // ONCE
|
||||||
|
bb += @; // ONCE
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
// This block is automatically generated from the template above
|
||||||
|
// by the interleave.pl script. Please do not edit by hand
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||||
|
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||||
|
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||||
|
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||||
|
const __m256d bir0 = _mm256_shuffle_pd(bri0, bri0, 5); // conj of b
|
||||||
|
const __m256d bir1 = _mm256_shuffle_pd(bri1, bri1, 5); // conj of b
|
||||||
|
const __m256d bir2 = _mm256_shuffle_pd(bri2, bri2, 5); // conj of b
|
||||||
|
const __m256d bir3 = _mm256_shuffle_pd(bri3, bri3, 5); // conj of b
|
||||||
|
const __m256d aii0 = _mm256_shuffle_pd(ari0, ari0, 15); // im of a
|
||||||
|
const __m256d aii1 = _mm256_shuffle_pd(ari1, ari1, 15); // im of a
|
||||||
|
const __m256d aii2 = _mm256_shuffle_pd(ari2, ari2, 15); // im of a
|
||||||
|
const __m256d aii3 = _mm256_shuffle_pd(ari3, ari3, 15); // im of a
|
||||||
|
const __m256d pro0 = _mm256_mul_pd(aii0, bir0);
|
||||||
|
const __m256d pro1 = _mm256_mul_pd(aii1, bir1);
|
||||||
|
const __m256d pro2 = _mm256_mul_pd(aii2, bir2);
|
||||||
|
const __m256d pro3 = _mm256_mul_pd(aii3, bir3);
|
||||||
|
const __m256d arr0 = _mm256_shuffle_pd(ari0, ari0, 0); // rr of a
|
||||||
|
const __m256d arr1 = _mm256_shuffle_pd(ari1, ari1, 0); // rr of a
|
||||||
|
const __m256d arr2 = _mm256_shuffle_pd(ari2, ari2, 0); // rr of a
|
||||||
|
const __m256d arr3 = _mm256_shuffle_pd(ari3, ari3, 0); // rr of a
|
||||||
|
const __m256d res0 = _mm256_fmaddsub_pd(arr0, bri0, pro0);
|
||||||
|
const __m256d res1 = _mm256_fmaddsub_pd(arr1, bri1, pro1);
|
||||||
|
const __m256d res2 = _mm256_fmaddsub_pd(arr2, bri2, pro2);
|
||||||
|
const __m256d res3 = _mm256_fmaddsub_pd(arr3, bri3, pro3);
|
||||||
|
_mm256_storeu_pd(rr[0], res0);
|
||||||
|
_mm256_storeu_pd(rr[1], res1);
|
||||||
|
_mm256_storeu_pd(rr[2], res2);
|
||||||
|
_mm256_storeu_pd(rr[3], res3);
|
||||||
|
rr += 4; // ONCE
|
||||||
|
aa += 4; // ONCE
|
||||||
|
bb += 4; // ONCE
|
||||||
|
// END_INTERLEAVE
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const double(*aa)[4] = (double(*)[4])a;
|
||||||
|
const double(*bb)[4] = (double(*)[4])b;
|
||||||
|
double(*rr)[4] = (double(*)[4])r;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||||
|
const __m256d rri% = _mm256_loadu_pd(rr[%]);
|
||||||
|
const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5);
|
||||||
|
const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15);
|
||||||
|
const __m256d pro% = _mm256_fmaddsub_pd(aii%,bir%,rri%);
|
||||||
|
const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0);
|
||||||
|
const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%);
|
||||||
|
_mm256_storeu_pd(rr[%],res%);
|
||||||
|
rr += @; // ONCE
|
||||||
|
aa += @; // ONCE
|
||||||
|
bb += @; // ONCE
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 2
|
||||||
|
// This block is automatically generated from the template above
|
||||||
|
// by the interleave.pl script. Please do not edit by hand
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||||
|
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||||
|
const __m256d rri0 = _mm256_loadu_pd(rr[0]);
|
||||||
|
const __m256d rri1 = _mm256_loadu_pd(rr[1]);
|
||||||
|
const __m256d bir0 = _mm256_shuffle_pd(bri0, bri0, 5);
|
||||||
|
const __m256d bir1 = _mm256_shuffle_pd(bri1, bri1, 5);
|
||||||
|
const __m256d aii0 = _mm256_shuffle_pd(ari0, ari0, 15);
|
||||||
|
const __m256d aii1 = _mm256_shuffle_pd(ari1, ari1, 15);
|
||||||
|
const __m256d pro0 = _mm256_fmaddsub_pd(aii0, bir0, rri0);
|
||||||
|
const __m256d pro1 = _mm256_fmaddsub_pd(aii1, bir1, rri1);
|
||||||
|
const __m256d arr0 = _mm256_shuffle_pd(ari0, ari0, 0);
|
||||||
|
const __m256d arr1 = _mm256_shuffle_pd(ari1, ari1, 0);
|
||||||
|
const __m256d res0 = _mm256_fmaddsub_pd(arr0, bri0, pro0);
|
||||||
|
const __m256d res1 = _mm256_fmaddsub_pd(arr1, bri1, pro1);
|
||||||
|
_mm256_storeu_pd(rr[0], res0);
|
||||||
|
_mm256_storeu_pd(rr[1], res1);
|
||||||
|
rr += 2; // ONCE
|
||||||
|
aa += 2; // ONCE
|
||||||
|
bb += 2; // ONCE
|
||||||
|
// END_INTERLEAVE
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* precomp, void* a, uint64_t slicea,
|
||||||
|
const void* omg) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const uint64_t OFFSET = slicea / sizeof(D4MEM);
|
||||||
|
D4MEM* aa = (D4MEM*)a;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
const __m256d om = _mm256_loadu_pd(omg);
|
||||||
|
const __m256d om1rr = _mm256_shuffle_pd(om, om, 0);
|
||||||
|
const __m256d om1ii = _mm256_shuffle_pd(om, om, 15);
|
||||||
|
const __m256d om2rr = _mm256_shuffle_pd(om, om, 0);
|
||||||
|
const __m256d om2ii = _mm256_shuffle_pd(om, om, 0);
|
||||||
|
const __m256d om3rr = _mm256_shuffle_pd(om, om, 15);
|
||||||
|
const __m256d om3ii = _mm256_shuffle_pd(om, om, 15);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
__m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
__m256d bri% = _mm256_loadu_pd((aa+OFFSET)[%]);
|
||||||
|
__m256d cri% = _mm256_loadu_pd((aa+2*OFFSET)[%]);
|
||||||
|
__m256d dri% = _mm256_loadu_pd((aa+3*OFFSET)[%]);
|
||||||
|
__m256d pa% = _mm256_shuffle_pd(cri%,cri%,5);
|
||||||
|
__m256d pb% = _mm256_shuffle_pd(dri%,dri%,5);
|
||||||
|
pa% = _mm256_mul_pd(pa%,om1ii);
|
||||||
|
pb% = _mm256_mul_pd(pb%,om1ii);
|
||||||
|
pa% = _mm256_fmaddsub_pd(cri%,om1rr,pa%);
|
||||||
|
pb% = _mm256_fmaddsub_pd(dri%,om1rr,pb%);
|
||||||
|
cri% = _mm256_sub_pd(ari%,pa%);
|
||||||
|
dri% = _mm256_sub_pd(bri%,pb%);
|
||||||
|
ari% = _mm256_add_pd(ari%,pa%);
|
||||||
|
bri% = _mm256_add_pd(bri%,pb%);
|
||||||
|
pa% = _mm256_shuffle_pd(bri%,bri%,5);
|
||||||
|
pb% = _mm256_shuffle_pd(dri%,dri%,5);
|
||||||
|
pa% = _mm256_mul_pd(pa%,om2ii);
|
||||||
|
pb% = _mm256_mul_pd(pb%,om3ii);
|
||||||
|
pa% = _mm256_fmaddsub_pd(bri%,om2rr,pa%);
|
||||||
|
pb% = _mm256_fmaddsub_pd(dri%,om3rr,pb%);
|
||||||
|
bri% = _mm256_sub_pd(ari%,pa%);
|
||||||
|
dri% = _mm256_sub_pd(cri%,pb%);
|
||||||
|
ari% = _mm256_add_pd(ari%,pa%);
|
||||||
|
cri% = _mm256_add_pd(cri%,pb%);
|
||||||
|
_mm256_storeu_pd(aa[%], ari%);
|
||||||
|
_mm256_storeu_pd((aa+OFFSET)[%],bri%);
|
||||||
|
_mm256_storeu_pd((aa+2*OFFSET)[%],cri%);
|
||||||
|
_mm256_storeu_pd((aa+3*OFFSET)[%],dri%);
|
||||||
|
aa += @; // ONCE
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 1
|
||||||
|
// This block is automatically generated from the template above
|
||||||
|
// by the interleave.pl script. Please do not edit by hand
|
||||||
|
__m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
__m256d bri0 = _mm256_loadu_pd((aa + OFFSET)[0]);
|
||||||
|
__m256d cri0 = _mm256_loadu_pd((aa + 2 * OFFSET)[0]);
|
||||||
|
__m256d dri0 = _mm256_loadu_pd((aa + 3 * OFFSET)[0]);
|
||||||
|
__m256d pa0 = _mm256_shuffle_pd(cri0, cri0, 5);
|
||||||
|
__m256d pb0 = _mm256_shuffle_pd(dri0, dri0, 5);
|
||||||
|
pa0 = _mm256_mul_pd(pa0, om1ii);
|
||||||
|
pb0 = _mm256_mul_pd(pb0, om1ii);
|
||||||
|
pa0 = _mm256_fmaddsub_pd(cri0, om1rr, pa0);
|
||||||
|
pb0 = _mm256_fmaddsub_pd(dri0, om1rr, pb0);
|
||||||
|
cri0 = _mm256_sub_pd(ari0, pa0);
|
||||||
|
dri0 = _mm256_sub_pd(bri0, pb0);
|
||||||
|
ari0 = _mm256_add_pd(ari0, pa0);
|
||||||
|
bri0 = _mm256_add_pd(bri0, pb0);
|
||||||
|
pa0 = _mm256_shuffle_pd(bri0, bri0, 5);
|
||||||
|
pb0 = _mm256_shuffle_pd(dri0, dri0, 5);
|
||||||
|
pa0 = _mm256_mul_pd(pa0, om2ii);
|
||||||
|
pb0 = _mm256_mul_pd(pb0, om3ii);
|
||||||
|
pa0 = _mm256_fmaddsub_pd(bri0, om2rr, pa0);
|
||||||
|
pb0 = _mm256_fmaddsub_pd(dri0, om3rr, pb0);
|
||||||
|
bri0 = _mm256_sub_pd(ari0, pa0);
|
||||||
|
dri0 = _mm256_sub_pd(cri0, pb0);
|
||||||
|
ari0 = _mm256_add_pd(ari0, pa0);
|
||||||
|
cri0 = _mm256_add_pd(cri0, pb0);
|
||||||
|
_mm256_storeu_pd(aa[0], ari0);
|
||||||
|
_mm256_storeu_pd((aa + OFFSET)[0], bri0);
|
||||||
|
_mm256_storeu_pd((aa + 2 * OFFSET)[0], cri0);
|
||||||
|
_mm256_storeu_pd((aa + 3 * OFFSET)[0], dri0);
|
||||||
|
aa += 1; // ONCE
|
||||||
|
// END_INTERLEAVE
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b) {
|
||||||
|
const double(*aa)[4] = (double(*)[4])a;
|
||||||
|
const double(*bb)[4] = (double(*)[4])b;
|
||||||
|
double(*rr)[4] = (double(*)[4])r;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||||
|
const __m256d sum% = _mm256_add_pd(ari%,bri%);
|
||||||
|
const __m256d rri% = _mm256_loadu_pd(rr[%]);
|
||||||
|
const __m256d res% = _mm256_sub_pd(rri%,sum%);
|
||||||
|
_mm256_storeu_pd(rr[%],res%);
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
// This block is automatically generated from the template above
|
||||||
|
// by the interleave.pl script. Please do not edit by hand
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||||
|
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||||
|
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||||
|
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||||
|
const __m256d sum0 = _mm256_add_pd(ari0, bri0);
|
||||||
|
const __m256d sum1 = _mm256_add_pd(ari1, bri1);
|
||||||
|
const __m256d sum2 = _mm256_add_pd(ari2, bri2);
|
||||||
|
const __m256d sum3 = _mm256_add_pd(ari3, bri3);
|
||||||
|
const __m256d rri0 = _mm256_loadu_pd(rr[0]);
|
||||||
|
const __m256d rri1 = _mm256_loadu_pd(rr[1]);
|
||||||
|
const __m256d rri2 = _mm256_loadu_pd(rr[2]);
|
||||||
|
const __m256d rri3 = _mm256_loadu_pd(rr[3]);
|
||||||
|
const __m256d res0 = _mm256_sub_pd(rri0, sum0);
|
||||||
|
const __m256d res1 = _mm256_sub_pd(rri1, sum1);
|
||||||
|
const __m256d res2 = _mm256_sub_pd(rri2, sum2);
|
||||||
|
const __m256d res3 = _mm256_sub_pd(rri3, sum3);
|
||||||
|
_mm256_storeu_pd(rr[0], res0);
|
||||||
|
_mm256_storeu_pd(rr[1], res1);
|
||||||
|
_mm256_storeu_pd(rr[2], res2);
|
||||||
|
_mm256_storeu_pd(rr[3], res3);
|
||||||
|
// END_INTERLEAVE
|
||||||
|
rr += 4;
|
||||||
|
aa += 4;
|
||||||
|
bb += 4;
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b) {
|
||||||
|
const double(*aa)[4] = (double(*)[4])a;
|
||||||
|
const double(*bb)[4] = (double(*)[4])b;
|
||||||
|
double(*rr)[4] = (double(*)[4])r;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||||
|
const __m256d res% = _mm256_add_pd(ari%,bri%);
|
||||||
|
_mm256_storeu_pd(rr[%],res%);
|
||||||
|
rr += @; // ONCE
|
||||||
|
aa += @; // ONCE
|
||||||
|
bb += @; // ONCE
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
// This block is automatically generated from the template above
|
||||||
|
// by the interleave.pl script. Please do not edit by hand
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||||
|
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||||
|
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||||
|
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||||
|
const __m256d res0 = _mm256_add_pd(ari0, bri0);
|
||||||
|
const __m256d res1 = _mm256_add_pd(ari1, bri1);
|
||||||
|
const __m256d res2 = _mm256_add_pd(ari2, bri2);
|
||||||
|
const __m256d res3 = _mm256_add_pd(ari3, bri3);
|
||||||
|
_mm256_storeu_pd(rr[0], res0);
|
||||||
|
_mm256_storeu_pd(rr[1], res1);
|
||||||
|
_mm256_storeu_pd(rr[2], res2);
|
||||||
|
_mm256_storeu_pd(rr[3], res3);
|
||||||
|
rr += 4; // ONCE
|
||||||
|
aa += 4; // ONCE
|
||||||
|
bb += 4; // ONCE
|
||||||
|
// END_INTERLEAVE
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* precomp, void* a, void* b, const void* omg) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
double(*aa)[4] = (double(*)[4])a;
|
||||||
|
double(*bb)[4] = (double(*)[4])b;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
const __m256d om = _mm256_loadu_pd(omg);
|
||||||
|
const __m256d omrr = _mm256_shuffle_pd(om, om, 0);
|
||||||
|
const __m256d omii = _mm256_shuffle_pd(om, om, 15);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||||
|
const __m256d bir% = _mm256_shuffle_pd(bri%,bri%,5);
|
||||||
|
__m256d p% = _mm256_mul_pd(bir%,omii);
|
||||||
|
p% = _mm256_fmaddsub_pd(bri%,omrr,p%);
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
_mm256_storeu_pd(aa[%],_mm256_add_pd(ari%,p%));
|
||||||
|
_mm256_storeu_pd(bb[%],_mm256_sub_pd(ari%,p%));
|
||||||
|
bb += @; // ONCE
|
||||||
|
aa += @; // ONCE
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
// This block is automatically generated from the template above
|
||||||
|
// by the interleave.pl script. Please do not edit by hand
|
||||||
|
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||||
|
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||||
|
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||||
|
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||||
|
const __m256d bir0 = _mm256_shuffle_pd(bri0, bri0, 5);
|
||||||
|
const __m256d bir1 = _mm256_shuffle_pd(bri1, bri1, 5);
|
||||||
|
const __m256d bir2 = _mm256_shuffle_pd(bri2, bri2, 5);
|
||||||
|
const __m256d bir3 = _mm256_shuffle_pd(bri3, bri3, 5);
|
||||||
|
__m256d p0 = _mm256_mul_pd(bir0, omii);
|
||||||
|
__m256d p1 = _mm256_mul_pd(bir1, omii);
|
||||||
|
__m256d p2 = _mm256_mul_pd(bir2, omii);
|
||||||
|
__m256d p3 = _mm256_mul_pd(bir3, omii);
|
||||||
|
p0 = _mm256_fmaddsub_pd(bri0, omrr, p0);
|
||||||
|
p1 = _mm256_fmaddsub_pd(bri1, omrr, p1);
|
||||||
|
p2 = _mm256_fmaddsub_pd(bri2, omrr, p2);
|
||||||
|
p3 = _mm256_fmaddsub_pd(bri3, omrr, p3);
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
_mm256_storeu_pd(aa[0], _mm256_add_pd(ari0, p0));
|
||||||
|
_mm256_storeu_pd(aa[1], _mm256_add_pd(ari1, p1));
|
||||||
|
_mm256_storeu_pd(aa[2], _mm256_add_pd(ari2, p2));
|
||||||
|
_mm256_storeu_pd(aa[3], _mm256_add_pd(ari3, p3));
|
||||||
|
_mm256_storeu_pd(bb[0], _mm256_sub_pd(ari0, p0));
|
||||||
|
_mm256_storeu_pd(bb[1], _mm256_sub_pd(ari1, p1));
|
||||||
|
_mm256_storeu_pd(bb[2], _mm256_sub_pd(ari2, p2));
|
||||||
|
_mm256_storeu_pd(bb[3], _mm256_sub_pd(ari3, p3));
|
||||||
|
bb += 4; // ONCE
|
||||||
|
aa += 4; // ONCE
|
||||||
|
// END_INTERLEAVE
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) {
|
||||||
|
const double(*aa)[4] = (double(*)[4])a;
|
||||||
|
double(*rr)[4] = (double(*)[4])r;
|
||||||
|
const double(*const aend)[4] = aa + (m >> 1);
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||||
|
_mm256_storeu_pd(rr[%],ari%);
|
||||||
|
rr += @; // ONCE
|
||||||
|
aa += @; // ONCE
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 4
|
||||||
|
// This block is automatically generated from the template above
|
||||||
|
// by the interleave.pl script. Please do not edit by hand
|
||||||
|
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||||
|
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||||
|
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||||
|
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||||
|
_mm256_storeu_pd(rr[0], ari0);
|
||||||
|
_mm256_storeu_pd(rr[1], ari1);
|
||||||
|
_mm256_storeu_pd(rr[2], ari2);
|
||||||
|
_mm256_storeu_pd(rr[3], ari3);
|
||||||
|
rr += 4; // ONCE
|
||||||
|
aa += 4; // ONCE
|
||||||
|
// END_INTERLEAVE
|
||||||
|
} while (aa < aend);
|
||||||
|
}
|
||||||
85
spqlios/lib/spqlios/cplx/cplx_fftvec_ref.c
Normal file
85
spqlios/lib/spqlios/cplx/cplx_fftvec_ref.c
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../commons_private.h"
|
||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_addmul_ref(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const CPLX* aa = (CPLX*)a;
|
||||||
|
const CPLX* bb = (CPLX*)b;
|
||||||
|
CPLX* rr = (CPLX*)r;
|
||||||
|
for (uint32_t i = 0; i < m; ++i) {
|
||||||
|
const double re = aa[i][0] * bb[i][0] - aa[i][1] * bb[i][1];
|
||||||
|
const double im = aa[i][0] * bb[i][1] + aa[i][1] * bb[i][0];
|
||||||
|
rr[i][0] += re;
|
||||||
|
rr[i][1] += im;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_mul_ref(const CPLX_FFTVEC_MUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const CPLX* aa = (CPLX*)a;
|
||||||
|
const CPLX* bb = (CPLX*)b;
|
||||||
|
CPLX* rr = (CPLX*)r;
|
||||||
|
for (uint32_t i = 0; i < m; ++i) {
|
||||||
|
const double re = aa[i][0] * bb[i][0] - aa[i][1] * bb[i][1];
|
||||||
|
const double im = aa[i][0] * bb[i][1] + aa[i][1] * bb[i][0];
|
||||||
|
rr[i][0] = re;
|
||||||
|
rr[i][1] = im;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void* init_cplx_fftvec_addmul_precomp(CPLX_FFTVEC_ADDMUL_PRECOMP* r, uint32_t m) {
|
||||||
|
if (m & (m - 1)) return spqlios_error("m must be a power of two");
|
||||||
|
r->m = m;
|
||||||
|
if (m <= 4) {
|
||||||
|
r->function = cplx_fftvec_addmul_ref;
|
||||||
|
} else if (CPU_SUPPORTS("fma")) {
|
||||||
|
r->function = cplx_fftvec_addmul_fma;
|
||||||
|
} else {
|
||||||
|
r->function = cplx_fftvec_addmul_ref;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void* init_cplx_fftvec_mul_precomp(CPLX_FFTVEC_MUL_PRECOMP* r, uint32_t m) {
|
||||||
|
if (m & (m - 1)) return spqlios_error("m must be a power of two");
|
||||||
|
r->m = m;
|
||||||
|
if (m <= 4) {
|
||||||
|
r->function = cplx_fftvec_mul_ref;
|
||||||
|
} else if (CPU_SUPPORTS("fma")) {
|
||||||
|
r->function = cplx_fftvec_mul_fma;
|
||||||
|
} else {
|
||||||
|
r->function = cplx_fftvec_mul_ref;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT CPLX_FFTVEC_ADDMUL_PRECOMP* new_cplx_fftvec_addmul_precomp(uint32_t m) {
|
||||||
|
CPLX_FFTVEC_ADDMUL_PRECOMP* r = malloc(sizeof(CPLX_FFTVEC_MUL_PRECOMP));
|
||||||
|
return spqlios_keep_or_free(r, init_cplx_fftvec_addmul_precomp(r, m));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT CPLX_FFTVEC_MUL_PRECOMP* new_cplx_fftvec_mul_precomp(uint32_t m) {
|
||||||
|
CPLX_FFTVEC_MUL_PRECOMP* r = malloc(sizeof(CPLX_FFTVEC_MUL_PRECOMP));
|
||||||
|
return spqlios_keep_or_free(r, init_cplx_fftvec_mul_precomp(r, m));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b) {
|
||||||
|
static CPLX_FFTVEC_MUL_PRECOMP p[31] = {0};
|
||||||
|
CPLX_FFTVEC_MUL_PRECOMP* f = p + log2m(m);
|
||||||
|
if (!f->function) {
|
||||||
|
if (!init_cplx_fftvec_mul_precomp(f, m)) abort();
|
||||||
|
}
|
||||||
|
f->function(f, r, a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b) {
|
||||||
|
static CPLX_FFTVEC_ADDMUL_PRECOMP p[31] = {0};
|
||||||
|
CPLX_FFTVEC_ADDMUL_PRECOMP* f = p + log2m(m);
|
||||||
|
if (!f->function) {
|
||||||
|
if (!init_cplx_fftvec_addmul_precomp(f, m)) abort();
|
||||||
|
}
|
||||||
|
f->function(f, r, a, b);
|
||||||
|
}
|
||||||
157
spqlios/lib/spqlios/cplx/cplx_ifft16_avx_fma.s
Normal file
157
spqlios/lib/spqlios/cplx/cplx_ifft16_avx_fma.s
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
# shifted FFT over X^16-i
|
||||||
|
# 1st argument (rdi) contains 16 complexes
|
||||||
|
# 2nd argument (rsi) contains: 8 complexes
|
||||||
|
# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma
|
||||||
|
# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta)
|
||||||
|
# j = sqrt(i), k=sqrt(j)
|
||||||
|
.globl cplx_ifft16_avx_fma
|
||||||
|
cplx_ifft16_avx_fma:
|
||||||
|
vmovupd (%rdi),%ymm8 # load data into registers %ymm8 -> %ymm15
|
||||||
|
vmovupd 0x20(%rdi),%ymm9
|
||||||
|
vmovupd 0x40(%rdi),%ymm10
|
||||||
|
vmovupd 0x60(%rdi),%ymm11
|
||||||
|
vmovupd 0x80(%rdi),%ymm12
|
||||||
|
vmovupd 0xa0(%rdi),%ymm13
|
||||||
|
vmovupd 0xc0(%rdi),%ymm14
|
||||||
|
vmovupd 0xe0(%rdi),%ymm15
|
||||||
|
|
||||||
|
.fourth_pass:
|
||||||
|
vmovupd 0(%rsi),%ymm0 /* gamma */
|
||||||
|
vmovupd 32(%rsi),%ymm2 /* delta */
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||||
|
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||||
|
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||||
|
vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5
|
||||||
|
vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7
|
||||||
|
vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13
|
||||||
|
vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15
|
||||||
|
vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||||
|
vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||||
|
vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12
|
||||||
|
vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14
|
||||||
|
vsubpd %ymm4,%ymm8,%ymm12 # tw: to mul by gamma
|
||||||
|
vsubpd %ymm5,%ymm9,%ymm13 # itw: to mul by i.gamma
|
||||||
|
vsubpd %ymm6,%ymm10,%ymm14 # tw: to mul by delta
|
||||||
|
vsubpd %ymm7,%ymm11,%ymm15 # itw: to mul by i.delta
|
||||||
|
vaddpd %ymm4,%ymm8,%ymm8
|
||||||
|
vaddpd %ymm5,%ymm9,%ymm9
|
||||||
|
vaddpd %ymm6,%ymm10,%ymm10
|
||||||
|
vaddpd %ymm7,%ymm11,%ymm11
|
||||||
|
vshufpd $5, %ymm12, %ymm12, %ymm4
|
||||||
|
vshufpd $5, %ymm13, %ymm13, %ymm5
|
||||||
|
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||||
|
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||||
|
vmulpd %ymm4,%ymm1,%ymm4
|
||||||
|
vmulpd %ymm5,%ymm0,%ymm5
|
||||||
|
vmulpd %ymm6,%ymm3,%ymm6
|
||||||
|
vmulpd %ymm7,%ymm2,%ymm7
|
||||||
|
vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4
|
||||||
|
vfmsubadd231pd %ymm13, %ymm1, %ymm5
|
||||||
|
vfmaddsub231pd %ymm14, %ymm2, %ymm6
|
||||||
|
vfmsubadd231pd %ymm15, %ymm3, %ymm7
|
||||||
|
|
||||||
|
vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma
|
||||||
|
vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma
|
||||||
|
vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta
|
||||||
|
vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta
|
||||||
|
vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12
|
||||||
|
vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14
|
||||||
|
vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||||
|
vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||||
|
|
||||||
|
|
||||||
|
.third_pass:
|
||||||
|
vmovupd 64(%rsi),%xmm0 /* gamma */
|
||||||
|
vmovupd 80(%rsi),%xmm2 /* delta */
|
||||||
|
vinsertf128 $1, %xmm0, %ymm0, %ymm0
|
||||||
|
vinsertf128 $1, %xmm2, %ymm2, %ymm2
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||||
|
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||||
|
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||||
|
vsubpd %ymm9,%ymm8,%ymm4
|
||||||
|
vsubpd %ymm11,%ymm10,%ymm5
|
||||||
|
vsubpd %ymm13,%ymm12,%ymm6
|
||||||
|
vsubpd %ymm15,%ymm14,%ymm7
|
||||||
|
vaddpd %ymm9,%ymm8,%ymm8
|
||||||
|
vaddpd %ymm11,%ymm10,%ymm10
|
||||||
|
vaddpd %ymm13,%ymm12,%ymm12
|
||||||
|
vaddpd %ymm15,%ymm14,%ymm14
|
||||||
|
vshufpd $5, %ymm4, %ymm4, %ymm9
|
||||||
|
vshufpd $5, %ymm5, %ymm5, %ymm11
|
||||||
|
vshufpd $5, %ymm6, %ymm6, %ymm13
|
||||||
|
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||||
|
vmulpd %ymm9,%ymm1,%ymm9
|
||||||
|
vmulpd %ymm11,%ymm0,%ymm11
|
||||||
|
vmulpd %ymm13,%ymm3,%ymm13
|
||||||
|
vmulpd %ymm15,%ymm2,%ymm15
|
||||||
|
vfmaddsub231pd %ymm4, %ymm0, %ymm9 # ymm9 = (ymm0 * ymm4) +/- ymm9
|
||||||
|
vfmsubadd231pd %ymm5, %ymm1, %ymm11
|
||||||
|
vfmaddsub231pd %ymm6, %ymm2, %ymm13
|
||||||
|
vfmsubadd231pd %ymm7, %ymm3, %ymm15
|
||||||
|
|
||||||
|
.second_pass:
|
||||||
|
vmovupd 96(%rsi),%xmm0 /* omri */
|
||||||
|
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||||
|
vsubpd %ymm10,%ymm8,%ymm4
|
||||||
|
vsubpd %ymm11,%ymm9,%ymm5
|
||||||
|
vsubpd %ymm14,%ymm12,%ymm6
|
||||||
|
vsubpd %ymm15,%ymm13,%ymm7
|
||||||
|
vaddpd %ymm10,%ymm8,%ymm8
|
||||||
|
vaddpd %ymm11,%ymm9,%ymm9
|
||||||
|
vaddpd %ymm14,%ymm12,%ymm12
|
||||||
|
vaddpd %ymm15,%ymm13,%ymm13
|
||||||
|
vshufpd $5, %ymm4, %ymm4, %ymm10
|
||||||
|
vshufpd $5, %ymm5, %ymm5, %ymm11
|
||||||
|
vshufpd $5, %ymm6, %ymm6, %ymm14
|
||||||
|
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||||
|
vmulpd %ymm10,%ymm1,%ymm10
|
||||||
|
vmulpd %ymm11,%ymm1,%ymm11
|
||||||
|
vmulpd %ymm14,%ymm0,%ymm14
|
||||||
|
vmulpd %ymm15,%ymm0,%ymm15
|
||||||
|
vfmaddsub231pd %ymm4, %ymm0, %ymm10 # ymm10 = (ymm0 * ymm4) +/- ymm10
|
||||||
|
vfmaddsub231pd %ymm5, %ymm0, %ymm11
|
||||||
|
vfmsubadd231pd %ymm6, %ymm1, %ymm14
|
||||||
|
vfmsubadd231pd %ymm7, %ymm1, %ymm15
|
||||||
|
|
||||||
|
.first_pass:
|
||||||
|
vmovupd 112(%rsi),%xmm0 /* omri */
|
||||||
|
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||||
|
vsubpd %ymm12,%ymm8,%ymm4
|
||||||
|
vsubpd %ymm13,%ymm9,%ymm5
|
||||||
|
vsubpd %ymm14,%ymm10,%ymm6
|
||||||
|
vsubpd %ymm15,%ymm11,%ymm7
|
||||||
|
vaddpd %ymm12,%ymm8,%ymm8
|
||||||
|
vaddpd %ymm13,%ymm9,%ymm9
|
||||||
|
vaddpd %ymm14,%ymm10,%ymm10
|
||||||
|
vaddpd %ymm15,%ymm11,%ymm11
|
||||||
|
vshufpd $5, %ymm4, %ymm4, %ymm12
|
||||||
|
vshufpd $5, %ymm5, %ymm5, %ymm13
|
||||||
|
vshufpd $5, %ymm6, %ymm6, %ymm14
|
||||||
|
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||||
|
vmulpd %ymm12,%ymm1,%ymm12
|
||||||
|
vmulpd %ymm13,%ymm1,%ymm13
|
||||||
|
vmulpd %ymm14,%ymm1,%ymm14
|
||||||
|
vmulpd %ymm15,%ymm1,%ymm15
|
||||||
|
vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12
|
||||||
|
vfmaddsub231pd %ymm5, %ymm0, %ymm13
|
||||||
|
vfmaddsub231pd %ymm6, %ymm0, %ymm14
|
||||||
|
vfmaddsub231pd %ymm7, %ymm0, %ymm15
|
||||||
|
|
||||||
|
.save_and_return:
|
||||||
|
vmovupd %ymm8,(%rdi)
|
||||||
|
vmovupd %ymm9,0x20(%rdi)
|
||||||
|
vmovupd %ymm10,0x40(%rdi)
|
||||||
|
vmovupd %ymm11,0x60(%rdi)
|
||||||
|
vmovupd %ymm12,0x80(%rdi)
|
||||||
|
vmovupd %ymm13,0xa0(%rdi)
|
||||||
|
vmovupd %ymm14,0xc0(%rdi)
|
||||||
|
vmovupd %ymm15,0xe0(%rdi)
|
||||||
|
ret
|
||||||
|
.size cplx_ifft16_avx_fma, .-cplx_ifft16_avx_fma
|
||||||
|
.section .note.GNU-stack,"",@progbits
|
||||||
192
spqlios/lib/spqlios/cplx/cplx_ifft16_avx_fma_win32.s
Normal file
192
spqlios/lib/spqlios/cplx/cplx_ifft16_avx_fma_win32.s
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
.text
|
||||||
|
.p2align 4
|
||||||
|
.globl cplx_ifft16_avx_fma
|
||||||
|
.def cplx_ifft16_avx_fma; .scl 2; .type 32; .endef
|
||||||
|
cplx_ifft16_avx_fma:
|
||||||
|
|
||||||
|
pushq %rdi
|
||||||
|
pushq %rsi
|
||||||
|
movq %rcx,%rdi
|
||||||
|
movq %rdx,%rsi
|
||||||
|
subq $0x100,%rsp
|
||||||
|
movdqu %xmm6,(%rsp)
|
||||||
|
movdqu %xmm7,0x10(%rsp)
|
||||||
|
movdqu %xmm8,0x20(%rsp)
|
||||||
|
movdqu %xmm9,0x30(%rsp)
|
||||||
|
movdqu %xmm10,0x40(%rsp)
|
||||||
|
movdqu %xmm11,0x50(%rsp)
|
||||||
|
movdqu %xmm12,0x60(%rsp)
|
||||||
|
movdqu %xmm13,0x70(%rsp)
|
||||||
|
movdqu %xmm14,0x80(%rsp)
|
||||||
|
movdqu %xmm15,0x90(%rsp)
|
||||||
|
callq cplx_ifft16_avx_fma_amd64
|
||||||
|
movdqu (%rsp),%xmm6
|
||||||
|
movdqu 0x10(%rsp),%xmm7
|
||||||
|
movdqu 0x20(%rsp),%xmm8
|
||||||
|
movdqu 0x30(%rsp),%xmm9
|
||||||
|
movdqu 0x40(%rsp),%xmm10
|
||||||
|
movdqu 0x50(%rsp),%xmm11
|
||||||
|
movdqu 0x60(%rsp),%xmm12
|
||||||
|
movdqu 0x70(%rsp),%xmm13
|
||||||
|
movdqu 0x80(%rsp),%xmm14
|
||||||
|
movdqu 0x90(%rsp),%xmm15
|
||||||
|
addq $0x100,%rsp
|
||||||
|
popq %rsi
|
||||||
|
popq %rdi
|
||||||
|
retq
|
||||||
|
|
||||||
|
# shifted FFT over X^16-i
|
||||||
|
# 1st argument (rdi) contains 16 complexes
|
||||||
|
# 2nd argument (rsi) contains: 8 complexes
|
||||||
|
# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma
|
||||||
|
# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta)
|
||||||
|
# j = sqrt(i), k=sqrt(j)
|
||||||
|
|
||||||
|
cplx_ifft16_avx_fma_amd64:
|
||||||
|
vmovupd (%rdi),%ymm8 # load data into registers %ymm8 -> %ymm15
|
||||||
|
vmovupd 0x20(%rdi),%ymm9
|
||||||
|
vmovupd 0x40(%rdi),%ymm10
|
||||||
|
vmovupd 0x60(%rdi),%ymm11
|
||||||
|
vmovupd 0x80(%rdi),%ymm12
|
||||||
|
vmovupd 0xa0(%rdi),%ymm13
|
||||||
|
vmovupd 0xc0(%rdi),%ymm14
|
||||||
|
vmovupd 0xe0(%rdi),%ymm15
|
||||||
|
|
||||||
|
.fourth_pass:
|
||||||
|
vmovupd 0(%rsi),%ymm0 /* gamma */
|
||||||
|
vmovupd 32(%rsi),%ymm2 /* delta */
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||||
|
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||||
|
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||||
|
vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5
|
||||||
|
vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7
|
||||||
|
vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13
|
||||||
|
vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15
|
||||||
|
vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||||
|
vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||||
|
vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12
|
||||||
|
vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14
|
||||||
|
vsubpd %ymm4,%ymm8,%ymm12 # tw: to mul by gamma
|
||||||
|
vsubpd %ymm5,%ymm9,%ymm13 # itw: to mul by i.gamma
|
||||||
|
vsubpd %ymm6,%ymm10,%ymm14 # tw: to mul by delta
|
||||||
|
vsubpd %ymm7,%ymm11,%ymm15 # itw: to mul by i.delta
|
||||||
|
vaddpd %ymm4,%ymm8,%ymm8
|
||||||
|
vaddpd %ymm5,%ymm9,%ymm9
|
||||||
|
vaddpd %ymm6,%ymm10,%ymm10
|
||||||
|
vaddpd %ymm7,%ymm11,%ymm11
|
||||||
|
vshufpd $5, %ymm12, %ymm12, %ymm4
|
||||||
|
vshufpd $5, %ymm13, %ymm13, %ymm5
|
||||||
|
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||||
|
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||||
|
vmulpd %ymm4,%ymm1,%ymm4
|
||||||
|
vmulpd %ymm5,%ymm0,%ymm5
|
||||||
|
vmulpd %ymm6,%ymm3,%ymm6
|
||||||
|
vmulpd %ymm7,%ymm2,%ymm7
|
||||||
|
vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4
|
||||||
|
vfmsubadd231pd %ymm13, %ymm1, %ymm5
|
||||||
|
vfmaddsub231pd %ymm14, %ymm2, %ymm6
|
||||||
|
vfmsubadd231pd %ymm15, %ymm3, %ymm7
|
||||||
|
|
||||||
|
vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma
|
||||||
|
vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma
|
||||||
|
vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta
|
||||||
|
vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta
|
||||||
|
vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12
|
||||||
|
vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14
|
||||||
|
vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||||
|
vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||||
|
|
||||||
|
|
||||||
|
.third_pass:
|
||||||
|
vmovupd 64(%rsi),%xmm0 /* gamma */
|
||||||
|
vmovupd 80(%rsi),%xmm2 /* delta */
|
||||||
|
vinsertf128 $1, %xmm0, %ymm0, %ymm0
|
||||||
|
vinsertf128 $1, %xmm2, %ymm2, %ymm2
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||||
|
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||||
|
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||||
|
vsubpd %ymm9,%ymm8,%ymm4
|
||||||
|
vsubpd %ymm11,%ymm10,%ymm5
|
||||||
|
vsubpd %ymm13,%ymm12,%ymm6
|
||||||
|
vsubpd %ymm15,%ymm14,%ymm7
|
||||||
|
vaddpd %ymm9,%ymm8,%ymm8
|
||||||
|
vaddpd %ymm11,%ymm10,%ymm10
|
||||||
|
vaddpd %ymm13,%ymm12,%ymm12
|
||||||
|
vaddpd %ymm15,%ymm14,%ymm14
|
||||||
|
vshufpd $5, %ymm4, %ymm4, %ymm9
|
||||||
|
vshufpd $5, %ymm5, %ymm5, %ymm11
|
||||||
|
vshufpd $5, %ymm6, %ymm6, %ymm13
|
||||||
|
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||||
|
vmulpd %ymm9,%ymm1,%ymm9
|
||||||
|
vmulpd %ymm11,%ymm0,%ymm11
|
||||||
|
vmulpd %ymm13,%ymm3,%ymm13
|
||||||
|
vmulpd %ymm15,%ymm2,%ymm15
|
||||||
|
vfmaddsub231pd %ymm4, %ymm0, %ymm9 # ymm9 = (ymm0 * ymm4) +/- ymm9
|
||||||
|
vfmsubadd231pd %ymm5, %ymm1, %ymm11
|
||||||
|
vfmaddsub231pd %ymm6, %ymm2, %ymm13
|
||||||
|
vfmsubadd231pd %ymm7, %ymm3, %ymm15
|
||||||
|
|
||||||
|
.second_pass:
|
||||||
|
vmovupd 96(%rsi),%xmm0 /* omri */
|
||||||
|
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||||
|
vsubpd %ymm10,%ymm8,%ymm4
|
||||||
|
vsubpd %ymm11,%ymm9,%ymm5
|
||||||
|
vsubpd %ymm14,%ymm12,%ymm6
|
||||||
|
vsubpd %ymm15,%ymm13,%ymm7
|
||||||
|
vaddpd %ymm10,%ymm8,%ymm8
|
||||||
|
vaddpd %ymm11,%ymm9,%ymm9
|
||||||
|
vaddpd %ymm14,%ymm12,%ymm12
|
||||||
|
vaddpd %ymm15,%ymm13,%ymm13
|
||||||
|
vshufpd $5, %ymm4, %ymm4, %ymm10
|
||||||
|
vshufpd $5, %ymm5, %ymm5, %ymm11
|
||||||
|
vshufpd $5, %ymm6, %ymm6, %ymm14
|
||||||
|
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||||
|
vmulpd %ymm10,%ymm1,%ymm10
|
||||||
|
vmulpd %ymm11,%ymm1,%ymm11
|
||||||
|
vmulpd %ymm14,%ymm0,%ymm14
|
||||||
|
vmulpd %ymm15,%ymm0,%ymm15
|
||||||
|
vfmaddsub231pd %ymm4, %ymm0, %ymm10 # ymm10 = (ymm0 * ymm4) +/- ymm10
|
||||||
|
vfmaddsub231pd %ymm5, %ymm0, %ymm11
|
||||||
|
vfmsubadd231pd %ymm6, %ymm1, %ymm14
|
||||||
|
vfmsubadd231pd %ymm7, %ymm1, %ymm15
|
||||||
|
|
||||||
|
.first_pass:
|
||||||
|
vmovupd 112(%rsi),%xmm0 /* omri */
|
||||||
|
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||||
|
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||||
|
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||||
|
vsubpd %ymm12,%ymm8,%ymm4
|
||||||
|
vsubpd %ymm13,%ymm9,%ymm5
|
||||||
|
vsubpd %ymm14,%ymm10,%ymm6
|
||||||
|
vsubpd %ymm15,%ymm11,%ymm7
|
||||||
|
vaddpd %ymm12,%ymm8,%ymm8
|
||||||
|
vaddpd %ymm13,%ymm9,%ymm9
|
||||||
|
vaddpd %ymm14,%ymm10,%ymm10
|
||||||
|
vaddpd %ymm15,%ymm11,%ymm11
|
||||||
|
vshufpd $5, %ymm4, %ymm4, %ymm12
|
||||||
|
vshufpd $5, %ymm5, %ymm5, %ymm13
|
||||||
|
vshufpd $5, %ymm6, %ymm6, %ymm14
|
||||||
|
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||||
|
vmulpd %ymm12,%ymm1,%ymm12
|
||||||
|
vmulpd %ymm13,%ymm1,%ymm13
|
||||||
|
vmulpd %ymm14,%ymm1,%ymm14
|
||||||
|
vmulpd %ymm15,%ymm1,%ymm15
|
||||||
|
vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12
|
||||||
|
vfmaddsub231pd %ymm5, %ymm0, %ymm13
|
||||||
|
vfmaddsub231pd %ymm6, %ymm0, %ymm14
|
||||||
|
vfmaddsub231pd %ymm7, %ymm0, %ymm15
|
||||||
|
|
||||||
|
.save_and_return:
|
||||||
|
vmovupd %ymm8,(%rdi)
|
||||||
|
vmovupd %ymm9,0x20(%rdi)
|
||||||
|
vmovupd %ymm10,0x40(%rdi)
|
||||||
|
vmovupd %ymm11,0x60(%rdi)
|
||||||
|
vmovupd %ymm12,0x80(%rdi)
|
||||||
|
vmovupd %ymm13,0xa0(%rdi)
|
||||||
|
vmovupd %ymm14,0xc0(%rdi)
|
||||||
|
vmovupd %ymm15,0xe0(%rdi)
|
||||||
|
ret
|
||||||
267
spqlios/lib/spqlios/cplx/cplx_ifft_avx2_fma.c
Normal file
267
spqlios/lib/spqlios/cplx/cplx_ifft_avx2_fma.c
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
#include <immintrin.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
typedef double D4MEM[4];
|
||||||
|
typedef double D2MEM[2];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief complex ifft via bfs strategy (for m between 2 and 8)
|
||||||
|
* @param dat the data to run the algorithm on
|
||||||
|
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||||
|
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||||
|
*/
|
||||||
|
void cplx_ifft_avx2_fma_bfs_2(D4MEM* dat, const D2MEM** omga, uint32_t m) {
|
||||||
|
double* data = (double*)dat;
|
||||||
|
D4MEM* const finaldd = (D4MEM*)(data + 2 * m);
|
||||||
|
{
|
||||||
|
// loop with h = 1
|
||||||
|
// we do not do any particular optimization in this loop,
|
||||||
|
// since this function is only called for small dimensions
|
||||||
|
D4MEM* dd = (D4MEM*)data;
|
||||||
|
do {
|
||||||
|
/*
|
||||||
|
BEGIN_TEMPLATE
|
||||||
|
const __m256d ab% = _mm256_loadu_pd(dd[0+2*%]);
|
||||||
|
const __m256d cd% = _mm256_loadu_pd(dd[1+2*%]);
|
||||||
|
const __m256d ac% = _mm256_permute2f128_pd(ab%, cd%, 0b100000);
|
||||||
|
const __m256d bd% = _mm256_permute2f128_pd(ab%, cd%, 0b110001);
|
||||||
|
const __m256d sum% = _mm256_add_pd(ac%, bd%);
|
||||||
|
const __m256d diff% = _mm256_sub_pd(ac%, bd%);
|
||||||
|
const __m256d diffbar% = _mm256_shuffle_pd(diff%, diff%, 5);
|
||||||
|
const __m256d om% = _mm256_load_pd((*omg)[0+%]);
|
||||||
|
const __m256d omre% = _mm256_unpacklo_pd(om%, om%);
|
||||||
|
const __m256d omim% = _mm256_unpackhi_pd(om%, om%);
|
||||||
|
const __m256d t1% = _mm256_mul_pd(diffbar%, omim%);
|
||||||
|
const __m256d t2% = _mm256_fmaddsub_pd(diff%, omre%, t1%);
|
||||||
|
const __m256d newab% = _mm256_permute2f128_pd(sum%, t2%, 0b100000);
|
||||||
|
const __m256d newcd% = _mm256_permute2f128_pd(sum%, t2%, 0b110001);
|
||||||
|
_mm256_storeu_pd(dd[0+2*%], newab%);
|
||||||
|
_mm256_storeu_pd(dd[1+2*%], newcd%);
|
||||||
|
dd += 2*@;
|
||||||
|
*omg += 2*@;
|
||||||
|
END_TEMPLATE
|
||||||
|
*/
|
||||||
|
// BEGIN_INTERLEAVE 1
|
||||||
|
const __m256d ab0 = _mm256_loadu_pd(dd[0 + 2 * 0]);
|
||||||
|
const __m256d cd0 = _mm256_loadu_pd(dd[1 + 2 * 0]);
|
||||||
|
const __m256d ac0 = _mm256_permute2f128_pd(ab0, cd0, 0b100000);
|
||||||
|
const __m256d bd0 = _mm256_permute2f128_pd(ab0, cd0, 0b110001);
|
||||||
|
const __m256d sum0 = _mm256_add_pd(ac0, bd0);
|
||||||
|
const __m256d diff0 = _mm256_sub_pd(ac0, bd0);
|
||||||
|
const __m256d diffbar0 = _mm256_shuffle_pd(diff0, diff0, 5);
|
||||||
|
const __m256d om0 = _mm256_load_pd((*omga)[0 + 0]);
|
||||||
|
const __m256d omre0 = _mm256_unpacklo_pd(om0, om0);
|
||||||
|
const __m256d omim0 = _mm256_unpackhi_pd(om0, om0);
|
||||||
|
const __m256d t10 = _mm256_mul_pd(diffbar0, omim0);
|
||||||
|
const __m256d t20 = _mm256_fmaddsub_pd(diff0, omre0, t10);
|
||||||
|
const __m256d newab0 = _mm256_permute2f128_pd(sum0, t20, 0b100000);
|
||||||
|
const __m256d newcd0 = _mm256_permute2f128_pd(sum0, t20, 0b110001);
|
||||||
|
_mm256_storeu_pd(dd[0 + 2 * 0], newab0);
|
||||||
|
_mm256_storeu_pd(dd[1 + 2 * 0], newcd0);
|
||||||
|
dd += 2 * 1;
|
||||||
|
*omga += 2 * 1;
|
||||||
|
// END_INTERLEAVE
|
||||||
|
} while (dd < finaldd);
|
||||||
|
#if 0
|
||||||
|
printf("c after first: ");
|
||||||
|
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||||
|
printf("%.6lf %.6lf ",ddata[ii][0],ddata[ii][1]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
// general case
|
||||||
|
const uint32_t ms2 = m >> 1;
|
||||||
|
for (uint32_t _2nblock = 2; _2nblock <= ms2; _2nblock <<= 1) {
|
||||||
|
// _2nblock = h in ref code
|
||||||
|
uint32_t nblock = _2nblock >> 1; // =h/2 in ref code
|
||||||
|
D4MEM* dd = (D4MEM*)data;
|
||||||
|
do {
|
||||||
|
const __m256d om = _mm256_load_pd((*omga)[0]);
|
||||||
|
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||||
|
const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om));
|
||||||
|
D4MEM* const ddend = (dd + nblock);
|
||||||
|
D4MEM* ddmid = ddend;
|
||||||
|
do {
|
||||||
|
const __m256d a = _mm256_loadu_pd(dd[0]);
|
||||||
|
const __m256d b = _mm256_loadu_pd(ddmid[0]);
|
||||||
|
const __m256d newa = _mm256_add_pd(a, b);
|
||||||
|
_mm256_storeu_pd(dd[0], newa);
|
||||||
|
const __m256d diff = _mm256_sub_pd(a, b);
|
||||||
|
const __m256d t1 = _mm256_mul_pd(diff, omre);
|
||||||
|
const __m256d bardiff = _mm256_shuffle_pd(diff, diff, 5);
|
||||||
|
const __m256d t2 = _mm256_fmadd_pd(bardiff, omim, t1);
|
||||||
|
_mm256_storeu_pd(ddmid[0], t2);
|
||||||
|
dd += 1;
|
||||||
|
ddmid += 1;
|
||||||
|
} while (dd < ddend);
|
||||||
|
dd += nblock;
|
||||||
|
*omga += 2;
|
||||||
|
} while (dd < finaldd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief complex fft via bfs strategy (for m >= 16)
|
||||||
|
* @param dat the data to run the algorithm on
|
||||||
|
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||||
|
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||||
|
*/
|
||||||
|
void cplx_ifft_avx2_fma_bfs_16(D4MEM* dat, const D2MEM** omga, uint32_t m) {
|
||||||
|
double* data = (double*)dat;
|
||||||
|
D4MEM* const finaldd = (D4MEM*)(data + 2 * m);
|
||||||
|
// base iteration when h = _2nblock == 8
|
||||||
|
{
|
||||||
|
D4MEM* dd = (D4MEM*)data;
|
||||||
|
do {
|
||||||
|
cplx_ifft16_avx_fma(dd, *omga);
|
||||||
|
dd += 8;
|
||||||
|
*omga += 8;
|
||||||
|
} while (dd < finaldd);
|
||||||
|
}
|
||||||
|
// general case
|
||||||
|
const uint32_t log2m = _mm_popcnt_u32(m-1); //_popcnt32(m-1); //log2(m);
|
||||||
|
uint32_t h=16;
|
||||||
|
if (log2m % 2 == 1) {
|
||||||
|
uint32_t nblock = h >> 1; // =h/2 in ref code
|
||||||
|
D4MEM* dd = (D4MEM*)data;
|
||||||
|
do {
|
||||||
|
const __m128d om1 = _mm_loadu_pd((*omga)[0]);
|
||||||
|
const __m256d om = _mm256_set_m128d(om1,om1);
|
||||||
|
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||||
|
const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om));
|
||||||
|
D4MEM* const ddend = (dd + nblock);
|
||||||
|
D4MEM* ddmid = ddend;
|
||||||
|
do {
|
||||||
|
const __m256d a = _mm256_loadu_pd(dd[0]);
|
||||||
|
const __m256d b = _mm256_loadu_pd(ddmid[0]);
|
||||||
|
const __m256d newa = _mm256_add_pd(a, b);
|
||||||
|
_mm256_storeu_pd(dd[0], newa);
|
||||||
|
const __m256d diff = _mm256_sub_pd(a, b);
|
||||||
|
const __m256d t1 = _mm256_mul_pd(diff, omre);
|
||||||
|
const __m256d bardiff = _mm256_shuffle_pd(diff, diff, 5);
|
||||||
|
const __m256d t2 = _mm256_fmadd_pd(bardiff, omim, t1);
|
||||||
|
_mm256_storeu_pd(ddmid[0], t2);
|
||||||
|
dd += 1;
|
||||||
|
ddmid += 1;
|
||||||
|
} while (dd < ddend);
|
||||||
|
dd += nblock;
|
||||||
|
*omga += 1;
|
||||||
|
} while (dd < finaldd);
|
||||||
|
h = 32;
|
||||||
|
}
|
||||||
|
for (; h < m; h <<= 2) {
|
||||||
|
// _2nblock = h in ref code
|
||||||
|
uint32_t nblock = h >> 1; // =h/2 in ref code
|
||||||
|
D4MEM* dd0 = (D4MEM*)data;
|
||||||
|
do {
|
||||||
|
const __m128d om1 = _mm_loadu_pd((*omga)[0]);
|
||||||
|
const __m128d al1 = _mm_loadu_pd((*omga)[1]);
|
||||||
|
const __m256d om = _mm256_set_m128d(om1,om1);
|
||||||
|
const __m256d al = _mm256_set_m128d(al1,al1);
|
||||||
|
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||||
|
const __m256d omim = _mm256_unpackhi_pd(om, om);
|
||||||
|
const __m256d alre = _mm256_unpacklo_pd(al, al);
|
||||||
|
const __m256d alim = _mm256_unpackhi_pd(al, al);
|
||||||
|
D4MEM* const ddend = (dd0 + nblock);
|
||||||
|
D4MEM* dd1 = ddend;
|
||||||
|
D4MEM* dd2 = dd1 + nblock;
|
||||||
|
D4MEM* dd3 = dd2 + nblock;
|
||||||
|
do {
|
||||||
|
__m256d u0 = _mm256_loadu_pd(dd0[0]);
|
||||||
|
__m256d u1 = _mm256_loadu_pd(dd1[0]);
|
||||||
|
__m256d u2 = _mm256_loadu_pd(dd2[0]);
|
||||||
|
__m256d u3 = _mm256_loadu_pd(dd3[0]);
|
||||||
|
__m256d u4 = _mm256_add_pd(u0, u1);
|
||||||
|
__m256d u5 = _mm256_sub_pd(u0, u1);
|
||||||
|
__m256d u6 = _mm256_add_pd(u2, u3);
|
||||||
|
__m256d u7 = _mm256_sub_pd(u2, u3);
|
||||||
|
u0 = _mm256_shuffle_pd(u5, u5, 5);
|
||||||
|
u2 = _mm256_shuffle_pd(u7, u7, 5);
|
||||||
|
u1 = _mm256_mul_pd(u0, omim);
|
||||||
|
u3 = _mm256_mul_pd(u2, omre);
|
||||||
|
u5 = _mm256_fmaddsub_pd(u5,omre, u1);
|
||||||
|
u7 = _mm256_fmsubadd_pd(u7,omim, u3);
|
||||||
|
//////
|
||||||
|
u0 = _mm256_add_pd(u4,u6);
|
||||||
|
u1 = _mm256_add_pd(u5,u7);
|
||||||
|
u2 = _mm256_sub_pd(u4,u6);
|
||||||
|
u3 = _mm256_sub_pd(u5,u7);
|
||||||
|
u4 = _mm256_shuffle_pd(u2, u2, 5);
|
||||||
|
u5 = _mm256_shuffle_pd(u3, u3, 5);
|
||||||
|
u6 = _mm256_mul_pd(u4, alim);
|
||||||
|
u7 = _mm256_mul_pd(u5, alim);
|
||||||
|
u2 = _mm256_fmaddsub_pd(u2,alre, u6);
|
||||||
|
u3 = _mm256_fmaddsub_pd(u3,alre, u7);
|
||||||
|
///////
|
||||||
|
_mm256_storeu_pd(dd0[0], u0);
|
||||||
|
_mm256_storeu_pd(dd1[0], u1);
|
||||||
|
_mm256_storeu_pd(dd2[0], u2);
|
||||||
|
_mm256_storeu_pd(dd3[0], u3);
|
||||||
|
dd0 += 1;
|
||||||
|
dd1 += 1;
|
||||||
|
dd2 += 1;
|
||||||
|
dd3 += 1;
|
||||||
|
} while (dd0 < ddend);
|
||||||
|
dd0 += 3*nblock;
|
||||||
|
*omga += 2;
|
||||||
|
} while (dd0 < finaldd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief complex ifft via dfs recursion (for m >= 16)
|
||||||
|
* @param dat the data to run the algorithm on
|
||||||
|
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||||
|
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||||
|
*/
|
||||||
|
void cplx_ifft_avx2_fma_rec_16(D4MEM* dat, const D2MEM** omga, uint32_t m) {
|
||||||
|
if (m <= 8) return cplx_ifft_avx2_fma_bfs_2(dat, omga, m);
|
||||||
|
if (m <= 2048) return cplx_ifft_avx2_fma_bfs_16(dat, omga, m);
|
||||||
|
const uint32_t _2nblock = m >> 1; // = h in ref code
|
||||||
|
const uint32_t nblock = _2nblock >> 1; // =h/2 in ref code
|
||||||
|
cplx_ifft_avx2_fma_rec_16(dat, omga, _2nblock);
|
||||||
|
cplx_ifft_avx2_fma_rec_16(dat + nblock, omga, _2nblock);
|
||||||
|
{
|
||||||
|
// final iteration
|
||||||
|
D4MEM* dd = dat;
|
||||||
|
const __m256d om = _mm256_load_pd((*omga)[0]);
|
||||||
|
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||||
|
const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om));
|
||||||
|
D4MEM* const ddend = (dd + nblock);
|
||||||
|
D4MEM* ddmid = ddend;
|
||||||
|
do {
|
||||||
|
const __m256d a = _mm256_loadu_pd(dd[0]);
|
||||||
|
const __m256d b = _mm256_loadu_pd(ddmid[0]);
|
||||||
|
const __m256d newa = _mm256_add_pd(a, b);
|
||||||
|
_mm256_storeu_pd(dd[0], newa);
|
||||||
|
const __m256d diff = _mm256_sub_pd(a, b);
|
||||||
|
const __m256d t1 = _mm256_mul_pd(diff, omre);
|
||||||
|
const __m256d bardiff = _mm256_shuffle_pd(diff, diff, 5);
|
||||||
|
const __m256d t2 = _mm256_fmadd_pd(bardiff, omim, t1);
|
||||||
|
_mm256_storeu_pd(ddmid[0], t2);
|
||||||
|
dd += 1;
|
||||||
|
ddmid += 1;
|
||||||
|
} while (dd < ddend);
|
||||||
|
*omga += 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief complex ifft via best strategy (for m>=1)
|
||||||
|
* @param dat the data to run the algorithm on: m complex numbers
|
||||||
|
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||||
|
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* precomp, void* d) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const D2MEM* omg = (D2MEM*)precomp->powomegas;
|
||||||
|
if (m <= 1) return;
|
||||||
|
if (m <= 8) return cplx_ifft_avx2_fma_bfs_2(d, &omg, m);
|
||||||
|
if (m <= 2048) return cplx_ifft_avx2_fma_bfs_16(d, &omg, m);
|
||||||
|
cplx_ifft_avx2_fma_rec_16(d, &omg, m);
|
||||||
|
}
|
||||||
315
spqlios/lib/spqlios/cplx/cplx_ifft_ref.c
Normal file
315
spqlios/lib/spqlios/cplx/cplx_ifft_ref.c
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "../commons_private.h"
|
||||||
|
#include "cplx_fft.h"
|
||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
/** @brief (a,b) <- (a+b,omegabar.(a-b)) */
|
||||||
|
void invctwiddle(CPLX a, CPLX b, const CPLX ombar) {
|
||||||
|
double diffre = a[0] - b[0];
|
||||||
|
double diffim = a[1] - b[1];
|
||||||
|
a[0] = a[0] + b[0];
|
||||||
|
a[1] = a[1] + b[1];
|
||||||
|
b[0] = diffre * ombar[0] - diffim * ombar[1];
|
||||||
|
b[1] = diffre * ombar[1] + diffim * ombar[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief (a,b) <- (a+b,-i.omegabar(a-b)) */
|
||||||
|
void invcitwiddle(CPLX a, CPLX b, const CPLX ombar) {
|
||||||
|
double diffre = a[0] - b[0];
|
||||||
|
double diffim = a[1] - b[1];
|
||||||
|
a[0] = a[0] + b[0];
|
||||||
|
a[1] = a[1] + b[1];
|
||||||
|
//-i(x+iy)=-ix+y
|
||||||
|
b[0] = diffre * ombar[1] + diffim * ombar[0];
|
||||||
|
b[1] = -diffre * ombar[0] + diffim * ombar[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief exp(-i.2pi.x) */
|
||||||
|
void cplx_set_e2pimx(CPLX res, double x) {
|
||||||
|
res[0] = m_accurate_cos(2 * M_PI * x);
|
||||||
|
res[1] = -m_accurate_sin(2 * M_PI * x);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,...
|
||||||
|
* essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/
|
||||||
|
double fracrevbits(uint32_t i);
|
||||||
|
/** @brief fft modulo X^m-exp(i.2pi.entry+pwr) -- reference code */
|
||||||
|
void cplx_ifft_naive(const uint32_t m, const double entry_pwr, CPLX* data) {
|
||||||
|
if (m == 1) return;
|
||||||
|
const double pom = entry_pwr / 2.;
|
||||||
|
const uint32_t h = m / 2;
|
||||||
|
CPLX cpom;
|
||||||
|
cplx_set_e2pimx(cpom, pom);
|
||||||
|
// do the recursive calls
|
||||||
|
cplx_ifft_naive(h, pom, data);
|
||||||
|
cplx_ifft_naive(h, pom + 0.5, data + h);
|
||||||
|
// apply the inverse twiddle factors
|
||||||
|
for (uint64_t i = 0; i < h; ++i) {
|
||||||
|
invctwiddle(data[i], data[i + h], cpom);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cplx_ifft16_precomp(const double entry_pwr, CPLX** omg) {
|
||||||
|
static const double j_pow = 1. / 8.;
|
||||||
|
static const double k_pow = 1. / 16.;
|
||||||
|
const double pom = entry_pwr / 2.;
|
||||||
|
const double pom_2 = entry_pwr / 4.;
|
||||||
|
const double pom_4 = entry_pwr / 8.;
|
||||||
|
const double pom_8 = entry_pwr / 16.;
|
||||||
|
cplx_set_e2pimx((*omg)[0], pom_8);
|
||||||
|
cplx_set_e2pimx((*omg)[1], pom_8 + j_pow);
|
||||||
|
cplx_set_e2pimx((*omg)[2], pom_8 + k_pow);
|
||||||
|
cplx_set_e2pimx((*omg)[3], pom_8 + j_pow + k_pow);
|
||||||
|
cplx_set_e2pimx((*omg)[4], pom_4);
|
||||||
|
cplx_set_e2pimx((*omg)[5], pom_4 + j_pow);
|
||||||
|
cplx_set_e2pimx((*omg)[6], pom_2);
|
||||||
|
cplx_set_e2pimx((*omg)[7], pom);
|
||||||
|
*omg += 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief iFFT modulo X^16-omega^2 (in registers)
|
||||||
|
* @param data contains 16 complexes
|
||||||
|
* @param omegabar 8 complexes in this order:
|
||||||
|
* gammabar,jb.gammabar,kb.gammabar,kbjb.gammabar,
|
||||||
|
* betabar,jb.betabar,alphabar,omegabar
|
||||||
|
* alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta)
|
||||||
|
* jb = sqrt(ib), kb=sqrt(jb)
|
||||||
|
*/
|
||||||
|
void cplx_ifft16_ref(void* data, const void* omegabar) {
|
||||||
|
CPLX* d = data;
|
||||||
|
const CPLX* om = omegabar;
|
||||||
|
// fourth pass inverse
|
||||||
|
invctwiddle(d[0], d[1], om[0]);
|
||||||
|
invcitwiddle(d[2], d[3], om[0]);
|
||||||
|
invctwiddle(d[4], d[5], om[1]);
|
||||||
|
invcitwiddle(d[6], d[7], om[1]);
|
||||||
|
invctwiddle(d[8], d[9], om[2]);
|
||||||
|
invcitwiddle(d[10], d[11], om[2]);
|
||||||
|
invctwiddle(d[12], d[13], om[3]);
|
||||||
|
invcitwiddle(d[14], d[15], om[3]);
|
||||||
|
// third pass inverse
|
||||||
|
invctwiddle(d[0], d[2], om[4]);
|
||||||
|
invctwiddle(d[1], d[3], om[4]);
|
||||||
|
invcitwiddle(d[4], d[6], om[4]);
|
||||||
|
invcitwiddle(d[5], d[7], om[4]);
|
||||||
|
invctwiddle(d[8], d[10], om[5]);
|
||||||
|
invctwiddle(d[9], d[11], om[5]);
|
||||||
|
invcitwiddle(d[12], d[14], om[5]);
|
||||||
|
invcitwiddle(d[13], d[15], om[5]);
|
||||||
|
// second pass inverse
|
||||||
|
invctwiddle(d[0], d[4], om[6]);
|
||||||
|
invctwiddle(d[1], d[5], om[6]);
|
||||||
|
invctwiddle(d[2], d[6], om[6]);
|
||||||
|
invctwiddle(d[3], d[7], om[6]);
|
||||||
|
invcitwiddle(d[8], d[12], om[6]);
|
||||||
|
invcitwiddle(d[9], d[13], om[6]);
|
||||||
|
invcitwiddle(d[10], d[14], om[6]);
|
||||||
|
invcitwiddle(d[11], d[15], om[6]);
|
||||||
|
// first pass
|
||||||
|
for (uint64_t i = 0; i < 8; ++i) {
|
||||||
|
invctwiddle(d[0 + i], d[8 + i], om[7]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cplx_ifft_ref_bfs_2(CPLX* dat, const CPLX** omg, uint32_t m) {
|
||||||
|
CPLX* const dend = dat + m;
|
||||||
|
for (CPLX* d = dat; d < dend; d += 2) {
|
||||||
|
split_fft_last_ref(d, (*omg)[0]);
|
||||||
|
*omg += 1;
|
||||||
|
}
|
||||||
|
#if 0
|
||||||
|
printf("after first: ");
|
||||||
|
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||||
|
printf("%.6lf %.6lf ",data[ii][0],data[ii][1]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
#endif
|
||||||
|
int32_t Ms2 = m / 2;
|
||||||
|
for (int32_t h = 2; h <= Ms2; h <<= 1) {
|
||||||
|
for (CPLX* d = dat; d < dend; d += 2 * h) {
|
||||||
|
if (memcmp((*omg)[0], (*omg)[1], 8) != 0) abort();
|
||||||
|
cplx_split_fft_ref(h, d, **omg);
|
||||||
|
*omg += 2;
|
||||||
|
}
|
||||||
|
#if 0
|
||||||
|
printf("after split %d: ", h);
|
||||||
|
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||||
|
printf("%.6lf %.6lf ",data[ii][0],data[ii][1]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cplx_ifft_ref_bfs_16(CPLX* dat, const CPLX** omg, uint32_t m) {
|
||||||
|
const uint64_t log2m = log2(m);
|
||||||
|
CPLX* const dend = dat + m;
|
||||||
|
// h=1,2,4,8 use the 16-dim macroblock
|
||||||
|
for (CPLX* d = dat; d < dend; d += 16) {
|
||||||
|
cplx_ifft16_ref(d, *omg);
|
||||||
|
*omg += 8;
|
||||||
|
}
|
||||||
|
#if 0
|
||||||
|
printf("after first: ");
|
||||||
|
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||||
|
printf("%.6lf %.6lf ",data[ii][0],data[ii][1]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
#endif
|
||||||
|
int32_t h = 16;
|
||||||
|
if (log2m % 2 != 0) {
|
||||||
|
// if parity needs it, uses one regular twiddle
|
||||||
|
for (CPLX* d = dat; d < dend; d += 2 * h) {
|
||||||
|
cplx_split_fft_ref(h, d, **omg);
|
||||||
|
*omg += 1;
|
||||||
|
}
|
||||||
|
h = 32;
|
||||||
|
}
|
||||||
|
// h=16,...,2*floor(Ms2/2) use the bitwiddle
|
||||||
|
for (; h < m; h <<= 2) {
|
||||||
|
for (CPLX* d = dat; d < dend; d += 4 * h) {
|
||||||
|
cplx_bisplit_fft_ref(h, d, *omg);
|
||||||
|
*omg += 2;
|
||||||
|
}
|
||||||
|
#if 0
|
||||||
|
printf("after split %d: ", h);
|
||||||
|
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||||
|
printf("%.6lf %.6lf ",data[ii][0],data[ii][1]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill_cplx_ifft_omegas_bfs_16(const double entry_pwr, CPLX** omg, uint32_t m) {
|
||||||
|
const uint64_t log2m = log2(m);
|
||||||
|
double pwr = entry_pwr * 16. / m;
|
||||||
|
{
|
||||||
|
// h=8
|
||||||
|
for (uint32_t i = 0; i < m / 16; i++) {
|
||||||
|
cplx_ifft16_precomp(pwr + fracrevbits(i), omg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int32_t h = 16;
|
||||||
|
if (log2m % 2 != 0) {
|
||||||
|
// if parity needs it, uses one regular twiddle
|
||||||
|
for (uint32_t i = 0; i < m / (2 * h); i++) {
|
||||||
|
cplx_set_e2pimx(omg[0][0], pwr + fracrevbits(i) / 2.);
|
||||||
|
*omg += 1;
|
||||||
|
}
|
||||||
|
pwr *= 2.;
|
||||||
|
h = 32;
|
||||||
|
}
|
||||||
|
for (; h < m; h <<= 2) {
|
||||||
|
for (uint32_t i = 0; i < m / (2 * h); i+=2) {
|
||||||
|
cplx_set_e2pimx(omg[0][0], pwr + fracrevbits(i) / 2.);
|
||||||
|
cplx_set_e2pimx(omg[0][1], 2.*pwr + fracrevbits(i));
|
||||||
|
*omg += 2;
|
||||||
|
}
|
||||||
|
pwr *= 4.;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill_cplx_ifft_omegas_bfs_2(const double entry_pwr, CPLX** omg, uint32_t m) {
|
||||||
|
double pom = entry_pwr / m;
|
||||||
|
{
|
||||||
|
// h=1
|
||||||
|
for (uint32_t i = 0; i < m / 2; i++) {
|
||||||
|
cplx_set_e2pimx((*omg)[0], pom + fracrevbits(i) / 2.);
|
||||||
|
*omg += 1; // optim function reads by 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int32_t h = 2; h <= m / 2; h <<= 1) {
|
||||||
|
pom *= 2;
|
||||||
|
for (uint32_t i = 0; i < m / (2 * h); i++) {
|
||||||
|
cplx_set_e2pimx(omg[0][0], pom + fracrevbits(i) / 2.);
|
||||||
|
cplx_set(omg[0][1], omg[0][0]);
|
||||||
|
*omg += 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill_cplx_ifft_omegas_rec_16(const double entry_pwr, CPLX** omg, uint32_t m) {
|
||||||
|
if (m == 1) return;
|
||||||
|
if (m <= 8) return fill_cplx_ifft_omegas_bfs_2(entry_pwr, omg, m);
|
||||||
|
if (m <= 2048) return fill_cplx_ifft_omegas_bfs_16(entry_pwr, omg, m);
|
||||||
|
double pom = entry_pwr / 2.;
|
||||||
|
uint32_t h = m / 2;
|
||||||
|
fill_cplx_ifft_omegas_rec_16(pom, omg, h);
|
||||||
|
fill_cplx_ifft_omegas_rec_16(pom + 0.5, omg, h);
|
||||||
|
cplx_set_e2pimx((*omg)[0], pom);
|
||||||
|
cplx_set((*omg)[1], (*omg)[0]);
|
||||||
|
*omg += 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
void cplx_ifft_ref_rec_16(CPLX* dat, const CPLX** omg, uint32_t m) {
|
||||||
|
if (m == 1) return;
|
||||||
|
if (m <= 8) return cplx_ifft_ref_bfs_2(dat, omg, m);
|
||||||
|
if (m <= 2048) return cplx_ifft_ref_bfs_16(dat, omg, m);
|
||||||
|
const uint32_t h = m / 2;
|
||||||
|
cplx_ifft_ref_rec_16(dat, omg, h);
|
||||||
|
cplx_ifft_ref_rec_16(dat + h, omg, h);
|
||||||
|
if (memcmp((*omg)[0], (*omg)[1], 8) != 0) abort();
|
||||||
|
cplx_split_fft_ref(h, dat, **omg);
|
||||||
|
*omg += 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_ifft_ref(const CPLX_IFFT_PRECOMP* precomp, void* d) {
|
||||||
|
CPLX* data = (CPLX*)d;
|
||||||
|
const int32_t m = precomp->m;
|
||||||
|
const CPLX* omg = (CPLX*)precomp->powomegas;
|
||||||
|
if (m == 1) return;
|
||||||
|
if (m <= 8) return cplx_ifft_ref_bfs_2(data, &omg, m);
|
||||||
|
if (m <= 2048) return cplx_ifft_ref_bfs_16(data, &omg, m);
|
||||||
|
cplx_ifft_ref_rec_16(data, &omg, m);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT CPLX_IFFT_PRECOMP* new_cplx_ifft_precomp(uint32_t m, uint32_t num_buffers) {
|
||||||
|
const uint64_t OMG_SPACE = ceilto64b(2 * m * sizeof(CPLX));
|
||||||
|
const uint64_t BUF_SIZE = ceilto64b(m * sizeof(CPLX));
|
||||||
|
void* reps = malloc(sizeof(CPLX_IFFT_PRECOMP) + 63 // padding
|
||||||
|
+ OMG_SPACE // tables
|
||||||
|
+ num_buffers * BUF_SIZE // buffers
|
||||||
|
);
|
||||||
|
uint64_t aligned_addr = ceilto64b((uint64_t) reps + sizeof(CPLX_IFFT_PRECOMP));
|
||||||
|
CPLX_IFFT_PRECOMP* r = (CPLX_IFFT_PRECOMP*)reps;
|
||||||
|
r->m = m;
|
||||||
|
r->buf_size = BUF_SIZE;
|
||||||
|
r->powomegas = (double*)aligned_addr;
|
||||||
|
r->aligned_buffers = (void*)(aligned_addr + OMG_SPACE);
|
||||||
|
// fill in powomegas
|
||||||
|
CPLX* omg = (CPLX*)r->powomegas;
|
||||||
|
fill_cplx_ifft_omegas_rec_16(0.25, &omg, m);
|
||||||
|
if (((uint64_t)omg) - aligned_addr > OMG_SPACE) abort();
|
||||||
|
{
|
||||||
|
if (m <= 4) {
|
||||||
|
// currently, we do not have any acceletated
|
||||||
|
// implementation for m<=4
|
||||||
|
r->function = cplx_ifft_ref;
|
||||||
|
} else if (CPU_SUPPORTS("fma")) {
|
||||||
|
r->function = cplx_ifft_avx2_fma;
|
||||||
|
} else {
|
||||||
|
r->function = cplx_ifft_ref;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return reps;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void* cplx_ifft_precomp_get_buffer(const CPLX_IFFT_PRECOMP* itables, uint32_t buffer_index) {
|
||||||
|
return (uint8_t*) itables->aligned_buffers + buffer_index * itables->buf_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* itables, void* data) {
|
||||||
|
itables->function(itables, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_ifft_simple(uint32_t m, void* data) {
|
||||||
|
static CPLX_IFFT_PRECOMP* p[31] = {0};
|
||||||
|
CPLX_IFFT_PRECOMP** f = p + log2m(m);
|
||||||
|
if (!*f) *f = new_cplx_ifft_precomp(m, 0);
|
||||||
|
(*f)->function(*f, data);
|
||||||
|
}
|
||||||
|
|
||||||
0
spqlios/lib/spqlios/cplx/spqlios_cplx_fft.c
Normal file
0
spqlios/lib/spqlios/cplx/spqlios_cplx_fft.c
Normal file
138
spqlios/lib/spqlios/ext/neon_accel/macrof.h
Normal file
138
spqlios/lib/spqlios/ext/neon_accel/macrof.h
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
/*
|
||||||
|
* This file is extracted from the implementation of the FFT on Arm64/Neon
|
||||||
|
* available in https://github.com/cothan/Falcon-Arm (neon/macrof.h).
|
||||||
|
* =============================================================================
|
||||||
|
* Copyright (c) 2022 by Cryptographic Engineering Research Group (CERG)
|
||||||
|
* ECE Department, George Mason University
|
||||||
|
* Fairfax, VA, U.S.A.
|
||||||
|
* @author: Duc Tri Nguyen dnguye69@gmu.edu, cothannguyen@gmail.com
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* =============================================================================
|
||||||
|
*
|
||||||
|
* This 64-bit Floating point NEON macro x1 has not been modified and is provided as is.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MACROF_H
|
||||||
|
#define MACROF_H
|
||||||
|
|
||||||
|
#include <arm_neon.h>
|
||||||
|
|
||||||
|
// c <= addr x1
|
||||||
|
#define vload(c, addr) c = vld1q_f64(addr);
|
||||||
|
// c <= addr interleave 2
|
||||||
|
#define vload2(c, addr) c = vld2q_f64(addr);
|
||||||
|
// c <= addr interleave 4
|
||||||
|
#define vload4(c, addr) c = vld4q_f64(addr);
|
||||||
|
|
||||||
|
#define vstore(addr, c) vst1q_f64(addr, c);
|
||||||
|
// addr <= c
|
||||||
|
#define vstore2(addr, c) vst2q_f64(addr, c);
|
||||||
|
// addr <= c
|
||||||
|
#define vstore4(addr, c) vst4q_f64(addr, c);
|
||||||
|
|
||||||
|
// c <= addr x2
|
||||||
|
#define vloadx2(c, addr) c = vld1q_f64_x2(addr);
|
||||||
|
// c <= addr x3
|
||||||
|
#define vloadx3(c, addr) c = vld1q_f64_x3(addr);
|
||||||
|
|
||||||
|
// addr <= c
|
||||||
|
#define vstorex2(addr, c) vst1q_f64_x2(addr, c);
|
||||||
|
|
||||||
|
// c = a - b
|
||||||
|
#define vfsub(c, a, b) c = vsubq_f64(a, b);
|
||||||
|
|
||||||
|
// c = a + b
|
||||||
|
#define vfadd(c, a, b) c = vaddq_f64(a, b);
|
||||||
|
|
||||||
|
// c = a * b
|
||||||
|
#define vfmul(c, a, b) c = vmulq_f64(a, b);
|
||||||
|
|
||||||
|
// c = a * n (n is constant)
|
||||||
|
#define vfmuln(c, a, n) c = vmulq_n_f64(a, n);
|
||||||
|
|
||||||
|
// Swap from a|b to b|a
|
||||||
|
#define vswap(c, a) c = vextq_f64(a, a, 1);
|
||||||
|
|
||||||
|
// c = a * b[i]
|
||||||
|
#define vfmul_lane(c, a, b, i) c = vmulq_laneq_f64(a, b, i);
|
||||||
|
|
||||||
|
// c = 1/a
|
||||||
|
#define vfinv(c, a) c = vdivq_f64(vdupq_n_f64(1.0), a);
|
||||||
|
|
||||||
|
// c = -a
|
||||||
|
#define vfneg(c, a) c = vnegq_f64(a);
|
||||||
|
|
||||||
|
#define transpose_f64(a, b, t, ia, ib, it) \
|
||||||
|
t.val[it] = a.val[ia]; \
|
||||||
|
a.val[ia] = vzip1q_f64(a.val[ia], b.val[ib]); \
|
||||||
|
b.val[ib] = vzip2q_f64(t.val[it], b.val[ib]);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* c = a + jb
|
||||||
|
* c[0] = a[0] - b[1]
|
||||||
|
* c[1] = a[1] + b[0]
|
||||||
|
*/
|
||||||
|
#define vfcaddj(c, a, b) c = vcaddq_rot90_f64(a, b);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* c = a - jb
|
||||||
|
* c[0] = a[0] + b[1]
|
||||||
|
* c[1] = a[1] - b[0]
|
||||||
|
*/
|
||||||
|
#define vfcsubj(c, a, b) c = vcaddq_rot270_f64(a, b);
|
||||||
|
|
||||||
|
// c[0] = c[0] + b[0]*a[0], c[1] = c[1] + b[1]*a[0]
|
||||||
|
#define vfcmla(c, a, b) c = vcmlaq_f64(c, a, b);
|
||||||
|
|
||||||
|
// c[0] = c[0] - b[1]*a[1], c[1] = c[1] + b[0]*a[1]
|
||||||
|
#define vfcmla_90(c, a, b) c = vcmlaq_rot90_f64(c, a, b);
|
||||||
|
|
||||||
|
// c[0] = c[0] - b[0]*a[0], c[1] = c[1] - b[1]*a[0]
|
||||||
|
#define vfcmla_180(c, a, b) c = vcmlaq_rot180_f64(c, a, b);
|
||||||
|
|
||||||
|
// c[0] = c[0] + b[1]*a[1], c[1] = c[1] - b[0]*a[1]
|
||||||
|
#define vfcmla_270(c, a, b) c = vcmlaq_rot270_f64(c, a, b);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Complex MUL: c = a*b
|
||||||
|
* c[0] = a[0]*b[0] - a[1]*b[1]
|
||||||
|
* c[1] = a[0]*b[1] + a[1]*b[0]
|
||||||
|
*/
|
||||||
|
#define FPC_CMUL(c, a, b) \
|
||||||
|
c = vmulq_laneq_f64(b, a, 0); \
|
||||||
|
c = vcmlaq_rot90_f64(c, a, b);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Complex MUL: c = a * conjugate(b) = a * (b[0], -b[1])
|
||||||
|
* c[0] = b[0]*a[0] + b[1]*a[1]
|
||||||
|
* c[1] = + b[0]*a[1] - b[1]*a[0]
|
||||||
|
*/
|
||||||
|
#define FPC_CMUL_CONJ(c, a, b) \
|
||||||
|
c = vmulq_laneq_f64(a, b, 0); \
|
||||||
|
c = vcmlaq_rot270_f64(c, b, a);
|
||||||
|
|
||||||
|
#if FMA == 1
|
||||||
|
// d = c + a *b
|
||||||
|
#define vfmla(d, c, a, b) d = vfmaq_f64(c, a, b);
|
||||||
|
// d = c - a * b
|
||||||
|
#define vfmls(d, c, a, b) d = vfmsq_f64(c, a, b);
|
||||||
|
// d = c + a * b[i]
|
||||||
|
#define vfmla_lane(d, c, a, b, i) d = vfmaq_laneq_f64(c, a, b, i);
|
||||||
|
// d = c - a * b[i]
|
||||||
|
#define vfmls_lane(d, c, a, b, i) d = vfmsq_laneq_f64(c, a, b, i);
|
||||||
|
|
||||||
|
#else
|
||||||
|
// d = c + a *b
|
||||||
|
#define vfmla(d, c, a, b) d = vaddq_f64(c, vmulq_f64(a, b));
|
||||||
|
// d = c - a *b
|
||||||
|
#define vfmls(d, c, a, b) d = vsubq_f64(c, vmulq_f64(a, b));
|
||||||
|
// d = c + a * b[i]
|
||||||
|
#define vfmla_lane(d, c, a, b, i) \
|
||||||
|
d = vaddq_f64(c, vmulq_laneq_f64(a, b, i));
|
||||||
|
|
||||||
|
#define vfmls_lane(d, c, a, b, i) \
|
||||||
|
d = vsubq_f64(c, vmulq_laneq_f64(a, b, i));
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
428
spqlios/lib/spqlios/ext/neon_accel/macrofx4.h
Normal file
428
spqlios/lib/spqlios/ext/neon_accel/macrofx4.h
Normal file
@@ -0,0 +1,428 @@
|
|||||||
|
/*
|
||||||
|
* This file is extracted from the implementation of the FFT on Arm64/Neon
|
||||||
|
* available in https://github.com/cothan/Falcon-Arm (neon/macrof.h).
|
||||||
|
* =============================================================================
|
||||||
|
* Copyright (c) 2022 by Cryptographic Engineering Research Group (CERG)
|
||||||
|
* ECE Department, George Mason University
|
||||||
|
* Fairfax, VA, U.S.A.
|
||||||
|
* @author: Duc Tri Nguyen dnguye69@gmu.edu, cothannguyen@gmail.com
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* =============================================================================
|
||||||
|
*
|
||||||
|
* This 64-bit Floating point NEON macro x4 has not been modified and is provided as is.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MACROFX4_H
|
||||||
|
#define MACROFX4_H
|
||||||
|
|
||||||
|
#include <arm_neon.h>
|
||||||
|
#include "macrof.h"
|
||||||
|
|
||||||
|
#define vloadx4(c, addr) c = vld1q_f64_x4(addr);
|
||||||
|
|
||||||
|
#define vstorex4(addr, c) vst1q_f64_x4(addr, c);
|
||||||
|
|
||||||
|
#define vfdupx4(c, constant) \
|
||||||
|
c.val[0] = vdupq_n_f64(constant); \
|
||||||
|
c.val[1] = vdupq_n_f64(constant); \
|
||||||
|
c.val[2] = vdupq_n_f64(constant); \
|
||||||
|
c.val[3] = vdupq_n_f64(constant);
|
||||||
|
|
||||||
|
#define vfnegx4(c, a) \
|
||||||
|
c.val[0] = vnegq_f64(a.val[0]); \
|
||||||
|
c.val[1] = vnegq_f64(a.val[1]); \
|
||||||
|
c.val[2] = vnegq_f64(a.val[2]); \
|
||||||
|
c.val[3] = vnegq_f64(a.val[3]);
|
||||||
|
|
||||||
|
#define vfmulnx4(c, a, n) \
|
||||||
|
c.val[0] = vmulq_n_f64(a.val[0], n); \
|
||||||
|
c.val[1] = vmulq_n_f64(a.val[1], n); \
|
||||||
|
c.val[2] = vmulq_n_f64(a.val[2], n); \
|
||||||
|
c.val[3] = vmulq_n_f64(a.val[3], n);
|
||||||
|
|
||||||
|
// c = a - b
|
||||||
|
#define vfsubx4(c, a, b) \
|
||||||
|
c.val[0] = vsubq_f64(a.val[0], b.val[0]); \
|
||||||
|
c.val[1] = vsubq_f64(a.val[1], b.val[1]); \
|
||||||
|
c.val[2] = vsubq_f64(a.val[2], b.val[2]); \
|
||||||
|
c.val[3] = vsubq_f64(a.val[3], b.val[3]);
|
||||||
|
|
||||||
|
// c = a + b
|
||||||
|
#define vfaddx4(c, a, b) \
|
||||||
|
c.val[0] = vaddq_f64(a.val[0], b.val[0]); \
|
||||||
|
c.val[1] = vaddq_f64(a.val[1], b.val[1]); \
|
||||||
|
c.val[2] = vaddq_f64(a.val[2], b.val[2]); \
|
||||||
|
c.val[3] = vaddq_f64(a.val[3], b.val[3]);
|
||||||
|
|
||||||
|
#define vfmulx4(c, a, b) \
|
||||||
|
c.val[0] = vmulq_f64(a.val[0], b.val[0]); \
|
||||||
|
c.val[1] = vmulq_f64(a.val[1], b.val[1]); \
|
||||||
|
c.val[2] = vmulq_f64(a.val[2], b.val[2]); \
|
||||||
|
c.val[3] = vmulq_f64(a.val[3], b.val[3]);
|
||||||
|
|
||||||
|
#define vfmulx4_i(c, a, b) \
|
||||||
|
c.val[0] = vmulq_f64(a.val[0], b); \
|
||||||
|
c.val[1] = vmulq_f64(a.val[1], b); \
|
||||||
|
c.val[2] = vmulq_f64(a.val[2], b); \
|
||||||
|
c.val[3] = vmulq_f64(a.val[3], b);
|
||||||
|
|
||||||
|
#define vfinvx4(c, a) \
|
||||||
|
c.val[0] = vdivq_f64(vdupq_n_f64(1.0), a.val[0]); \
|
||||||
|
c.val[1] = vdivq_f64(vdupq_n_f64(1.0), a.val[1]); \
|
||||||
|
c.val[2] = vdivq_f64(vdupq_n_f64(1.0), a.val[2]); \
|
||||||
|
c.val[3] = vdivq_f64(vdupq_n_f64(1.0), a.val[3]);
|
||||||
|
|
||||||
|
#define vfcvtx4(c, a) \
|
||||||
|
c.val[0] = vcvtq_f64_s64(a.val[0]); \
|
||||||
|
c.val[1] = vcvtq_f64_s64(a.val[1]); \
|
||||||
|
c.val[2] = vcvtq_f64_s64(a.val[2]); \
|
||||||
|
c.val[3] = vcvtq_f64_s64(a.val[3]);
|
||||||
|
|
||||||
|
#define vfmlax4(d, c, a, b) \
|
||||||
|
vfmla(d.val[0], c.val[0], a.val[0], b.val[0]); \
|
||||||
|
vfmla(d.val[1], c.val[1], a.val[1], b.val[1]); \
|
||||||
|
vfmla(d.val[2], c.val[2], a.val[2], b.val[2]); \
|
||||||
|
vfmla(d.val[3], c.val[3], a.val[3], b.val[3]);
|
||||||
|
|
||||||
|
#define vfmlsx4(d, c, a, b) \
|
||||||
|
vfmls(d.val[0], c.val[0], a.val[0], b.val[0]); \
|
||||||
|
vfmls(d.val[1], c.val[1], a.val[1], b.val[1]); \
|
||||||
|
vfmls(d.val[2], c.val[2], a.val[2], b.val[2]); \
|
||||||
|
vfmls(d.val[3], c.val[3], a.val[3], b.val[3]);
|
||||||
|
|
||||||
|
#define vfrintx4(c, a) \
|
||||||
|
c.val[0] = vcvtnq_s64_f64(a.val[0]); \
|
||||||
|
c.val[1] = vcvtnq_s64_f64(a.val[1]); \
|
||||||
|
c.val[2] = vcvtnq_s64_f64(a.val[2]); \
|
||||||
|
c.val[3] = vcvtnq_s64_f64(a.val[3]);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Wrapper for FFT, split/merge and poly_float.c
|
||||||
|
*/
|
||||||
|
|
||||||
|
#define FPC_MUL(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
vfmul(d_re, a_re, b_re); \
|
||||||
|
vfmls(d_re, d_re, a_im, b_im); \
|
||||||
|
vfmul(d_im, a_re, b_im); \
|
||||||
|
vfmla(d_im, d_im, a_im, b_re);
|
||||||
|
|
||||||
|
#define FPC_MULx2(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
vfmul(d_re.val[0], a_re.val[0], b_re.val[0]); \
|
||||||
|
vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \
|
||||||
|
vfmul(d_re.val[1], a_re.val[1], b_re.val[1]); \
|
||||||
|
vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \
|
||||||
|
vfmul(d_im.val[0], a_re.val[0], b_im.val[0]); \
|
||||||
|
vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \
|
||||||
|
vfmul(d_im.val[1], a_re.val[1], b_im.val[1]); \
|
||||||
|
vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]);
|
||||||
|
|
||||||
|
#define FPC_MULx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
vfmul(d_re.val[0], a_re.val[0], b_re.val[0]); \
|
||||||
|
vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \
|
||||||
|
vfmul(d_re.val[1], a_re.val[1], b_re.val[1]); \
|
||||||
|
vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \
|
||||||
|
vfmul(d_re.val[2], a_re.val[2], b_re.val[2]); \
|
||||||
|
vfmls(d_re.val[2], d_re.val[2], a_im.val[2], b_im.val[2]); \
|
||||||
|
vfmul(d_re.val[3], a_re.val[3], b_re.val[3]); \
|
||||||
|
vfmls(d_re.val[3], d_re.val[3], a_im.val[3], b_im.val[3]); \
|
||||||
|
vfmul(d_im.val[0], a_re.val[0], b_im.val[0]); \
|
||||||
|
vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \
|
||||||
|
vfmul(d_im.val[1], a_re.val[1], b_im.val[1]); \
|
||||||
|
vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]); \
|
||||||
|
vfmul(d_im.val[2], a_re.val[2], b_im.val[2]); \
|
||||||
|
vfmla(d_im.val[2], d_im.val[2], a_im.val[2], b_re.val[2]); \
|
||||||
|
vfmul(d_im.val[3], a_re.val[3], b_im.val[3]); \
|
||||||
|
vfmla(d_im.val[3], d_im.val[3], a_im.val[3], b_re.val[3]);
|
||||||
|
|
||||||
|
#define FPC_MLA(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
vfmla(d_re, d_re, a_re, b_re); \
|
||||||
|
vfmls(d_re, d_re, a_im, b_im); \
|
||||||
|
vfmla(d_im, d_im, a_re, b_im); \
|
||||||
|
vfmla(d_im, d_im, a_im, b_re);
|
||||||
|
|
||||||
|
#define FPC_MLAx2(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \
|
||||||
|
vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \
|
||||||
|
vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \
|
||||||
|
vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \
|
||||||
|
vfmla(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \
|
||||||
|
vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \
|
||||||
|
vfmla(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \
|
||||||
|
vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]);
|
||||||
|
|
||||||
|
#define FPC_MLAx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \
|
||||||
|
vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \
|
||||||
|
vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \
|
||||||
|
vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \
|
||||||
|
vfmla(d_re.val[2], d_re.val[2], a_re.val[2], b_re.val[2]); \
|
||||||
|
vfmls(d_re.val[2], d_re.val[2], a_im.val[2], b_im.val[2]); \
|
||||||
|
vfmla(d_re.val[3], d_re.val[3], a_re.val[3], b_re.val[3]); \
|
||||||
|
vfmls(d_re.val[3], d_re.val[3], a_im.val[3], b_im.val[3]); \
|
||||||
|
vfmla(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \
|
||||||
|
vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \
|
||||||
|
vfmla(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \
|
||||||
|
vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]); \
|
||||||
|
vfmla(d_im.val[2], d_im.val[2], a_re.val[2], b_im.val[2]); \
|
||||||
|
vfmla(d_im.val[2], d_im.val[2], a_im.val[2], b_re.val[2]); \
|
||||||
|
vfmla(d_im.val[3], d_im.val[3], a_re.val[3], b_im.val[3]); \
|
||||||
|
vfmla(d_im.val[3], d_im.val[3], a_im.val[3], b_re.val[3]);
|
||||||
|
|
||||||
|
#define FPC_MUL_CONJx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
vfmul(d_re.val[0], b_im.val[0], a_im.val[0]); \
|
||||||
|
vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \
|
||||||
|
vfmul(d_re.val[1], b_im.val[1], a_im.val[1]); \
|
||||||
|
vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \
|
||||||
|
vfmul(d_re.val[2], b_im.val[2], a_im.val[2]); \
|
||||||
|
vfmla(d_re.val[2], d_re.val[2], a_re.val[2], b_re.val[2]); \
|
||||||
|
vfmul(d_re.val[3], b_im.val[3], a_im.val[3]); \
|
||||||
|
vfmla(d_re.val[3], d_re.val[3], a_re.val[3], b_re.val[3]); \
|
||||||
|
vfmul(d_im.val[0], b_re.val[0], a_im.val[0]); \
|
||||||
|
vfmls(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \
|
||||||
|
vfmul(d_im.val[1], b_re.val[1], a_im.val[1]); \
|
||||||
|
vfmls(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \
|
||||||
|
vfmul(d_im.val[2], b_re.val[2], a_im.val[2]); \
|
||||||
|
vfmls(d_im.val[2], d_im.val[2], a_re.val[2], b_im.val[2]); \
|
||||||
|
vfmul(d_im.val[3], b_re.val[3], a_im.val[3]); \
|
||||||
|
vfmls(d_im.val[3], d_im.val[3], a_re.val[3], b_im.val[3]);
|
||||||
|
|
||||||
|
#define FPC_MLA_CONJx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
vfmla(d_re.val[0], d_re.val[0], b_im.val[0], a_im.val[0]); \
|
||||||
|
vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \
|
||||||
|
vfmla(d_re.val[1], d_re.val[1], b_im.val[1], a_im.val[1]); \
|
||||||
|
vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \
|
||||||
|
vfmla(d_re.val[2], d_re.val[2], b_im.val[2], a_im.val[2]); \
|
||||||
|
vfmla(d_re.val[2], d_re.val[2], a_re.val[2], b_re.val[2]); \
|
||||||
|
vfmla(d_re.val[3], d_re.val[3], b_im.val[3], a_im.val[3]); \
|
||||||
|
vfmla(d_re.val[3], d_re.val[3], a_re.val[3], b_re.val[3]); \
|
||||||
|
vfmla(d_im.val[0], d_im.val[0], b_re.val[0], a_im.val[0]); \
|
||||||
|
vfmls(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \
|
||||||
|
vfmla(d_im.val[1], d_im.val[1], b_re.val[1], a_im.val[1]); \
|
||||||
|
vfmls(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \
|
||||||
|
vfmla(d_im.val[2], d_im.val[2], b_re.val[2], a_im.val[2]); \
|
||||||
|
vfmls(d_im.val[2], d_im.val[2], a_re.val[2], b_im.val[2]); \
|
||||||
|
vfmla(d_im.val[3], d_im.val[3], b_re.val[3], a_im.val[3]); \
|
||||||
|
vfmls(d_im.val[3], d_im.val[3], a_re.val[3], b_im.val[3]);
|
||||||
|
|
||||||
|
#define FPC_MUL_LANE(d_re, d_im, a_re, a_im, b_re_im) \
|
||||||
|
vfmul_lane(d_re, a_re, b_re_im, 0); \
|
||||||
|
vfmls_lane(d_re, d_re, a_im, b_re_im, 1); \
|
||||||
|
vfmul_lane(d_im, a_re, b_re_im, 1); \
|
||||||
|
vfmla_lane(d_im, d_im, a_im, b_re_im, 0);
|
||||||
|
|
||||||
|
#define FPC_MUL_LANEx4(d_re, d_im, a_re, a_im, b_re_im) \
|
||||||
|
vfmul_lane(d_re.val[0], a_re.val[0], b_re_im, 0); \
|
||||||
|
vfmls_lane(d_re.val[0], d_re.val[0], a_im.val[0], b_re_im, 1); \
|
||||||
|
vfmul_lane(d_re.val[1], a_re.val[1], b_re_im, 0); \
|
||||||
|
vfmls_lane(d_re.val[1], d_re.val[1], a_im.val[1], b_re_im, 1); \
|
||||||
|
vfmul_lane(d_re.val[2], a_re.val[2], b_re_im, 0); \
|
||||||
|
vfmls_lane(d_re.val[2], d_re.val[2], a_im.val[2], b_re_im, 1); \
|
||||||
|
vfmul_lane(d_re.val[3], a_re.val[3], b_re_im, 0); \
|
||||||
|
vfmls_lane(d_re.val[3], d_re.val[3], a_im.val[3], b_re_im, 1); \
|
||||||
|
vfmul_lane(d_im.val[0], a_re.val[0], b_re_im, 1); \
|
||||||
|
vfmla_lane(d_im.val[0], d_im.val[0], a_im.val[0], b_re_im, 0); \
|
||||||
|
vfmul_lane(d_im.val[1], a_re.val[1], b_re_im, 1); \
|
||||||
|
vfmla_lane(d_im.val[1], d_im.val[1], a_im.val[1], b_re_im, 0); \
|
||||||
|
vfmul_lane(d_im.val[2], a_re.val[2], b_re_im, 1); \
|
||||||
|
vfmla_lane(d_im.val[2], d_im.val[2], a_im.val[2], b_re_im, 0); \
|
||||||
|
vfmul_lane(d_im.val[3], a_re.val[3], b_re_im, 1); \
|
||||||
|
vfmla_lane(d_im.val[3], d_im.val[3], a_im.val[3], b_re_im, 0);
|
||||||
|
|
||||||
|
#define FWD_TOP(t_re, t_im, b_re, b_im, zeta_re, zeta_im) \
|
||||||
|
FPC_MUL(t_re, t_im, b_re, b_im, zeta_re, zeta_im);
|
||||||
|
|
||||||
|
#define FWD_TOP_LANE(t_re, t_im, b_re, b_im, zeta) \
|
||||||
|
FPC_MUL_LANE(t_re, t_im, b_re, b_im, zeta);
|
||||||
|
|
||||||
|
#define FWD_TOP_LANEx4(t_re, t_im, b_re, b_im, zeta) \
|
||||||
|
FPC_MUL_LANEx4(t_re, t_im, b_re, b_im, zeta);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* FPC
|
||||||
|
*/
|
||||||
|
|
||||||
|
#define FPC_SUB(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
d_re = vsubq_f64(a_re, b_re); \
|
||||||
|
d_im = vsubq_f64(a_im, b_im);
|
||||||
|
|
||||||
|
#define FPC_SUBx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
d_re.val[0] = vsubq_f64(a_re.val[0], b_re.val[0]); \
|
||||||
|
d_im.val[0] = vsubq_f64(a_im.val[0], b_im.val[0]); \
|
||||||
|
d_re.val[1] = vsubq_f64(a_re.val[1], b_re.val[1]); \
|
||||||
|
d_im.val[1] = vsubq_f64(a_im.val[1], b_im.val[1]); \
|
||||||
|
d_re.val[2] = vsubq_f64(a_re.val[2], b_re.val[2]); \
|
||||||
|
d_im.val[2] = vsubq_f64(a_im.val[2], b_im.val[2]); \
|
||||||
|
d_re.val[3] = vsubq_f64(a_re.val[3], b_re.val[3]); \
|
||||||
|
d_im.val[3] = vsubq_f64(a_im.val[3], b_im.val[3]);
|
||||||
|
|
||||||
|
#define FPC_ADD(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
d_re = vaddq_f64(a_re, b_re); \
|
||||||
|
d_im = vaddq_f64(a_im, b_im);
|
||||||
|
|
||||||
|
#define FPC_ADDx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
d_re.val[0] = vaddq_f64(a_re.val[0], b_re.val[0]); \
|
||||||
|
d_im.val[0] = vaddq_f64(a_im.val[0], b_im.val[0]); \
|
||||||
|
d_re.val[1] = vaddq_f64(a_re.val[1], b_re.val[1]); \
|
||||||
|
d_im.val[1] = vaddq_f64(a_im.val[1], b_im.val[1]); \
|
||||||
|
d_re.val[2] = vaddq_f64(a_re.val[2], b_re.val[2]); \
|
||||||
|
d_im.val[2] = vaddq_f64(a_im.val[2], b_im.val[2]); \
|
||||||
|
d_re.val[3] = vaddq_f64(a_re.val[3], b_re.val[3]); \
|
||||||
|
d_im.val[3] = vaddq_f64(a_im.val[3], b_im.val[3]);
|
||||||
|
|
||||||
|
#define FWD_BOT(a_re, a_im, b_re, b_im, t_re, t_im) \
|
||||||
|
FPC_SUB(b_re, b_im, a_re, a_im, t_re, t_im); \
|
||||||
|
FPC_ADD(a_re, a_im, a_re, a_im, t_re, t_im);
|
||||||
|
|
||||||
|
#define FWD_BOTx4(a_re, a_im, b_re, b_im, t_re, t_im) \
|
||||||
|
FPC_SUBx4(b_re, b_im, a_re, a_im, t_re, t_im); \
|
||||||
|
FPC_ADDx4(a_re, a_im, a_re, a_im, t_re, t_im);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* FPC_J
|
||||||
|
*/
|
||||||
|
|
||||||
|
#define FPC_ADDJ(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
d_re = vsubq_f64(a_re, b_im); \
|
||||||
|
d_im = vaddq_f64(a_im, b_re);
|
||||||
|
|
||||||
|
#define FPC_ADDJx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
d_re.val[0] = vsubq_f64(a_re.val[0], b_im.val[0]); \
|
||||||
|
d_im.val[0] = vaddq_f64(a_im.val[0], b_re.val[0]); \
|
||||||
|
d_re.val[1] = vsubq_f64(a_re.val[1], b_im.val[1]); \
|
||||||
|
d_im.val[1] = vaddq_f64(a_im.val[1], b_re.val[1]); \
|
||||||
|
d_re.val[2] = vsubq_f64(a_re.val[2], b_im.val[2]); \
|
||||||
|
d_im.val[2] = vaddq_f64(a_im.val[2], b_re.val[2]); \
|
||||||
|
d_re.val[3] = vsubq_f64(a_re.val[3], b_im.val[3]); \
|
||||||
|
d_im.val[3] = vaddq_f64(a_im.val[3], b_re.val[3]);
|
||||||
|
|
||||||
|
#define FPC_SUBJ(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
d_re = vaddq_f64(a_re, b_im); \
|
||||||
|
d_im = vsubq_f64(a_im, b_re);
|
||||||
|
|
||||||
|
#define FPC_SUBJx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
d_re.val[0] = vaddq_f64(a_re.val[0], b_im.val[0]); \
|
||||||
|
d_im.val[0] = vsubq_f64(a_im.val[0], b_re.val[0]); \
|
||||||
|
d_re.val[1] = vaddq_f64(a_re.val[1], b_im.val[1]); \
|
||||||
|
d_im.val[1] = vsubq_f64(a_im.val[1], b_re.val[1]); \
|
||||||
|
d_re.val[2] = vaddq_f64(a_re.val[2], b_im.val[2]); \
|
||||||
|
d_im.val[2] = vsubq_f64(a_im.val[2], b_re.val[2]); \
|
||||||
|
d_re.val[3] = vaddq_f64(a_re.val[3], b_im.val[3]); \
|
||||||
|
d_im.val[3] = vsubq_f64(a_im.val[3], b_re.val[3]);
|
||||||
|
|
||||||
|
#define FWD_BOTJ(a_re, a_im, b_re, b_im, t_re, t_im) \
|
||||||
|
FPC_SUBJ(b_re, b_im, a_re, a_im, t_re, t_im); \
|
||||||
|
FPC_ADDJ(a_re, a_im, a_re, a_im, t_re, t_im);
|
||||||
|
|
||||||
|
#define FWD_BOTJx4(a_re, a_im, b_re, b_im, t_re, t_im) \
|
||||||
|
FPC_SUBJx4(b_re, b_im, a_re, a_im, t_re, t_im); \
|
||||||
|
FPC_ADDJx4(a_re, a_im, a_re, a_im, t_re, t_im);
|
||||||
|
|
||||||
|
//============== Inverse FFT
|
||||||
|
/*
|
||||||
|
* FPC_J
|
||||||
|
* a * conj(b)
|
||||||
|
* Original (without swap):
|
||||||
|
* d_re = b_im * a_im + a_re * b_re;
|
||||||
|
* d_im = b_re * a_im - a_re * b_im;
|
||||||
|
*/
|
||||||
|
#define FPC_MUL_BOTJ_LANE(d_re, d_im, a_re, a_im, b_re_im) \
|
||||||
|
vfmul_lane(d_re, a_re, b_re_im, 0); \
|
||||||
|
vfmla_lane(d_re, d_re, a_im, b_re_im, 1); \
|
||||||
|
vfmul_lane(d_im, a_im, b_re_im, 0); \
|
||||||
|
vfmls_lane(d_im, d_im, a_re, b_re_im, 1);
|
||||||
|
|
||||||
|
#define FPC_MUL_BOTJ_LANEx4(d_re, d_im, a_re, a_im, b_re_im) \
|
||||||
|
vfmul_lane(d_re.val[0], a_re.val[0], b_re_im, 0); \
|
||||||
|
vfmla_lane(d_re.val[0], d_re.val[0], a_im.val[0], b_re_im, 1); \
|
||||||
|
vfmul_lane(d_im.val[0], a_im.val[0], b_re_im, 0); \
|
||||||
|
vfmls_lane(d_im.val[0], d_im.val[0], a_re.val[0], b_re_im, 1); \
|
||||||
|
vfmul_lane(d_re.val[1], a_re.val[1], b_re_im, 0); \
|
||||||
|
vfmla_lane(d_re.val[1], d_re.val[1], a_im.val[1], b_re_im, 1); \
|
||||||
|
vfmul_lane(d_im.val[1], a_im.val[1], b_re_im, 0); \
|
||||||
|
vfmls_lane(d_im.val[1], d_im.val[1], a_re.val[1], b_re_im, 1); \
|
||||||
|
vfmul_lane(d_re.val[2], a_re.val[2], b_re_im, 0); \
|
||||||
|
vfmla_lane(d_re.val[2], d_re.val[2], a_im.val[2], b_re_im, 1); \
|
||||||
|
vfmul_lane(d_im.val[2], a_im.val[2], b_re_im, 0); \
|
||||||
|
vfmls_lane(d_im.val[2], d_im.val[2], a_re.val[2], b_re_im, 1); \
|
||||||
|
vfmul_lane(d_re.val[3], a_re.val[3], b_re_im, 0); \
|
||||||
|
vfmla_lane(d_re.val[3], d_re.val[3], a_im.val[3], b_re_im, 1); \
|
||||||
|
vfmul_lane(d_im.val[3], a_im.val[3], b_re_im, 0); \
|
||||||
|
vfmls_lane(d_im.val[3], d_im.val[3], a_re.val[3], b_re_im, 1);
|
||||||
|
|
||||||
|
#define FPC_MUL_BOTJ(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
vfmul(d_re, b_im, a_im); \
|
||||||
|
vfmla(d_re, d_re, a_re, b_re); \
|
||||||
|
vfmul(d_im, b_re, a_im); \
|
||||||
|
vfmls(d_im, d_im, a_re, b_im);
|
||||||
|
|
||||||
|
#define INV_TOPJ(t_re, t_im, a_re, a_im, b_re, b_im) \
|
||||||
|
FPC_SUB(t_re, t_im, a_re, a_im, b_re, b_im); \
|
||||||
|
FPC_ADD(a_re, a_im, a_re, a_im, b_re, b_im);
|
||||||
|
|
||||||
|
#define INV_TOPJx4(t_re, t_im, a_re, a_im, b_re, b_im) \
|
||||||
|
FPC_SUBx4(t_re, t_im, a_re, a_im, b_re, b_im); \
|
||||||
|
FPC_ADDx4(a_re, a_im, a_re, a_im, b_re, b_im);
|
||||||
|
|
||||||
|
#define INV_BOTJ(b_re, b_im, t_re, t_im, zeta_re, zeta_im) \
|
||||||
|
FPC_MUL_BOTJ(b_re, b_im, t_re, t_im, zeta_re, zeta_im);
|
||||||
|
|
||||||
|
#define INV_BOTJ_LANE(b_re, b_im, t_re, t_im, zeta) \
|
||||||
|
FPC_MUL_BOTJ_LANE(b_re, b_im, t_re, t_im, zeta);
|
||||||
|
|
||||||
|
#define INV_BOTJ_LANEx4(b_re, b_im, t_re, t_im, zeta) \
|
||||||
|
FPC_MUL_BOTJ_LANEx4(b_re, b_im, t_re, t_im, zeta);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* FPC_Jm
|
||||||
|
* a * -conj(b)
|
||||||
|
* d_re = a_re * b_im - a_im * b_re;
|
||||||
|
* d_im = a_im * b_im + a_re * b_re;
|
||||||
|
*/
|
||||||
|
#define FPC_MUL_BOTJm_LANE(d_re, d_im, a_re, a_im, b_re_im) \
|
||||||
|
vfmul_lane(d_re, a_re, b_re_im, 1); \
|
||||||
|
vfmls_lane(d_re, d_re, a_im, b_re_im, 0); \
|
||||||
|
vfmul_lane(d_im, a_re, b_re_im, 0); \
|
||||||
|
vfmla_lane(d_im, d_im, a_im, b_re_im, 1);
|
||||||
|
|
||||||
|
#define FPC_MUL_BOTJm_LANEx4(d_re, d_im, a_re, a_im, b_re_im) \
|
||||||
|
vfmul_lane(d_re.val[0], a_re.val[0], b_re_im, 1); \
|
||||||
|
vfmls_lane(d_re.val[0], d_re.val[0], a_im.val[0], b_re_im, 0); \
|
||||||
|
vfmul_lane(d_im.val[0], a_re.val[0], b_re_im, 0); \
|
||||||
|
vfmla_lane(d_im.val[0], d_im.val[0], a_im.val[0], b_re_im, 1); \
|
||||||
|
vfmul_lane(d_re.val[1], a_re.val[1], b_re_im, 1); \
|
||||||
|
vfmls_lane(d_re.val[1], d_re.val[1], a_im.val[1], b_re_im, 0); \
|
||||||
|
vfmul_lane(d_im.val[1], a_re.val[1], b_re_im, 0); \
|
||||||
|
vfmla_lane(d_im.val[1], d_im.val[1], a_im.val[1], b_re_im, 1); \
|
||||||
|
vfmul_lane(d_re.val[2], a_re.val[2], b_re_im, 1); \
|
||||||
|
vfmls_lane(d_re.val[2], d_re.val[2], a_im.val[2], b_re_im, 0); \
|
||||||
|
vfmul_lane(d_im.val[2], a_re.val[2], b_re_im, 0); \
|
||||||
|
vfmla_lane(d_im.val[2], d_im.val[2], a_im.val[2], b_re_im, 1); \
|
||||||
|
vfmul_lane(d_re.val[3], a_re.val[3], b_re_im, 1); \
|
||||||
|
vfmls_lane(d_re.val[3], d_re.val[3], a_im.val[3], b_re_im, 0); \
|
||||||
|
vfmul_lane(d_im.val[3], a_re.val[3], b_re_im, 0); \
|
||||||
|
vfmla_lane(d_im.val[3], d_im.val[3], a_im.val[3], b_re_im, 1);
|
||||||
|
|
||||||
|
#define FPC_MUL_BOTJm(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||||
|
vfmul(d_re, a_re, b_im); \
|
||||||
|
vfmls(d_re, d_re, a_im, b_re); \
|
||||||
|
vfmul(d_im, a_im, b_im); \
|
||||||
|
vfmla(d_im, d_im, a_re, b_re);
|
||||||
|
|
||||||
|
#define INV_TOPJm(t_re, t_im, a_re, a_im, b_re, b_im) \
|
||||||
|
FPC_SUB(t_re, t_im, b_re, b_im, a_re, a_im); \
|
||||||
|
FPC_ADD(a_re, a_im, a_re, a_im, b_re, b_im);
|
||||||
|
|
||||||
|
#define INV_TOPJmx4(t_re, t_im, a_re, a_im, b_re, b_im) \
|
||||||
|
FPC_SUBx4(t_re, t_im, b_re, b_im, a_re, a_im); \
|
||||||
|
FPC_ADDx4(a_re, a_im, a_re, a_im, b_re, b_im);
|
||||||
|
|
||||||
|
#define INV_BOTJm(b_re, b_im, t_re, t_im, zeta_re, zeta_im) \
|
||||||
|
FPC_MUL_BOTJm(b_re, b_im, t_re, t_im, zeta_re, zeta_im);
|
||||||
|
|
||||||
|
#define INV_BOTJm_LANE(b_re, b_im, t_re, t_im, zeta) \
|
||||||
|
FPC_MUL_BOTJm_LANE(b_re, b_im, t_re, t_im, zeta);
|
||||||
|
|
||||||
|
#define INV_BOTJm_LANEx4(b_re, b_im, t_re, t_im, zeta) \
|
||||||
|
FPC_MUL_BOTJm_LANEx4(b_re, b_im, t_re, t_im, zeta);
|
||||||
|
|
||||||
|
#endif
|
||||||
115
spqlios/lib/spqlios/q120/q120_arithmetic.h
Normal file
115
spqlios/lib/spqlios/q120/q120_arithmetic.h
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
#ifndef SPQLIOS_Q120_ARITHMETIC_H
|
||||||
|
#define SPQLIOS_Q120_ARITHMETIC_H
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "../commons.h"
|
||||||
|
#include "q120_common.h"
|
||||||
|
|
||||||
|
typedef struct _q120_mat1col_product_baa_precomp q120_mat1col_product_baa_precomp;
|
||||||
|
typedef struct _q120_mat1col_product_bbb_precomp q120_mat1col_product_bbb_precomp;
|
||||||
|
typedef struct _q120_mat1col_product_bbc_precomp q120_mat1col_product_bbc_precomp;
|
||||||
|
|
||||||
|
EXPORT q120_mat1col_product_baa_precomp* q120_new_vec_mat1col_product_baa_precomp();
|
||||||
|
EXPORT void q120_delete_vec_mat1col_product_baa_precomp(q120_mat1col_product_baa_precomp*);
|
||||||
|
EXPORT q120_mat1col_product_bbb_precomp* q120_new_vec_mat1col_product_bbb_precomp();
|
||||||
|
EXPORT void q120_delete_vec_mat1col_product_bbb_precomp(q120_mat1col_product_bbb_precomp*);
|
||||||
|
EXPORT q120_mat1col_product_bbc_precomp* q120_new_vec_mat1col_product_bbc_precomp();
|
||||||
|
EXPORT void q120_delete_vec_mat1col_product_bbc_precomp(q120_mat1col_product_bbc_precomp*);
|
||||||
|
|
||||||
|
// ell < 10000
|
||||||
|
EXPORT void q120_vec_mat1col_product_baa_ref(q120_mat1col_product_baa_precomp*, const uint64_t ell, q120b* const res,
|
||||||
|
const q120a* const x, const q120a* const y);
|
||||||
|
EXPORT void q120_vec_mat1col_product_bbb_ref(q120_mat1col_product_bbb_precomp*, const uint64_t ell, q120b* const res,
|
||||||
|
const q120b* const x, const q120b* const y);
|
||||||
|
EXPORT void q120_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp*, const uint64_t ell, q120b* const res,
|
||||||
|
const q120b* const x, const q120c* const y);
|
||||||
|
|
||||||
|
EXPORT void q120_vec_mat1col_product_baa_avx2(q120_mat1col_product_baa_precomp*, const uint64_t ell, q120b* const res,
|
||||||
|
const q120a* const x, const q120a* const y);
|
||||||
|
EXPORT void q120_vec_mat1col_product_bbb_avx2(q120_mat1col_product_bbb_precomp*, const uint64_t ell, q120b* const res,
|
||||||
|
const q120b* const x, const q120b* const y);
|
||||||
|
EXPORT void q120_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp*, const uint64_t ell, q120b* const res,
|
||||||
|
const q120b* const x, const q120c* const y);
|
||||||
|
|
||||||
|
EXPORT void q120x2_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120b* const x, const q120c* const y);
|
||||||
|
EXPORT void q120x2_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120b* const x, const q120c* const y);
|
||||||
|
EXPORT void q120x2_vec_mat2cols_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120b* const x, const q120c* const y);
|
||||||
|
EXPORT void q120x2_vec_mat2cols_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120b* const x, const q120c* const y);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief extract 1 q120x2 block from one q120 ntt vectors
|
||||||
|
* @param nn the size of each vector
|
||||||
|
* @param blk the block id to extract (<nn/2)
|
||||||
|
* @param dst the output: nrows q120x2's dst[i] = src[i](blk)
|
||||||
|
* @param src the input: nrows q120 ntt vecs's
|
||||||
|
*/
|
||||||
|
EXPORT void q120x2_extract_1blk_from_q120b_ref(uint64_t nn, uint64_t blk,
|
||||||
|
q120x2b* const dst, // 8 doubles
|
||||||
|
const q120b* const src // a reim vector
|
||||||
|
);
|
||||||
|
EXPORT void q120x2_extract_1blk_from_q120c_ref(uint64_t nn, uint64_t blk,
|
||||||
|
q120x2c* const dst, // 8 doubles
|
||||||
|
const q120c* const src // a reim vector
|
||||||
|
);
|
||||||
|
EXPORT void q120x2_extract_1blk_from_q120b_avx(uint64_t nn, uint64_t blk,
|
||||||
|
q120x2b* const dst, // 8 doubles
|
||||||
|
const q120b* const src // a reim vector
|
||||||
|
);
|
||||||
|
EXPORT void q120x2_extract_1blk_from_q120c_avx(uint64_t nn, uint64_t blk,
|
||||||
|
q120x2c* const dst, // 8 doubles
|
||||||
|
const q120c* const src // a reim vector
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief extract 1 reim4 block from nrows reim vectors of m complexes
|
||||||
|
* @param nn the size of each q120
|
||||||
|
* @param nrows the number of q120 (ntt) vectors
|
||||||
|
* @param blk the block id to extract (<m/4)
|
||||||
|
* @param dst the output: nrows q120x2's dst[i] = src[i](blk)
|
||||||
|
* @param src the input: nrows q120 ntt vectors
|
||||||
|
*/
|
||||||
|
EXPORT void q120x2_extract_1blk_from_contiguous_q120b_ref(
|
||||||
|
uint64_t nn, uint64_t nrows, uint64_t blk,
|
||||||
|
q120x2b* const dst, // nrows * 2 q120
|
||||||
|
const q120b* const src // a contiguous array of nrows q120b vectors
|
||||||
|
);
|
||||||
|
EXPORT void q120x2_extract_1blk_from_contiguous_q120b_avx(
|
||||||
|
uint64_t nn, uint64_t nrows, uint64_t blk,
|
||||||
|
q120x2b* const dst, // nrows * 2 q120
|
||||||
|
const q120b* const src // a contiguous array of nrows q120b vectors
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief saves 1 single q120x2 block in a q120 vectors of size nn
|
||||||
|
* @param nn the size of the output q120
|
||||||
|
* @param blk the block id to save (<nn/2)
|
||||||
|
* @param dest the output q120b vector: dst(blk) = src
|
||||||
|
* @param src the input q120x2b
|
||||||
|
*/
|
||||||
|
EXPORT void q120x2b_save_1blk_to_q120b_ref(uint64_t nn, uint64_t blk,
|
||||||
|
q120b* dest, // 1 reim vector of length m
|
||||||
|
const q120x2b* src // 8 doubles
|
||||||
|
);
|
||||||
|
EXPORT void q120x2b_save_1blk_to_q120b_avx(uint64_t nn, uint64_t blk,
|
||||||
|
q120b* dest, // 1 reim vector of length m
|
||||||
|
const q120x2b* src // 8 doubles
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void q120_add_bbb_simple(uint64_t nn, q120b* const res, const q120b* const x, const q120b* const y);
|
||||||
|
|
||||||
|
EXPORT void q120_add_ccc_simple(uint64_t nn, q120c* const res, const q120c* const x, const q120c* const y);
|
||||||
|
|
||||||
|
EXPORT void q120_c_from_b_simple(uint64_t nn, q120c* const res, const q120b* const x);
|
||||||
|
|
||||||
|
EXPORT void q120_b_from_znx64_simple(uint64_t nn, q120b* const res, const int64_t* const x);
|
||||||
|
|
||||||
|
EXPORT void q120_c_from_znx64_simple(uint64_t nn, q120c* const res, const int64_t* const x);
|
||||||
|
|
||||||
|
EXPORT void q120_b_to_znx128_simple(uint64_t nn, __int128_t* const res, const q120b* const x);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_Q120_ARITHMETIC_H
|
||||||
567
spqlios/lib/spqlios/q120/q120_arithmetic_avx2.c
Normal file
567
spqlios/lib/spqlios/q120/q120_arithmetic_avx2.c
Normal file
@@ -0,0 +1,567 @@
|
|||||||
|
#include <immintrin.h>
|
||||||
|
|
||||||
|
#include "q120_arithmetic.h"
|
||||||
|
#include "q120_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT void q120_vec_mat1col_product_baa_avx2(q120_mat1col_product_baa_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120a* const x, const q120a* const y) {
|
||||||
|
/**
|
||||||
|
* Algorithm:
|
||||||
|
* - res = acc1 + acc2 . ((2^H) % Q)
|
||||||
|
* - acc1 is the sum of H LSB of products x[i].y[i]
|
||||||
|
* - acc2 is the sum of 64-H MSB of products x[i]].y[i]
|
||||||
|
* - for l < 10k acc1 will have H + log2(10000) and acc2 64 - H + log2(10000) bits
|
||||||
|
* - final sum has max(H, 64 - H + bit_size((2^H) % Q)) + log2(10000) + 1 bits
|
||||||
|
*/
|
||||||
|
|
||||||
|
const uint64_t H = precomp->h;
|
||||||
|
const __m256i MASK = _mm256_set1_epi64x((UINT64_C(1) << H) - 1);
|
||||||
|
|
||||||
|
__m256i acc1 = _mm256_setzero_si256();
|
||||||
|
__m256i acc2 = _mm256_setzero_si256();
|
||||||
|
|
||||||
|
const __m256i* x_ptr = (__m256i*)x;
|
||||||
|
const __m256i* y_ptr = (__m256i*)y;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < ell; ++i) {
|
||||||
|
__m256i a = _mm256_loadu_si256(x_ptr);
|
||||||
|
__m256i b = _mm256_loadu_si256(y_ptr);
|
||||||
|
__m256i t = _mm256_mul_epu32(a, b);
|
||||||
|
|
||||||
|
acc1 = _mm256_add_epi64(acc1, _mm256_and_si256(t, MASK));
|
||||||
|
acc2 = _mm256_add_epi64(acc2, _mm256_srli_epi64(t, H));
|
||||||
|
|
||||||
|
x_ptr++;
|
||||||
|
y_ptr++;
|
||||||
|
}
|
||||||
|
|
||||||
|
const __m256i H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->h_pow_red);
|
||||||
|
|
||||||
|
__m256i t = _mm256_add_epi64(acc1, _mm256_mul_epu32(acc2, H_POW_RED));
|
||||||
|
_mm256_storeu_si256((__m256i*)res, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120_vec_mat1col_product_bbb_avx2(q120_mat1col_product_bbb_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120b* const x, const q120b* const y) {
|
||||||
|
/**
|
||||||
|
* Algorithm:
|
||||||
|
* 1. Split x_i and y_i in 2 32-bit parts and compute the cross-products:
|
||||||
|
* - x_i = xl_i + xh_i . 2^32
|
||||||
|
* - y_i = yl_i + yh_i . 2^32
|
||||||
|
* - A_i = xl_i . yl_i
|
||||||
|
* - B_i = xl_i . yh_i
|
||||||
|
* - C_i = xh_i . yl_i
|
||||||
|
* - D_i = xh_i . yh_i
|
||||||
|
* - we have x_i . y_i == A_i + (B_i + C_i) . 2^32 + D_i . 2^64
|
||||||
|
* 2. Split A_i, B_i, C_i and D_i into 2 32-bit parts
|
||||||
|
* - A_i = Al_i + Ah_i . 2^32
|
||||||
|
* - B_i = Bl_i + Bh_i . 2^32
|
||||||
|
* - C_i = Cl_i + Ch_i . 2^32
|
||||||
|
* - D_i = Dl_i + Dh_i . 2^32
|
||||||
|
* 3. Compute the sums:
|
||||||
|
* - S1 = \sum Al_i
|
||||||
|
* - S2 = \sum (Ah_i + Bl_i + Cl_i)
|
||||||
|
* - S3 = \sum (Bh_i + Ch_i + Dl_i)
|
||||||
|
* - S4 = \sum Dh_i
|
||||||
|
* - here S1, S4 have 32 + log2(ell) bits and S2, S3 have 32 + log2(ell) +
|
||||||
|
* log2(3) bits
|
||||||
|
* - for ell == 10000 S2, S3 have < 47 bits
|
||||||
|
* 4. Split S1, S2, S3 and S4 in 2 24-bit parts (24 = ceil(47/2))
|
||||||
|
* - S1 = S1l + S1h . 2^24
|
||||||
|
* - S2 = S2l + S2h . 2^24
|
||||||
|
* - S3 = S3l + S3h . 2^24
|
||||||
|
* - S4 = S4l + S4h . 2^24
|
||||||
|
* 5. Compute final result as:
|
||||||
|
* - \sum x_i . y_i = S1l + S1h . 2^24
|
||||||
|
* + S2l . 2^32 + S2h . 2^(32+24)
|
||||||
|
* + S3l . 2^64 + S3h . 2^(64 + 24)
|
||||||
|
* + S4l . 2^96 + S4l . 2^(96+24)
|
||||||
|
* - here the powers of 2 are reduced modulo the primes Q before
|
||||||
|
* multiplications
|
||||||
|
* - the result will be on 24 + 3 + bit size of primes Q
|
||||||
|
*/
|
||||||
|
const uint64_t H1 = 32;
|
||||||
|
const __m256i MASK1 = _mm256_set1_epi64x((UINT64_C(1) << H1) - 1);
|
||||||
|
|
||||||
|
__m256i s1 = _mm256_setzero_si256();
|
||||||
|
__m256i s2 = _mm256_setzero_si256();
|
||||||
|
__m256i s3 = _mm256_setzero_si256();
|
||||||
|
__m256i s4 = _mm256_setzero_si256();
|
||||||
|
|
||||||
|
const __m256i* x_ptr = (__m256i*)x;
|
||||||
|
const __m256i* y_ptr = (__m256i*)y;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < ell; ++i) {
|
||||||
|
__m256i x = _mm256_loadu_si256(x_ptr);
|
||||||
|
__m256i xl = _mm256_and_si256(x, MASK1);
|
||||||
|
__m256i xh = _mm256_srli_epi64(x, H1);
|
||||||
|
|
||||||
|
__m256i y = _mm256_loadu_si256(y_ptr);
|
||||||
|
__m256i yl = _mm256_and_si256(y, MASK1);
|
||||||
|
__m256i yh = _mm256_srli_epi64(y, H1);
|
||||||
|
|
||||||
|
__m256i a = _mm256_mul_epu32(xl, yl);
|
||||||
|
__m256i b = _mm256_mul_epu32(xl, yh);
|
||||||
|
__m256i c = _mm256_mul_epu32(xh, yl);
|
||||||
|
__m256i d = _mm256_mul_epu32(xh, yh);
|
||||||
|
|
||||||
|
s1 = _mm256_add_epi64(s1, _mm256_and_si256(a, MASK1));
|
||||||
|
|
||||||
|
s2 = _mm256_add_epi64(s2, _mm256_srli_epi64(a, H1));
|
||||||
|
s2 = _mm256_add_epi64(s2, _mm256_and_si256(b, MASK1));
|
||||||
|
s2 = _mm256_add_epi64(s2, _mm256_and_si256(c, MASK1));
|
||||||
|
|
||||||
|
s3 = _mm256_add_epi64(s3, _mm256_srli_epi64(b, H1));
|
||||||
|
s3 = _mm256_add_epi64(s3, _mm256_srli_epi64(c, H1));
|
||||||
|
s3 = _mm256_add_epi64(s3, _mm256_and_si256(d, MASK1));
|
||||||
|
|
||||||
|
s4 = _mm256_add_epi64(s4, _mm256_srli_epi64(d, H1));
|
||||||
|
|
||||||
|
x_ptr++;
|
||||||
|
y_ptr++;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint64_t H2 = precomp->h;
|
||||||
|
const __m256i MASK2 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1);
|
||||||
|
|
||||||
|
const __m256i S1H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s1h_pow_red);
|
||||||
|
__m256i s1l = _mm256_and_si256(s1, MASK2);
|
||||||
|
__m256i s1h = _mm256_srli_epi64(s1, H2);
|
||||||
|
__m256i t = _mm256_add_epi64(s1l, _mm256_mul_epu32(s1h, S1H_POW_RED));
|
||||||
|
|
||||||
|
const __m256i S2L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red);
|
||||||
|
const __m256i S2H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red);
|
||||||
|
__m256i s2l = _mm256_and_si256(s2, MASK2);
|
||||||
|
__m256i s2h = _mm256_srli_epi64(s2, H2);
|
||||||
|
t = _mm256_add_epi64(t, _mm256_mul_epu32(s2l, S2L_POW_RED));
|
||||||
|
t = _mm256_add_epi64(t, _mm256_mul_epu32(s2h, S2H_POW_RED));
|
||||||
|
|
||||||
|
const __m256i S3L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s3l_pow_red);
|
||||||
|
const __m256i S3H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s3h_pow_red);
|
||||||
|
__m256i s3l = _mm256_and_si256(s3, MASK2);
|
||||||
|
__m256i s3h = _mm256_srli_epi64(s3, H2);
|
||||||
|
t = _mm256_add_epi64(t, _mm256_mul_epu32(s3l, S3L_POW_RED));
|
||||||
|
t = _mm256_add_epi64(t, _mm256_mul_epu32(s3h, S3H_POW_RED));
|
||||||
|
|
||||||
|
const __m256i S4L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s4l_pow_red);
|
||||||
|
const __m256i S4H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s4h_pow_red);
|
||||||
|
__m256i s4l = _mm256_and_si256(s4, MASK2);
|
||||||
|
__m256i s4h = _mm256_srli_epi64(s4, H2);
|
||||||
|
t = _mm256_add_epi64(t, _mm256_mul_epu32(s4l, S4L_POW_RED));
|
||||||
|
t = _mm256_add_epi64(t, _mm256_mul_epu32(s4h, S4H_POW_RED));
|
||||||
|
|
||||||
|
_mm256_storeu_si256((__m256i*)res, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||||
|
/**
|
||||||
|
* Algorithm:
|
||||||
|
* 0. We have
|
||||||
|
* - y0_i == y_i % Q and y1_i == (y_i . 2^32) % Q
|
||||||
|
* 1. Split x_i in 2 32-bit parts and compute the cross-products:
|
||||||
|
* - x_i = xl_i + xh_i . 2^32
|
||||||
|
* - A_i = xl_i . y1_i
|
||||||
|
* - B_i = xh_i . y2_i
|
||||||
|
* - we have x_i . y_i == A_i + B_i
|
||||||
|
* 2. Split A_i and B_i into 2 32-bit parts
|
||||||
|
* - A_i = Al_i + Ah_i . 2^32
|
||||||
|
* - B_i = Bl_i + Bh_i . 2^32
|
||||||
|
* 3. Compute the sums:
|
||||||
|
* - S1 = \sum Al_i + Bl_i
|
||||||
|
* - S2 = \sum Ah_i + Bh_i
|
||||||
|
* - here S1 and S2 have 32 + log2(ell) bits
|
||||||
|
* - for ell == 10000 S1, S2 have < 46 bits
|
||||||
|
* 4. Split S2 in 27-bit and 19-bit parts (27+19 == 46)
|
||||||
|
* - S2 = S2l + S2h . 2^27
|
||||||
|
* 5. Compute final result as:
|
||||||
|
* - \sum x_i . y_i = S1 + S2l . 2^32 + S2h . 2^(32+27)
|
||||||
|
* - here the powers of 2 are reduced modulo the primes Q before
|
||||||
|
* multiplications
|
||||||
|
* - the result will be on < 52 bits
|
||||||
|
*/
|
||||||
|
|
||||||
|
const uint64_t H1 = 32;
|
||||||
|
const __m256i MASK1 = _mm256_set1_epi64x((UINT64_C(1) << H1) - 1);
|
||||||
|
|
||||||
|
__m256i s1 = _mm256_setzero_si256();
|
||||||
|
__m256i s2 = _mm256_setzero_si256();
|
||||||
|
|
||||||
|
const __m256i* x_ptr = (__m256i*)x;
|
||||||
|
const __m256i* y_ptr = (__m256i*)y;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < ell; ++i) {
|
||||||
|
__m256i x = _mm256_loadu_si256(x_ptr);
|
||||||
|
__m256i xl = _mm256_and_si256(x, MASK1);
|
||||||
|
__m256i xh = _mm256_srli_epi64(x, H1);
|
||||||
|
|
||||||
|
__m256i y = _mm256_loadu_si256(y_ptr);
|
||||||
|
__m256i y0 = _mm256_and_si256(y, MASK1);
|
||||||
|
__m256i y1 = _mm256_srli_epi64(y, H1);
|
||||||
|
|
||||||
|
__m256i a = _mm256_mul_epu32(xl, y0);
|
||||||
|
__m256i b = _mm256_mul_epu32(xh, y1);
|
||||||
|
|
||||||
|
s1 = _mm256_add_epi64(s1, _mm256_and_si256(a, MASK1));
|
||||||
|
s1 = _mm256_add_epi64(s1, _mm256_and_si256(b, MASK1));
|
||||||
|
|
||||||
|
s2 = _mm256_add_epi64(s2, _mm256_srli_epi64(a, H1));
|
||||||
|
s2 = _mm256_add_epi64(s2, _mm256_srli_epi64(b, H1));
|
||||||
|
|
||||||
|
x_ptr++;
|
||||||
|
y_ptr++;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint64_t H2 = precomp->h;
|
||||||
|
const __m256i MASK2 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1);
|
||||||
|
|
||||||
|
__m256i t = s1;
|
||||||
|
|
||||||
|
const __m256i S2L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red);
|
||||||
|
const __m256i S2H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red);
|
||||||
|
__m256i s2l = _mm256_and_si256(s2, MASK2);
|
||||||
|
__m256i s2h = _mm256_srli_epi64(s2, H2);
|
||||||
|
t = _mm256_add_epi64(t, _mm256_mul_epu32(s2l, S2L_POW_RED));
|
||||||
|
t = _mm256_add_epi64(t, _mm256_mul_epu32(s2h, S2H_POW_RED));
|
||||||
|
|
||||||
|
_mm256_storeu_si256((__m256i*)res, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated keeping this one for history only.
|
||||||
|
* There is a slight register starvation condition on the q120x2_vec_mat2cols
|
||||||
|
* strategy below sounds better.
|
||||||
|
*/
|
||||||
|
EXPORT void q120x2_vec_mat2cols_product_bbc_avx2_old(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||||
|
__m256i s0 = _mm256_setzero_si256(); // col 1a
|
||||||
|
__m256i s1 = _mm256_setzero_si256();
|
||||||
|
__m256i s2 = _mm256_setzero_si256(); // col 1b
|
||||||
|
__m256i s3 = _mm256_setzero_si256();
|
||||||
|
__m256i s4 = _mm256_setzero_si256(); // col 2a
|
||||||
|
__m256i s5 = _mm256_setzero_si256();
|
||||||
|
__m256i s6 = _mm256_setzero_si256(); // col 2b
|
||||||
|
__m256i s7 = _mm256_setzero_si256();
|
||||||
|
__m256i s8, s9, s10, s11;
|
||||||
|
__m256i s12, s13, s14, s15;
|
||||||
|
|
||||||
|
const __m256i* x_ptr = (__m256i*)x;
|
||||||
|
const __m256i* y_ptr = (__m256i*)y;
|
||||||
|
__m256i* res_ptr = (__m256i*)res;
|
||||||
|
for (uint64_t i = 0; i < ell; ++i) {
|
||||||
|
s8 = _mm256_loadu_si256(x_ptr);
|
||||||
|
s9 = _mm256_loadu_si256(x_ptr + 1);
|
||||||
|
s10 = _mm256_srli_epi64(s8, 32);
|
||||||
|
s11 = _mm256_srli_epi64(s9, 32);
|
||||||
|
|
||||||
|
s12 = _mm256_loadu_si256(y_ptr);
|
||||||
|
s13 = _mm256_loadu_si256(y_ptr + 1);
|
||||||
|
s14 = _mm256_srli_epi64(s12, 32);
|
||||||
|
s15 = _mm256_srli_epi64(s13, 32);
|
||||||
|
|
||||||
|
s12 = _mm256_mul_epu32(s8, s12); // -> s0,s1
|
||||||
|
s13 = _mm256_mul_epu32(s9, s13); // -> s2,s3
|
||||||
|
s14 = _mm256_mul_epu32(s10, s14); // -> s0,s1
|
||||||
|
s15 = _mm256_mul_epu32(s11, s15); // -> s2,s3
|
||||||
|
|
||||||
|
s10 = _mm256_slli_epi64(s12, 32); // -> s0
|
||||||
|
s11 = _mm256_slli_epi64(s13, 32); // -> s2
|
||||||
|
s12 = _mm256_srli_epi64(s12, 32); // -> s1
|
||||||
|
s13 = _mm256_srli_epi64(s13, 32); // -> s3
|
||||||
|
s10 = _mm256_srli_epi64(s10, 32); // -> s0
|
||||||
|
s11 = _mm256_srli_epi64(s11, 32); // -> s2
|
||||||
|
|
||||||
|
s0 = _mm256_add_epi64(s0, s10);
|
||||||
|
s1 = _mm256_add_epi64(s1, s12);
|
||||||
|
s2 = _mm256_add_epi64(s2, s11);
|
||||||
|
s3 = _mm256_add_epi64(s3, s13);
|
||||||
|
|
||||||
|
s10 = _mm256_slli_epi64(s14, 32); // -> s0
|
||||||
|
s11 = _mm256_slli_epi64(s15, 32); // -> s2
|
||||||
|
s14 = _mm256_srli_epi64(s14, 32); // -> s1
|
||||||
|
s15 = _mm256_srli_epi64(s15, 32); // -> s3
|
||||||
|
s10 = _mm256_srli_epi64(s10, 32); // -> s0
|
||||||
|
s11 = _mm256_srli_epi64(s11, 32); // -> s2
|
||||||
|
|
||||||
|
s0 = _mm256_add_epi64(s0, s10);
|
||||||
|
s1 = _mm256_add_epi64(s1, s14);
|
||||||
|
s2 = _mm256_add_epi64(s2, s11);
|
||||||
|
s3 = _mm256_add_epi64(s3, s15);
|
||||||
|
|
||||||
|
// deal with the second column
|
||||||
|
// s8,s9 are still in place!
|
||||||
|
s10 = _mm256_srli_epi64(s8, 32);
|
||||||
|
s11 = _mm256_srli_epi64(s9, 32);
|
||||||
|
|
||||||
|
s12 = _mm256_loadu_si256(y_ptr + 2);
|
||||||
|
s13 = _mm256_loadu_si256(y_ptr + 3);
|
||||||
|
s14 = _mm256_srli_epi64(s12, 32);
|
||||||
|
s15 = _mm256_srli_epi64(s13, 32);
|
||||||
|
|
||||||
|
s12 = _mm256_mul_epu32(s8, s12); // -> s4,s5
|
||||||
|
s13 = _mm256_mul_epu32(s9, s13); // -> s6,s7
|
||||||
|
s14 = _mm256_mul_epu32(s10, s14); // -> s4,s5
|
||||||
|
s15 = _mm256_mul_epu32(s11, s15); // -> s6,s7
|
||||||
|
|
||||||
|
s10 = _mm256_slli_epi64(s12, 32); // -> s4
|
||||||
|
s11 = _mm256_slli_epi64(s13, 32); // -> s6
|
||||||
|
s12 = _mm256_srli_epi64(s12, 32); // -> s5
|
||||||
|
s13 = _mm256_srli_epi64(s13, 32); // -> s7
|
||||||
|
s10 = _mm256_srli_epi64(s10, 32); // -> s4
|
||||||
|
s11 = _mm256_srli_epi64(s11, 32); // -> s6
|
||||||
|
|
||||||
|
s4 = _mm256_add_epi64(s4, s10);
|
||||||
|
s5 = _mm256_add_epi64(s5, s12);
|
||||||
|
s6 = _mm256_add_epi64(s6, s11);
|
||||||
|
s7 = _mm256_add_epi64(s7, s13);
|
||||||
|
|
||||||
|
s10 = _mm256_slli_epi64(s14, 32); // -> s4
|
||||||
|
s11 = _mm256_slli_epi64(s15, 32); // -> s6
|
||||||
|
s14 = _mm256_srli_epi64(s14, 32); // -> s5
|
||||||
|
s15 = _mm256_srli_epi64(s15, 32); // -> s7
|
||||||
|
s10 = _mm256_srli_epi64(s10, 32); // -> s4
|
||||||
|
s11 = _mm256_srli_epi64(s11, 32); // -> s6
|
||||||
|
|
||||||
|
s4 = _mm256_add_epi64(s4, s10);
|
||||||
|
s5 = _mm256_add_epi64(s5, s14);
|
||||||
|
s6 = _mm256_add_epi64(s6, s11);
|
||||||
|
s7 = _mm256_add_epi64(s7, s15);
|
||||||
|
|
||||||
|
x_ptr += 2;
|
||||||
|
y_ptr += 4;
|
||||||
|
}
|
||||||
|
// final reduction
|
||||||
|
const uint64_t H2 = precomp->h;
|
||||||
|
s8 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); // MASK2
|
||||||
|
s9 = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); // S2L_POW_RED
|
||||||
|
s10 = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); // S2H_POW_RED
|
||||||
|
//--- s0,s1
|
||||||
|
s11 = _mm256_and_si256(s1, s8);
|
||||||
|
s12 = _mm256_srli_epi64(s1, H2);
|
||||||
|
s13 = _mm256_mul_epu32(s11, s9);
|
||||||
|
s14 = _mm256_mul_epu32(s12, s10);
|
||||||
|
s0 = _mm256_add_epi64(s0, s13);
|
||||||
|
s0 = _mm256_add_epi64(s0, s14);
|
||||||
|
_mm256_storeu_si256(res_ptr + 0, s0);
|
||||||
|
//--- s2,s3
|
||||||
|
s11 = _mm256_and_si256(s3, s8);
|
||||||
|
s12 = _mm256_srli_epi64(s3, H2);
|
||||||
|
s13 = _mm256_mul_epu32(s11, s9);
|
||||||
|
s14 = _mm256_mul_epu32(s12, s10);
|
||||||
|
s2 = _mm256_add_epi64(s2, s13);
|
||||||
|
s2 = _mm256_add_epi64(s2, s14);
|
||||||
|
_mm256_storeu_si256(res_ptr + 1, s2);
|
||||||
|
//--- s4,s5
|
||||||
|
s11 = _mm256_and_si256(s5, s8);
|
||||||
|
s12 = _mm256_srli_epi64(s5, H2);
|
||||||
|
s13 = _mm256_mul_epu32(s11, s9);
|
||||||
|
s14 = _mm256_mul_epu32(s12, s10);
|
||||||
|
s4 = _mm256_add_epi64(s4, s13);
|
||||||
|
s4 = _mm256_add_epi64(s4, s14);
|
||||||
|
_mm256_storeu_si256(res_ptr + 2, s4);
|
||||||
|
//--- s6,s7
|
||||||
|
s11 = _mm256_and_si256(s7, s8);
|
||||||
|
s12 = _mm256_srli_epi64(s7, H2);
|
||||||
|
s13 = _mm256_mul_epu32(s11, s9);
|
||||||
|
s14 = _mm256_mul_epu32(s12, s10);
|
||||||
|
s6 = _mm256_add_epi64(s6, s13);
|
||||||
|
s6 = _mm256_add_epi64(s6, s14);
|
||||||
|
_mm256_storeu_si256(res_ptr + 3, s6);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120x2_vec_mat2cols_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||||
|
__m256i s0 = _mm256_setzero_si256(); // col 1a
|
||||||
|
__m256i s1 = _mm256_setzero_si256();
|
||||||
|
__m256i s2 = _mm256_setzero_si256(); // col 1b
|
||||||
|
__m256i s3 = _mm256_setzero_si256();
|
||||||
|
__m256i s4 = _mm256_setzero_si256(); // col 2a
|
||||||
|
__m256i s5 = _mm256_setzero_si256();
|
||||||
|
__m256i s6 = _mm256_setzero_si256(); // col 2b
|
||||||
|
__m256i s7 = _mm256_setzero_si256();
|
||||||
|
__m256i s8, s9, s10, s11;
|
||||||
|
__m256i s12, s13, s14, s15;
|
||||||
|
|
||||||
|
s11 = _mm256_set1_epi64x(0xFFFFFFFFUL);
|
||||||
|
const __m256i* x_ptr = (__m256i*)x;
|
||||||
|
const __m256i* y_ptr = (__m256i*)y;
|
||||||
|
__m256i* res_ptr = (__m256i*)res;
|
||||||
|
for (uint64_t i = 0; i < ell; ++i) {
|
||||||
|
// treat item a
|
||||||
|
s8 = _mm256_loadu_si256(x_ptr);
|
||||||
|
s9 = _mm256_srli_epi64(s8, 32);
|
||||||
|
|
||||||
|
s12 = _mm256_loadu_si256(y_ptr);
|
||||||
|
s13 = _mm256_loadu_si256(y_ptr + 2);
|
||||||
|
s14 = _mm256_srli_epi64(s12, 32);
|
||||||
|
s15 = _mm256_srli_epi64(s13, 32);
|
||||||
|
|
||||||
|
s12 = _mm256_mul_epu32(s8, s12); // c1a -> s0,s1
|
||||||
|
s13 = _mm256_mul_epu32(s8, s13); // c2a -> s4,s5
|
||||||
|
s14 = _mm256_mul_epu32(s9, s14); // c1a -> s0,s1
|
||||||
|
s15 = _mm256_mul_epu32(s9, s15); // c2a -> s4,s5
|
||||||
|
|
||||||
|
s8 = _mm256_and_si256(s12, s11); // -> s0
|
||||||
|
s9 = _mm256_and_si256(s13, s11); // -> s4
|
||||||
|
s12 = _mm256_srli_epi64(s12, 32); // -> s1
|
||||||
|
s13 = _mm256_srli_epi64(s13, 32); // -> s5
|
||||||
|
s0 = _mm256_add_epi64(s0, s8);
|
||||||
|
s1 = _mm256_add_epi64(s1, s12);
|
||||||
|
s4 = _mm256_add_epi64(s4, s9);
|
||||||
|
s5 = _mm256_add_epi64(s5, s13);
|
||||||
|
|
||||||
|
s8 = _mm256_and_si256(s14, s11); // -> s0
|
||||||
|
s9 = _mm256_and_si256(s15, s11); // -> s4
|
||||||
|
s14 = _mm256_srli_epi64(s14, 32); // -> s1
|
||||||
|
s15 = _mm256_srli_epi64(s15, 32); // -> s5
|
||||||
|
s0 = _mm256_add_epi64(s0, s8);
|
||||||
|
s1 = _mm256_add_epi64(s1, s14);
|
||||||
|
s4 = _mm256_add_epi64(s4, s9);
|
||||||
|
s5 = _mm256_add_epi64(s5, s15);
|
||||||
|
|
||||||
|
// treat item b
|
||||||
|
s8 = _mm256_loadu_si256(x_ptr + 1);
|
||||||
|
s9 = _mm256_srli_epi64(s8, 32);
|
||||||
|
|
||||||
|
s12 = _mm256_loadu_si256(y_ptr + 1);
|
||||||
|
s13 = _mm256_loadu_si256(y_ptr + 3);
|
||||||
|
s14 = _mm256_srli_epi64(s12, 32);
|
||||||
|
s15 = _mm256_srli_epi64(s13, 32);
|
||||||
|
|
||||||
|
s12 = _mm256_mul_epu32(s8, s12); // c1b -> s2,s3
|
||||||
|
s13 = _mm256_mul_epu32(s8, s13); // c2b -> s6,s7
|
||||||
|
s14 = _mm256_mul_epu32(s9, s14); // c1b -> s2,s3
|
||||||
|
s15 = _mm256_mul_epu32(s9, s15); // c2b -> s6,s7
|
||||||
|
|
||||||
|
s8 = _mm256_and_si256(s12, s11); // -> s2
|
||||||
|
s9 = _mm256_and_si256(s13, s11); // -> s6
|
||||||
|
s12 = _mm256_srli_epi64(s12, 32); // -> s3
|
||||||
|
s13 = _mm256_srli_epi64(s13, 32); // -> s7
|
||||||
|
s2 = _mm256_add_epi64(s2, s8);
|
||||||
|
s3 = _mm256_add_epi64(s3, s12);
|
||||||
|
s6 = _mm256_add_epi64(s6, s9);
|
||||||
|
s7 = _mm256_add_epi64(s7, s13);
|
||||||
|
|
||||||
|
s8 = _mm256_and_si256(s14, s11); // -> s2
|
||||||
|
s9 = _mm256_and_si256(s15, s11); // -> s6
|
||||||
|
s14 = _mm256_srli_epi64(s14, 32); // -> s3
|
||||||
|
s15 = _mm256_srli_epi64(s15, 32); // -> s7
|
||||||
|
s2 = _mm256_add_epi64(s2, s8);
|
||||||
|
s3 = _mm256_add_epi64(s3, s14);
|
||||||
|
s6 = _mm256_add_epi64(s6, s9);
|
||||||
|
s7 = _mm256_add_epi64(s7, s15);
|
||||||
|
|
||||||
|
x_ptr += 2;
|
||||||
|
y_ptr += 4;
|
||||||
|
}
|
||||||
|
// final reduction
|
||||||
|
const uint64_t H2 = precomp->h;
|
||||||
|
s8 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); // MASK2
|
||||||
|
s9 = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); // S2L_POW_RED
|
||||||
|
s10 = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); // S2H_POW_RED
|
||||||
|
//--- s0,s1
|
||||||
|
s11 = _mm256_and_si256(s1, s8);
|
||||||
|
s12 = _mm256_srli_epi64(s1, H2);
|
||||||
|
s13 = _mm256_mul_epu32(s11, s9);
|
||||||
|
s14 = _mm256_mul_epu32(s12, s10);
|
||||||
|
s0 = _mm256_add_epi64(s0, s13);
|
||||||
|
s0 = _mm256_add_epi64(s0, s14);
|
||||||
|
_mm256_storeu_si256(res_ptr + 0, s0);
|
||||||
|
//--- s2,s3
|
||||||
|
s11 = _mm256_and_si256(s3, s8);
|
||||||
|
s12 = _mm256_srli_epi64(s3, H2);
|
||||||
|
s13 = _mm256_mul_epu32(s11, s9);
|
||||||
|
s14 = _mm256_mul_epu32(s12, s10);
|
||||||
|
s2 = _mm256_add_epi64(s2, s13);
|
||||||
|
s2 = _mm256_add_epi64(s2, s14);
|
||||||
|
_mm256_storeu_si256(res_ptr + 1, s2);
|
||||||
|
//--- s4,s5
|
||||||
|
s11 = _mm256_and_si256(s5, s8);
|
||||||
|
s12 = _mm256_srli_epi64(s5, H2);
|
||||||
|
s13 = _mm256_mul_epu32(s11, s9);
|
||||||
|
s14 = _mm256_mul_epu32(s12, s10);
|
||||||
|
s4 = _mm256_add_epi64(s4, s13);
|
||||||
|
s4 = _mm256_add_epi64(s4, s14);
|
||||||
|
_mm256_storeu_si256(res_ptr + 2, s4);
|
||||||
|
//--- s6,s7
|
||||||
|
s11 = _mm256_and_si256(s7, s8);
|
||||||
|
s12 = _mm256_srli_epi64(s7, H2);
|
||||||
|
s13 = _mm256_mul_epu32(s11, s9);
|
||||||
|
s14 = _mm256_mul_epu32(s12, s10);
|
||||||
|
s6 = _mm256_add_epi64(s6, s13);
|
||||||
|
s6 = _mm256_add_epi64(s6, s14);
|
||||||
|
_mm256_storeu_si256(res_ptr + 3, s6);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120x2_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||||
|
__m256i s0 = _mm256_setzero_si256(); // col 1a
|
||||||
|
__m256i s1 = _mm256_setzero_si256();
|
||||||
|
__m256i s2 = _mm256_setzero_si256(); // col 1b
|
||||||
|
__m256i s3 = _mm256_setzero_si256();
|
||||||
|
__m256i s4 = _mm256_set1_epi64x(0xFFFFFFFFUL);
|
||||||
|
__m256i s8, s9, s10, s11;
|
||||||
|
__m256i s12, s13, s14, s15;
|
||||||
|
|
||||||
|
const __m256i* x_ptr = (__m256i*)x;
|
||||||
|
const __m256i* y_ptr = (__m256i*)y;
|
||||||
|
__m256i* res_ptr = (__m256i*)res;
|
||||||
|
for (uint64_t i = 0; i < ell; ++i) {
|
||||||
|
s8 = _mm256_loadu_si256(x_ptr);
|
||||||
|
s9 = _mm256_loadu_si256(x_ptr + 1);
|
||||||
|
s10 = _mm256_srli_epi64(s8, 32);
|
||||||
|
s11 = _mm256_srli_epi64(s9, 32);
|
||||||
|
|
||||||
|
s12 = _mm256_loadu_si256(y_ptr);
|
||||||
|
s13 = _mm256_loadu_si256(y_ptr + 1);
|
||||||
|
s14 = _mm256_srli_epi64(s12, 32);
|
||||||
|
s15 = _mm256_srli_epi64(s13, 32);
|
||||||
|
|
||||||
|
s12 = _mm256_mul_epu32(s8, s12); // -> s0,s1
|
||||||
|
s13 = _mm256_mul_epu32(s9, s13); // -> s2,s3
|
||||||
|
s14 = _mm256_mul_epu32(s10, s14); // -> s0,s1
|
||||||
|
s15 = _mm256_mul_epu32(s11, s15); // -> s2,s3
|
||||||
|
|
||||||
|
s8 = _mm256_and_si256(s12, s4); // -> s0
|
||||||
|
s9 = _mm256_and_si256(s13, s4); // -> s2
|
||||||
|
s10 = _mm256_and_si256(s14, s4); // -> s0
|
||||||
|
s11 = _mm256_and_si256(s15, s4); // -> s2
|
||||||
|
s12 = _mm256_srli_epi64(s12, 32); // -> s1
|
||||||
|
s13 = _mm256_srli_epi64(s13, 32); // -> s3
|
||||||
|
s14 = _mm256_srli_epi64(s14, 32); // -> s1
|
||||||
|
s15 = _mm256_srli_epi64(s15, 32); // -> s3
|
||||||
|
|
||||||
|
s0 = _mm256_add_epi64(s0, s8);
|
||||||
|
s1 = _mm256_add_epi64(s1, s12);
|
||||||
|
s2 = _mm256_add_epi64(s2, s9);
|
||||||
|
s3 = _mm256_add_epi64(s3, s13);
|
||||||
|
s0 = _mm256_add_epi64(s0, s10);
|
||||||
|
s1 = _mm256_add_epi64(s1, s14);
|
||||||
|
s2 = _mm256_add_epi64(s2, s11);
|
||||||
|
s3 = _mm256_add_epi64(s3, s15);
|
||||||
|
|
||||||
|
x_ptr += 2;
|
||||||
|
y_ptr += 2;
|
||||||
|
}
|
||||||
|
// final reduction
|
||||||
|
const uint64_t H2 = precomp->h;
|
||||||
|
s8 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); // MASK2
|
||||||
|
s9 = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); // S2L_POW_RED
|
||||||
|
s10 = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); // S2H_POW_RED
|
||||||
|
//--- s0,s1
|
||||||
|
s11 = _mm256_and_si256(s1, s8);
|
||||||
|
s12 = _mm256_srli_epi64(s1, H2);
|
||||||
|
s13 = _mm256_mul_epu32(s11, s9);
|
||||||
|
s14 = _mm256_mul_epu32(s12, s10);
|
||||||
|
s0 = _mm256_add_epi64(s0, s13);
|
||||||
|
s0 = _mm256_add_epi64(s0, s14);
|
||||||
|
_mm256_storeu_si256(res_ptr + 0, s0);
|
||||||
|
//--- s2,s3
|
||||||
|
s11 = _mm256_and_si256(s3, s8);
|
||||||
|
s12 = _mm256_srli_epi64(s3, H2);
|
||||||
|
s13 = _mm256_mul_epu32(s11, s9);
|
||||||
|
s14 = _mm256_mul_epu32(s12, s10);
|
||||||
|
s2 = _mm256_add_epi64(s2, s13);
|
||||||
|
s2 = _mm256_add_epi64(s2, s14);
|
||||||
|
_mm256_storeu_si256(res_ptr + 1, s2);
|
||||||
|
}
|
||||||
37
spqlios/lib/spqlios/q120/q120_arithmetic_private.h
Normal file
37
spqlios/lib/spqlios/q120/q120_arithmetic_private.h
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
#ifndef SPQLIOS_Q120_ARITHMETIC_DEF_H
|
||||||
|
#define SPQLIOS_Q120_ARITHMETIC_DEF_H
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
typedef struct _q120_mat1col_product_baa_precomp {
|
||||||
|
uint64_t h;
|
||||||
|
uint64_t h_pow_red[4];
|
||||||
|
#ifndef NDEBUG
|
||||||
|
double res_bit_size;
|
||||||
|
#endif
|
||||||
|
} q120_mat1col_product_baa_precomp;
|
||||||
|
|
||||||
|
typedef struct _q120_mat1col_product_bbb_precomp {
|
||||||
|
uint64_t h;
|
||||||
|
uint64_t s1h_pow_red[4];
|
||||||
|
uint64_t s2l_pow_red[4];
|
||||||
|
uint64_t s2h_pow_red[4];
|
||||||
|
uint64_t s3l_pow_red[4];
|
||||||
|
uint64_t s3h_pow_red[4];
|
||||||
|
uint64_t s4l_pow_red[4];
|
||||||
|
uint64_t s4h_pow_red[4];
|
||||||
|
#ifndef NDEBUG
|
||||||
|
double res_bit_size;
|
||||||
|
#endif
|
||||||
|
} q120_mat1col_product_bbb_precomp;
|
||||||
|
|
||||||
|
typedef struct _q120_mat1col_product_bbc_precomp {
|
||||||
|
uint64_t h;
|
||||||
|
uint64_t s2l_pow_red[4];
|
||||||
|
uint64_t s2h_pow_red[4];
|
||||||
|
#ifndef NDEBUG
|
||||||
|
double res_bit_size;
|
||||||
|
#endif
|
||||||
|
} q120_mat1col_product_bbc_precomp;
|
||||||
|
|
||||||
|
#endif // SPQLIOS_Q120_ARITHMETIC_DEF_H
|
||||||
506
spqlios/lib/spqlios/q120/q120_arithmetic_ref.c
Normal file
506
spqlios/lib/spqlios/q120/q120_arithmetic_ref.c
Normal file
@@ -0,0 +1,506 @@
|
|||||||
|
#include <assert.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "q120_arithmetic.h"
|
||||||
|
#include "q120_arithmetic_private.h"
|
||||||
|
#include "q120_common.h"
|
||||||
|
|
||||||
|
#define MODQ(val, q) ((val) % (q))
|
||||||
|
|
||||||
|
double comp_bit_size_red(const uint64_t h, const uint64_t qs[4]) {
|
||||||
|
assert(h < 128);
|
||||||
|
double h_pow2_bs = 0;
|
||||||
|
for (uint64_t k = 0; k < 4; ++k) {
|
||||||
|
double t = log2((double)MODQ((__uint128_t)1 << h, qs[k]));
|
||||||
|
if (t > h_pow2_bs) h_pow2_bs = t;
|
||||||
|
}
|
||||||
|
return h_pow2_bs;
|
||||||
|
}
|
||||||
|
|
||||||
|
double comp_bit_size_sum(const uint64_t n, const double* const bs) {
|
||||||
|
double s = 0;
|
||||||
|
for (uint64_t i = 0; i < n; ++i) {
|
||||||
|
s += pow(2, bs[i]);
|
||||||
|
}
|
||||||
|
return log2(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void vec_mat1col_product_baa_precomp(q120_mat1col_product_baa_precomp* precomp) {
|
||||||
|
uint64_t qs[4] = {Q1, Q2, Q3, Q4};
|
||||||
|
|
||||||
|
double min_res_bs = 1000;
|
||||||
|
uint64_t min_h = -1;
|
||||||
|
|
||||||
|
double ell_bs = log2((double)MAX_ELL);
|
||||||
|
for (uint64_t h = 1; h < 64; ++h) {
|
||||||
|
double h_pow2_bs = comp_bit_size_red(h, qs);
|
||||||
|
|
||||||
|
const double bs[] = {h + ell_bs, 64 - h + ell_bs + h_pow2_bs};
|
||||||
|
const double res_bs = comp_bit_size_sum(2, bs);
|
||||||
|
|
||||||
|
if (min_res_bs > res_bs) {
|
||||||
|
min_res_bs = res_bs;
|
||||||
|
min_h = h;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(min_res_bs < 64);
|
||||||
|
precomp->h = min_h;
|
||||||
|
for (uint64_t k = 0; k < 4; ++k) {
|
||||||
|
precomp->h_pow_red[k] = MODQ(UINT64_C(1) << precomp->h, qs[k]);
|
||||||
|
}
|
||||||
|
#ifndef NDEBUG
|
||||||
|
precomp->res_bit_size = min_res_bs;
|
||||||
|
#endif
|
||||||
|
// printf("AA %lu %lf\n", min_h, min_res_bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT q120_mat1col_product_baa_precomp* q120_new_vec_mat1col_product_baa_precomp() {
|
||||||
|
q120_mat1col_product_baa_precomp* res = malloc(sizeof(q120_mat1col_product_baa_precomp));
|
||||||
|
vec_mat1col_product_baa_precomp(res);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120_delete_vec_mat1col_product_baa_precomp(q120_mat1col_product_baa_precomp* addr) { free(addr); }
|
||||||
|
|
||||||
|
EXPORT void q120_vec_mat1col_product_baa_ref(q120_mat1col_product_baa_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120a* const x, const q120a* const y) {
|
||||||
|
/**
|
||||||
|
* Algorithm:
|
||||||
|
* - res = acc1 + acc2 . ((2^H) % Q)
|
||||||
|
* - acc1 is the sum of H LSB of products x[i].y[i]
|
||||||
|
* - acc2 is the sum of 64-H MSB of products x[i]].y[i]
|
||||||
|
* - for l < 10k acc1 will have H + log2(10000) and acc2 64 - H + log2(10000) bits
|
||||||
|
* - final sum has max(H, 64 - H + bit_size((2^H) % Q)) + log2(10000) + 1 bits
|
||||||
|
*/
|
||||||
|
const uint64_t H = precomp->h;
|
||||||
|
const uint64_t MASK = (UINT64_C(1) << H) - 1;
|
||||||
|
|
||||||
|
uint64_t acc1[4] = {0, 0, 0, 0}; // accumulate H least significant bits of product
|
||||||
|
uint64_t acc2[4] = {0, 0, 0, 0}; // accumulate 64 - H most significan bits of product
|
||||||
|
|
||||||
|
const uint64_t* const x_ptr = (uint64_t*)x;
|
||||||
|
const uint64_t* const y_ptr = (uint64_t*)y;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < 4 * ell; i += 4) {
|
||||||
|
for (uint64_t j = 0; j < 4; ++j) {
|
||||||
|
uint64_t t = x_ptr[i + j] * y_ptr[i + j];
|
||||||
|
acc1[j] += t & MASK;
|
||||||
|
acc2[j] += t >> H;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t* const res_ptr = (uint64_t*)res;
|
||||||
|
for (uint64_t j = 0; j < 4; ++j) {
|
||||||
|
res_ptr[j] = acc1[j] + acc2[j] * precomp->h_pow_red[j];
|
||||||
|
assert(log2(res_ptr[j]) < precomp->res_bit_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void vec_mat1col_product_bbb_precomp(q120_mat1col_product_bbb_precomp* precomp) {
|
||||||
|
uint64_t qs[4] = {Q1, Q2, Q3, Q4};
|
||||||
|
|
||||||
|
double ell_bs = log2((double)MAX_ELL);
|
||||||
|
double min_res_bs = 1000;
|
||||||
|
uint64_t min_h = -1;
|
||||||
|
|
||||||
|
const double s1_bs = 32 + ell_bs;
|
||||||
|
const double s2_bs = 32 + ell_bs + log2(3);
|
||||||
|
const double s3_bs = 32 + ell_bs + log2(3);
|
||||||
|
const double s4_bs = 32 + ell_bs;
|
||||||
|
for (uint64_t h = 16; h < 32; ++h) {
|
||||||
|
const double s1l_bs = h;
|
||||||
|
const double s1h_bs = (s1_bs - h) + comp_bit_size_red(h, qs);
|
||||||
|
const double s2l_bs = h + comp_bit_size_red(32, qs);
|
||||||
|
const double s2h_bs = (s2_bs - h) + comp_bit_size_red(32 + h, qs);
|
||||||
|
const double s3l_bs = h + comp_bit_size_red(64, qs);
|
||||||
|
const double s3h_bs = (s3_bs - h) + comp_bit_size_red(64 + h, qs);
|
||||||
|
const double s4l_bs = h + comp_bit_size_red(96, qs);
|
||||||
|
const double s4h_bs = (s4_bs - h) + comp_bit_size_red(96 + h, qs);
|
||||||
|
|
||||||
|
const double bs[] = {s1l_bs, s1h_bs, s2l_bs, s2h_bs, s3l_bs, s3h_bs, s4l_bs, s4h_bs};
|
||||||
|
const double res_bs = comp_bit_size_sum(8, bs);
|
||||||
|
|
||||||
|
if (min_res_bs > res_bs) {
|
||||||
|
min_res_bs = res_bs;
|
||||||
|
min_h = h;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(min_res_bs < 64);
|
||||||
|
precomp->h = min_h;
|
||||||
|
for (uint64_t k = 0; k < 4; ++k) {
|
||||||
|
precomp->s1h_pow_red[k] = UINT64_C(1) << precomp->h; // 2^24
|
||||||
|
precomp->s2l_pow_red[k] = MODQ(UINT64_C(1) << 32, qs[k]); // 2^32
|
||||||
|
precomp->s2h_pow_red[k] = MODQ(precomp->s2l_pow_red[k] * precomp->s1h_pow_red[k], qs[k]); // 2^(32+24)
|
||||||
|
precomp->s3l_pow_red[k] = MODQ(precomp->s2l_pow_red[k] * precomp->s2l_pow_red[k], qs[k]); // 2^64 = 2^(32+32)
|
||||||
|
precomp->s3h_pow_red[k] = MODQ(precomp->s3l_pow_red[k] * precomp->s1h_pow_red[k], qs[k]); // 2^(64+24)
|
||||||
|
precomp->s4l_pow_red[k] = MODQ(precomp->s3l_pow_red[k] * precomp->s2l_pow_red[k], qs[k]); // 2^96 = 2^(64+32)
|
||||||
|
precomp->s4h_pow_red[k] = MODQ(precomp->s4l_pow_red[k] * precomp->s1h_pow_red[k], qs[k]); // 2^(96+24)
|
||||||
|
}
|
||||||
|
// printf("AA %lu %lf\n", min_h, min_res_bs);
|
||||||
|
#ifndef NDEBUG
|
||||||
|
precomp->res_bit_size = min_res_bs;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT q120_mat1col_product_bbb_precomp* q120_new_vec_mat1col_product_bbb_precomp() {
|
||||||
|
q120_mat1col_product_bbb_precomp* res = malloc(sizeof(q120_mat1col_product_bbb_precomp));
|
||||||
|
vec_mat1col_product_bbb_precomp(res);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120_delete_vec_mat1col_product_bbb_precomp(q120_mat1col_product_bbb_precomp* addr) { free(addr); }
|
||||||
|
|
||||||
|
EXPORT void q120_vec_mat1col_product_bbb_ref(q120_mat1col_product_bbb_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120b* const x, const q120b* const y) {
|
||||||
|
/**
|
||||||
|
* Algorithm:
|
||||||
|
* 1. Split x_i and y_i in 2 32-bit parts and compute the cross-products:
|
||||||
|
* - x_i = xl_i + xh_i . 2^32
|
||||||
|
* - y_i = yl_i + yh_i . 2^32
|
||||||
|
* - A_i = xl_i . yl_i
|
||||||
|
* - B_i = xl_i . yh_i
|
||||||
|
* - C_i = xh_i . yl_i
|
||||||
|
* - D_i = xh_i . yh_i
|
||||||
|
* - we have x_i . y_i == A_i + (B_i + C_i) . 2^32 + D_i . 2^64
|
||||||
|
* 2. Split A_i, B_i, C_i and D_i into 2 32-bit parts
|
||||||
|
* - A_i = Al_i + Ah_i . 2^32
|
||||||
|
* - B_i = Bl_i + Bh_i . 2^32
|
||||||
|
* - C_i = Cl_i + Ch_i . 2^32
|
||||||
|
* - D_i = Dl_i + Dh_i . 2^32
|
||||||
|
* 3. Compute the sums:
|
||||||
|
* - S1 = \sum Al_i
|
||||||
|
* - S2 = \sum (Ah_i + Bl_i + Cl_i)
|
||||||
|
* - S3 = \sum (Bh_i + Ch_i + Dl_i)
|
||||||
|
* - S4 = \sum Dh_i
|
||||||
|
* - here S1, S4 have 32 + log2(ell) bits and S2, S3 have 32 + log2(ell) +
|
||||||
|
* log2(3) bits
|
||||||
|
* - for ell == 10000 S2, S3 have < 47 bits
|
||||||
|
* 4. Split S1, S2, S3 and S4 in 2 24-bit parts (24 = ceil(47/2))
|
||||||
|
* - S1 = S1l + S1h . 2^24
|
||||||
|
* - S2 = S2l + S2h . 2^24
|
||||||
|
* - S3 = S3l + S3h . 2^24
|
||||||
|
* - S4 = S4l + S4h . 2^24
|
||||||
|
* 5. Compute final result as:
|
||||||
|
* - \sum x_i . y_i = S1l + S1h . 2^24
|
||||||
|
* + S2l . 2^32 + S2h . 2^(32+24)
|
||||||
|
* + S3l . 2^64 + S3h . 2^(64 + 24)
|
||||||
|
* + S4l . 2^96 + S4l . 2^(96+24)
|
||||||
|
* - here the powers of 2 are reduced modulo the primes Q before
|
||||||
|
* multiplications
|
||||||
|
* - the result will be on 24 + 3 + bit size of primes Q
|
||||||
|
*/
|
||||||
|
const uint64_t H1 = 32;
|
||||||
|
const uint64_t MASK1 = (UINT64_C(1) << H1) - 1;
|
||||||
|
|
||||||
|
uint64_t s1[4] = {0, 0, 0, 0};
|
||||||
|
uint64_t s2[4] = {0, 0, 0, 0};
|
||||||
|
uint64_t s3[4] = {0, 0, 0, 0};
|
||||||
|
uint64_t s4[4] = {0, 0, 0, 0};
|
||||||
|
|
||||||
|
const uint64_t* const x_ptr = (uint64_t*)x;
|
||||||
|
const uint64_t* const y_ptr = (uint64_t*)y;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < 4 * ell; i += 4) {
|
||||||
|
for (uint64_t j = 0; j < 4; ++j) {
|
||||||
|
const uint64_t xl = x_ptr[i + j] & MASK1;
|
||||||
|
const uint64_t xh = x_ptr[i + j] >> H1;
|
||||||
|
const uint64_t yl = y_ptr[i + j] & MASK1;
|
||||||
|
const uint64_t yh = y_ptr[i + j] >> H1;
|
||||||
|
|
||||||
|
const uint64_t a = xl * yl;
|
||||||
|
const uint64_t al = a & MASK1;
|
||||||
|
const uint64_t ah = a >> H1;
|
||||||
|
|
||||||
|
const uint64_t b = xl * yh;
|
||||||
|
const uint64_t bl = b & MASK1;
|
||||||
|
const uint64_t bh = b >> H1;
|
||||||
|
|
||||||
|
const uint64_t c = xh * yl;
|
||||||
|
const uint64_t cl = c & MASK1;
|
||||||
|
const uint64_t ch = c >> H1;
|
||||||
|
|
||||||
|
const uint64_t d = xh * yh;
|
||||||
|
const uint64_t dl = d & MASK1;
|
||||||
|
const uint64_t dh = d >> H1;
|
||||||
|
|
||||||
|
s1[j] += al;
|
||||||
|
s2[j] += ah + bl + cl;
|
||||||
|
s3[j] += bh + ch + dl;
|
||||||
|
s4[j] += dh;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint64_t H2 = precomp->h;
|
||||||
|
const uint64_t MASK2 = (UINT64_C(1) << H2) - 1;
|
||||||
|
|
||||||
|
uint64_t* const res_ptr = (uint64_t*)res;
|
||||||
|
for (uint64_t j = 0; j < 4; ++j) {
|
||||||
|
const uint64_t s1l = s1[j] & MASK2;
|
||||||
|
const uint64_t s1h = s1[j] >> H2;
|
||||||
|
const uint64_t s2l = s2[j] & MASK2;
|
||||||
|
const uint64_t s2h = s2[j] >> H2;
|
||||||
|
const uint64_t s3l = s3[j] & MASK2;
|
||||||
|
const uint64_t s3h = s3[j] >> H2;
|
||||||
|
const uint64_t s4l = s4[j] & MASK2;
|
||||||
|
const uint64_t s4h = s4[j] >> H2;
|
||||||
|
|
||||||
|
uint64_t t = s1l;
|
||||||
|
t += s1h * precomp->s1h_pow_red[j];
|
||||||
|
t += s2l * precomp->s2l_pow_red[j];
|
||||||
|
t += s2h * precomp->s2h_pow_red[j];
|
||||||
|
t += s3l * precomp->s3l_pow_red[j];
|
||||||
|
t += s3h * precomp->s3h_pow_red[j];
|
||||||
|
t += s4l * precomp->s4l_pow_red[j];
|
||||||
|
t += s4h * precomp->s4h_pow_red[j];
|
||||||
|
|
||||||
|
res_ptr[j] = t;
|
||||||
|
assert(log2(res_ptr[j]) < precomp->res_bit_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void vec_mat1col_product_bbc_precomp(q120_mat1col_product_bbc_precomp* precomp) {
|
||||||
|
uint64_t qs[4] = {Q1, Q2, Q3, Q4};
|
||||||
|
|
||||||
|
double min_res_bs = 1000;
|
||||||
|
uint64_t min_h = -1;
|
||||||
|
|
||||||
|
double pow2_32_bs = comp_bit_size_red(32, qs);
|
||||||
|
|
||||||
|
double ell_bs = log2((double)MAX_ELL);
|
||||||
|
double s1_bs = 32 + ell_bs;
|
||||||
|
for (uint64_t h = 16; h < 32; ++h) {
|
||||||
|
double s2l_bs = pow2_32_bs + h;
|
||||||
|
double s2h_bs = s1_bs - h + comp_bit_size_red(32 + h, qs);
|
||||||
|
|
||||||
|
const double bs[] = {s1_bs, s2l_bs, s2h_bs};
|
||||||
|
const double res_bs = comp_bit_size_sum(3, bs);
|
||||||
|
|
||||||
|
if (min_res_bs > res_bs) {
|
||||||
|
min_res_bs = res_bs;
|
||||||
|
min_h = h;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(min_res_bs < 64);
|
||||||
|
precomp->h = min_h;
|
||||||
|
for (uint64_t k = 0; k < 4; ++k) {
|
||||||
|
precomp->s2l_pow_red[k] = MODQ(UINT64_C(1) << 32, qs[k]);
|
||||||
|
precomp->s2h_pow_red[k] = MODQ(UINT64_C(1) << (32 + precomp->h), qs[k]);
|
||||||
|
}
|
||||||
|
#ifndef NDEBUG
|
||||||
|
precomp->res_bit_size = min_res_bs;
|
||||||
|
#endif
|
||||||
|
// printf("AA %lu %lf\n", min_h, min_res_bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT q120_mat1col_product_bbc_precomp* q120_new_vec_mat1col_product_bbc_precomp() {
|
||||||
|
q120_mat1col_product_bbc_precomp* res = malloc(sizeof(q120_mat1col_product_bbc_precomp));
|
||||||
|
vec_mat1col_product_bbc_precomp(res);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120_delete_vec_mat1col_product_bbc_precomp(q120_mat1col_product_bbc_precomp* addr) { free(addr); }
|
||||||
|
|
||||||
|
EXPORT void q120_vec_mat1col_product_bbc_ref_old(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||||
|
/**
|
||||||
|
* Algorithm:
|
||||||
|
* 0. We have
|
||||||
|
* - y0_i == y_i % Q and y1_i == (y_i . 2^32) % Q
|
||||||
|
* 1. Split x_i in 2 32-bit parts and compute the cross-products:
|
||||||
|
* - x_i = xl_i + xh_i . 2^32
|
||||||
|
* - A_i = xl_i . y1_i
|
||||||
|
* - B_i = xh_i . y2_i
|
||||||
|
* - we have x_i . y_i == A_i + B_i
|
||||||
|
* 2. Split A_i and B_i into 2 32-bit parts
|
||||||
|
* - A_i = Al_i + Ah_i . 2^32
|
||||||
|
* - B_i = Bl_i + Bh_i . 2^32
|
||||||
|
* 3. Compute the sums:
|
||||||
|
* - S1 = \sum Al_i + Bl_i
|
||||||
|
* - S2 = \sum Ah_i + Bh_i
|
||||||
|
* - here S1 and S2 have 32 + log2(ell) bits
|
||||||
|
* - for ell == 10000 S1, S2 have < 46 bits
|
||||||
|
* 4. Split S2 in 27-bit and 19-bit parts (27+19 == 46)
|
||||||
|
* - S2 = S2l + S2h . 2^27
|
||||||
|
* 5. Compute final result as:
|
||||||
|
* - \sum x_i . y_i = S1 + S2l . 2^32 + S2h . 2^(32+27)
|
||||||
|
* - here the powers of 2 are reduced modulo the primes Q before
|
||||||
|
* multiplications
|
||||||
|
* - the result will be on < 52 bits
|
||||||
|
*/
|
||||||
|
|
||||||
|
const uint64_t H1 = 32;
|
||||||
|
const uint64_t MASK1 = (UINT64_C(1) << H1) - 1;
|
||||||
|
|
||||||
|
uint64_t s1[4] = {0, 0, 0, 0};
|
||||||
|
uint64_t s2[4] = {0, 0, 0, 0};
|
||||||
|
|
||||||
|
const uint64_t* const x_ptr = (uint64_t*)x;
|
||||||
|
const uint32_t* const y_ptr = (uint32_t*)y;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < 4 * ell; i += 4) {
|
||||||
|
for (uint64_t j = 0; j < 4; ++j) {
|
||||||
|
const uint64_t xl = x_ptr[i + j] & MASK1;
|
||||||
|
const uint64_t xh = x_ptr[i + j] >> H1;
|
||||||
|
const uint64_t y0 = y_ptr[2 * (i + j)];
|
||||||
|
const uint64_t y1 = y_ptr[2 * (i + j) + 1];
|
||||||
|
|
||||||
|
const uint64_t a = xl * y0;
|
||||||
|
const uint64_t al = a & MASK1;
|
||||||
|
const uint64_t ah = a >> H1;
|
||||||
|
|
||||||
|
const uint64_t b = xh * y1;
|
||||||
|
const uint64_t bl = b & MASK1;
|
||||||
|
const uint64_t bh = b >> H1;
|
||||||
|
|
||||||
|
s1[j] += al + bl;
|
||||||
|
s2[j] += ah + bh;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint64_t H2 = precomp->h;
|
||||||
|
const uint64_t MASK2 = (UINT64_C(1) << H2) - 1;
|
||||||
|
|
||||||
|
uint64_t* const res_ptr = (uint64_t*)res;
|
||||||
|
for (uint64_t k = 0; k < 4; ++k) {
|
||||||
|
const uint64_t s2l = s2[k] & MASK2;
|
||||||
|
const uint64_t s2h = s2[k] >> H2;
|
||||||
|
|
||||||
|
uint64_t t = s1[k];
|
||||||
|
t += s2l * precomp->s2l_pow_red[k];
|
||||||
|
t += s2h * precomp->s2h_pow_red[k];
|
||||||
|
|
||||||
|
res_ptr[k] = t;
|
||||||
|
assert(log2(res_ptr[k]) < precomp->res_bit_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __always_inline void accum_mul_q120_bc(uint64_t res[8], //
|
||||||
|
const uint32_t x_layb[8], const uint32_t y_layc[8]) {
|
||||||
|
for (uint64_t i = 0; i < 4; ++i) {
|
||||||
|
static const uint64_t MASK32 = 0xFFFFFFFFUL;
|
||||||
|
uint64_t x_lo = x_layb[2 * i];
|
||||||
|
uint64_t x_hi = x_layb[2 * i + 1];
|
||||||
|
uint64_t y_lo = y_layc[2 * i];
|
||||||
|
uint64_t y_hi = y_layc[2 * i + 1];
|
||||||
|
uint64_t xy_lo = x_lo * y_lo;
|
||||||
|
uint64_t xy_hi = x_hi * y_hi;
|
||||||
|
res[2 * i] += (xy_lo & MASK32) + (xy_hi & MASK32);
|
||||||
|
res[2 * i + 1] += (xy_lo >> 32) + (xy_hi >> 32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __always_inline void accum_to_q120b(uint64_t res[4], //
|
||||||
|
const uint64_t s[8], const q120_mat1col_product_bbc_precomp* precomp) {
|
||||||
|
const uint64_t H2 = precomp->h;
|
||||||
|
const uint64_t MASK2 = (UINT64_C(1) << H2) - 1;
|
||||||
|
for (uint64_t k = 0; k < 4; ++k) {
|
||||||
|
const uint64_t s2l = s[2 * k + 1] & MASK2;
|
||||||
|
const uint64_t s2h = s[2 * k + 1] >> H2;
|
||||||
|
uint64_t t = s[2 * k];
|
||||||
|
t += s2l * precomp->s2l_pow_red[k];
|
||||||
|
t += s2h * precomp->s2h_pow_red[k];
|
||||||
|
res[k] = t;
|
||||||
|
assert(log2(res[k]) < precomp->res_bit_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||||
|
uint64_t s[8] = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||||
|
|
||||||
|
const uint32_t(*const x_ptr)[8] = (const uint32_t(*const)[8])x;
|
||||||
|
const uint32_t(*const y_ptr)[8] = (const uint32_t(*const)[8])y;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < ell; i++) {
|
||||||
|
accum_mul_q120_bc(s, x_ptr[i], y_ptr[i]);
|
||||||
|
}
|
||||||
|
accum_to_q120b((uint64_t*)res, s, precomp);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120x2_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||||
|
uint64_t s[2][16] = {0};
|
||||||
|
|
||||||
|
const uint32_t(*const x_ptr)[2][8] = (const uint32_t(*const)[2][8])x;
|
||||||
|
const uint32_t(*const y_ptr)[2][8] = (const uint32_t(*const)[2][8])y;
|
||||||
|
uint64_t(*re)[4] = (uint64_t(*)[4])res;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < ell; i++) {
|
||||||
|
accum_mul_q120_bc(s[0], x_ptr[i][0], y_ptr[i][0]);
|
||||||
|
accum_mul_q120_bc(s[1], x_ptr[i][1], y_ptr[i][1]);
|
||||||
|
}
|
||||||
|
accum_to_q120b(re[0], s[0], precomp);
|
||||||
|
accum_to_q120b(re[1], s[1], precomp);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120x2_vec_mat2cols_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||||
|
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||||
|
uint64_t s[4][16] = {0};
|
||||||
|
|
||||||
|
const uint32_t(*const x_ptr)[2][8] = (const uint32_t(*const)[2][8])x;
|
||||||
|
const uint32_t(*const y_ptr)[4][8] = (const uint32_t(*const)[4][8])y;
|
||||||
|
uint64_t(*re)[4] = (uint64_t(*)[4])res;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < ell; i++) {
|
||||||
|
accum_mul_q120_bc(s[0], x_ptr[i][0], y_ptr[i][0]);
|
||||||
|
accum_mul_q120_bc(s[1], x_ptr[i][1], y_ptr[i][1]);
|
||||||
|
accum_mul_q120_bc(s[2], x_ptr[i][0], y_ptr[i][2]);
|
||||||
|
accum_mul_q120_bc(s[3], x_ptr[i][1], y_ptr[i][3]);
|
||||||
|
}
|
||||||
|
accum_to_q120b(re[0], s[0], precomp);
|
||||||
|
accum_to_q120b(re[1], s[1], precomp);
|
||||||
|
accum_to_q120b(re[2], s[2], precomp);
|
||||||
|
accum_to_q120b(re[3], s[3], precomp);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120x2_extract_1blk_from_q120b_ref(uint64_t nn, uint64_t blk,
|
||||||
|
q120x2b* const dst, // 8 doubles
|
||||||
|
const q120b* const src // a q120b vector
|
||||||
|
) {
|
||||||
|
const uint64_t* in = (uint64_t*)src;
|
||||||
|
uint64_t* out = (uint64_t*)dst;
|
||||||
|
for (uint64_t i = 0; i < 8; ++i) {
|
||||||
|
out[i] = in[8 * blk + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// function on layout c is the exact same as on layout b
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#pragma weak q120x2_extract_1blk_from_q120c_ref = q120x2_extract_1blk_from_q120b_ref
|
||||||
|
#else
|
||||||
|
EXPORT void q120x2_extract_1blk_from_q120c_ref(uint64_t nn, uint64_t blk,
|
||||||
|
q120x2c* const dst, // 8 doubles
|
||||||
|
const q120c* const src // a q120c vector
|
||||||
|
) __attribute__((alias("q120x2_extract_1blk_from_q120b_ref")));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
EXPORT void q120x2_extract_1blk_from_contiguous_q120b_ref(
|
||||||
|
uint64_t nn, uint64_t nrows, uint64_t blk,
|
||||||
|
q120x2b* const dst, // nrows * 2 q120
|
||||||
|
const q120b* const src // a contiguous array of nrows q120b vectors
|
||||||
|
) {
|
||||||
|
const uint64_t* in = (uint64_t*)src;
|
||||||
|
uint64_t* out = (uint64_t*)dst;
|
||||||
|
for (uint64_t row = 0; row < nrows; ++row) {
|
||||||
|
for (uint64_t i = 0; i < 8; ++i) {
|
||||||
|
out[i] = in[8 * blk + i];
|
||||||
|
}
|
||||||
|
in += 4 * nn;
|
||||||
|
out += 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120x2b_save_1blk_to_q120b_ref(uint64_t nn, uint64_t blk,
|
||||||
|
q120b* dest, // 1 reim vector of length m
|
||||||
|
const q120x2b* src // 8 doubles
|
||||||
|
) {
|
||||||
|
const uint64_t* in = (uint64_t*)src;
|
||||||
|
uint64_t* out = (uint64_t*)dest;
|
||||||
|
for (uint64_t i = 0; i < 8; ++i) {
|
||||||
|
out[8 * blk + i] = in[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
111
spqlios/lib/spqlios/q120/q120_arithmetic_simple.c
Normal file
111
spqlios/lib/spqlios/q120/q120_arithmetic_simple.c
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
#include <assert.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "q120_arithmetic.h"
|
||||||
|
#include "q120_common.h"
|
||||||
|
|
||||||
|
EXPORT void q120_add_bbb_simple(uint64_t nn, q120b* const res, const q120b* const x, const q120b* const y) {
|
||||||
|
const uint64_t* x_u64 = (uint64_t*)x;
|
||||||
|
const uint64_t* y_u64 = (uint64_t*)y;
|
||||||
|
uint64_t* res_u64 = (uint64_t*)res;
|
||||||
|
for (uint64_t i = 0; i < 4 * nn; i += 4) {
|
||||||
|
res_u64[i + 0] = x_u64[i + 0] % ((uint64_t)Q1 << 33) + y_u64[i + 0] % ((uint64_t)Q1 << 33);
|
||||||
|
res_u64[i + 1] = x_u64[i + 1] % ((uint64_t)Q2 << 33) + y_u64[i + 1] % ((uint64_t)Q2 << 33);
|
||||||
|
res_u64[i + 2] = x_u64[i + 2] % ((uint64_t)Q3 << 33) + y_u64[i + 2] % ((uint64_t)Q3 << 33);
|
||||||
|
res_u64[i + 3] = x_u64[i + 3] % ((uint64_t)Q4 << 33) + y_u64[i + 3] % ((uint64_t)Q4 << 33);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120_add_ccc_simple(uint64_t nn, q120c* const res, const q120c* const x, const q120c* const y) {
|
||||||
|
const uint32_t* x_u32 = (uint32_t*)x;
|
||||||
|
const uint32_t* y_u32 = (uint32_t*)y;
|
||||||
|
uint32_t* res_u32 = (uint32_t*)res;
|
||||||
|
for (uint64_t i = 0; i < 8 * nn; i += 8) {
|
||||||
|
res_u32[i + 0] = (uint32_t)(((uint64_t)x_u32[i + 0] + (uint64_t)y_u32[i + 0]) % Q1);
|
||||||
|
res_u32[i + 1] = (uint32_t)(((uint64_t)x_u32[i + 1] + (uint64_t)y_u32[i + 1]) % Q1);
|
||||||
|
res_u32[i + 2] = (uint32_t)(((uint64_t)x_u32[i + 2] + (uint64_t)y_u32[i + 2]) % Q2);
|
||||||
|
res_u32[i + 3] = (uint32_t)(((uint64_t)x_u32[i + 3] + (uint64_t)y_u32[i + 3]) % Q2);
|
||||||
|
res_u32[i + 4] = (uint32_t)(((uint64_t)x_u32[i + 4] + (uint64_t)y_u32[i + 4]) % Q3);
|
||||||
|
res_u32[i + 5] = (uint32_t)(((uint64_t)x_u32[i + 5] + (uint64_t)y_u32[i + 5]) % Q3);
|
||||||
|
res_u32[i + 6] = (uint32_t)(((uint64_t)x_u32[i + 6] + (uint64_t)y_u32[i + 6]) % Q4);
|
||||||
|
res_u32[i + 7] = (uint32_t)(((uint64_t)x_u32[i + 7] + (uint64_t)y_u32[i + 7]) % Q4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120_c_from_b_simple(uint64_t nn, q120c* const res, const q120b* const x) {
|
||||||
|
const uint64_t* x_u64 = (uint64_t*)x;
|
||||||
|
uint32_t* res_u32 = (uint32_t*)res;
|
||||||
|
for (uint64_t i = 0, j = 0; i < 4 * nn; i += 4, j += 8) {
|
||||||
|
res_u32[j + 0] = x_u64[i + 0] % Q1;
|
||||||
|
res_u32[j + 1] = ((uint64_t)res_u32[j + 0] << 32) % Q1;
|
||||||
|
res_u32[j + 2] = x_u64[i + 1] % Q2;
|
||||||
|
res_u32[j + 3] = ((uint64_t)res_u32[j + 2] << 32) % Q2;
|
||||||
|
res_u32[j + 4] = x_u64[i + 2] % Q3;
|
||||||
|
res_u32[j + 5] = ((uint64_t)res_u32[j + 4] << 32) % Q3;
|
||||||
|
res_u32[j + 6] = x_u64[i + 3] % Q4;
|
||||||
|
res_u32[j + 7] = ((uint64_t)res_u32[j + 6] << 32) % Q4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120_b_from_znx64_simple(uint64_t nn, q120b* const res, const int64_t* const x) {
|
||||||
|
static const int64_t MASK_HI = INT64_C(0x8000000000000000);
|
||||||
|
static const int64_t MASK_LO = ~MASK_HI;
|
||||||
|
static const uint64_t OQ[4] = {
|
||||||
|
(Q1 - (UINT64_C(0x8000000000000000) % Q1)),
|
||||||
|
(Q2 - (UINT64_C(0x8000000000000000) % Q2)),
|
||||||
|
(Q3 - (UINT64_C(0x8000000000000000) % Q3)),
|
||||||
|
(Q4 - (UINT64_C(0x8000000000000000) % Q4)),
|
||||||
|
};
|
||||||
|
uint64_t* res_u64 = (uint64_t*)res;
|
||||||
|
for (uint64_t i = 0, j = 0; j < nn; i += 4, ++j) {
|
||||||
|
uint64_t xj_lo = x[j] & MASK_LO;
|
||||||
|
uint64_t xj_hi = x[j] & MASK_HI;
|
||||||
|
res_u64[i + 0] = xj_lo + (xj_hi ? OQ[0] : 0);
|
||||||
|
res_u64[i + 1] = xj_lo + (xj_hi ? OQ[1] : 0);
|
||||||
|
res_u64[i + 2] = xj_lo + (xj_hi ? OQ[2] : 0);
|
||||||
|
res_u64[i + 3] = xj_lo + (xj_hi ? OQ[3] : 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static int64_t posmod(int64_t x, int64_t q) {
|
||||||
|
int64_t t = x % q;
|
||||||
|
if (t < 0)
|
||||||
|
return t + q;
|
||||||
|
else
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120_c_from_znx64_simple(uint64_t nn, q120c* const res, const int64_t* const x) {
|
||||||
|
uint32_t* res_u32 = (uint32_t*)res;
|
||||||
|
for (uint64_t i = 0, j = 0; j < nn; i += 8, ++j) {
|
||||||
|
res_u32[i + 0] = posmod(x[j], Q1);
|
||||||
|
res_u32[i + 1] = ((uint64_t)res_u32[i + 0] << 32) % Q1;
|
||||||
|
res_u32[i + 2] = posmod(x[j], Q2);
|
||||||
|
res_u32[i + 3] = ((uint64_t)res_u32[i + 2] << 32) % Q2;
|
||||||
|
res_u32[i + 4] = posmod(x[j], Q3);
|
||||||
|
res_u32[i + 5] = ((uint64_t)res_u32[i + 4] << 32) % Q3;
|
||||||
|
res_u32[i + 6] = posmod(x[j], Q4);
|
||||||
|
res_u32[i + 7] = ((uint64_t)res_u32[i + 6] << 32) % Q4;
|
||||||
|
;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120_b_to_znx128_simple(uint64_t nn, __int128_t* const res, const q120b* const x) {
|
||||||
|
static const __int128_t Q = (__int128_t)Q1 * Q2 * Q3 * Q4;
|
||||||
|
static const __int128_t Qm1 = (__int128_t)Q2 * Q3 * Q4;
|
||||||
|
static const __int128_t Qm2 = (__int128_t)Q1 * Q3 * Q4;
|
||||||
|
static const __int128_t Qm3 = (__int128_t)Q1 * Q2 * Q4;
|
||||||
|
static const __int128_t Qm4 = (__int128_t)Q1 * Q2 * Q3;
|
||||||
|
|
||||||
|
const uint64_t* x_u64 = (uint64_t*)x;
|
||||||
|
for (uint64_t i = 0, j = 0; j < nn; i += 4, ++j) {
|
||||||
|
__int128_t tmp = 0;
|
||||||
|
tmp += (((x_u64[i + 0] % Q1) * Q1_CRT_CST) % Q1) * Qm1;
|
||||||
|
tmp += (((x_u64[i + 1] % Q2) * Q2_CRT_CST) % Q2) * Qm2;
|
||||||
|
tmp += (((x_u64[i + 2] % Q3) * Q3_CRT_CST) % Q3) * Qm3;
|
||||||
|
tmp += (((x_u64[i + 3] % Q4) * Q4_CRT_CST) % Q4) * Qm4;
|
||||||
|
tmp %= Q;
|
||||||
|
res[j] = (tmp >= (Q + 1) / 2) ? tmp - Q : tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
94
spqlios/lib/spqlios/q120/q120_common.h
Normal file
94
spqlios/lib/spqlios/q120/q120_common.h
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
#ifndef SPQLIOS_Q120_COMMON_H
|
||||||
|
#define SPQLIOS_Q120_COMMON_H
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#if !defined(SPQLIOS_Q120_USE_29_BIT_PRIMES) && !defined(SPQLIOS_Q120_USE_30_BIT_PRIMES) && \
|
||||||
|
!defined(SPQLIOS_Q120_USE_31_BIT_PRIMES)
|
||||||
|
#define SPQLIOS_Q120_USE_30_BIT_PRIMES
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 29-bit primes and 2*2^16 roots of unity
|
||||||
|
*/
|
||||||
|
#ifdef SPQLIOS_Q120_USE_29_BIT_PRIMES
|
||||||
|
#define Q1 ((1u << 29) - 2 * (1u << 17) + 1)
|
||||||
|
#define OMEGA1 78289835
|
||||||
|
#define Q1_CRT_CST 301701286 // (Q2*Q3*Q4)^-1 mod Q1
|
||||||
|
|
||||||
|
#define Q2 ((1u << 29) - 5 * (1u << 17) + 1)
|
||||||
|
#define OMEGA2 178519192
|
||||||
|
#define Q2_CRT_CST 536020447 // (Q1*Q3*Q4)^-1 mod Q2
|
||||||
|
|
||||||
|
#define Q3 ((1u << 29) - 26 * (1u << 17) + 1)
|
||||||
|
#define OMEGA3 483889678
|
||||||
|
#define Q3_CRT_CST 86367873 // (Q1*Q2*Q4)^-1 mod Q3
|
||||||
|
|
||||||
|
#define Q4 ((1u << 29) - 35 * (1u << 17) + 1)
|
||||||
|
#define OMEGA4 239808033
|
||||||
|
#define Q4_CRT_CST 147030781 // (Q1*Q2*Q3)^-1 mod Q4
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 30-bit primes and 2*2^16 roots of unity
|
||||||
|
*/
|
||||||
|
#ifdef SPQLIOS_Q120_USE_30_BIT_PRIMES
|
||||||
|
#define Q1 ((1u << 30) - 2 * (1u << 17) + 1)
|
||||||
|
#define OMEGA1 1070907127
|
||||||
|
#define Q1_CRT_CST 43599465 // (Q2*Q3*Q4)^-1 mod Q1
|
||||||
|
|
||||||
|
#define Q2 ((1u << 30) - 17 * (1u << 17) + 1)
|
||||||
|
#define OMEGA2 315046632
|
||||||
|
#define Q2_CRT_CST 292938863 // (Q1*Q3*Q4)^-1 mod Q2
|
||||||
|
|
||||||
|
#define Q3 ((1u << 30) - 23 * (1u << 17) + 1)
|
||||||
|
#define OMEGA3 309185662
|
||||||
|
#define Q3_CRT_CST 594011630 // (Q1*Q2*Q4)^-1 mod Q3
|
||||||
|
|
||||||
|
#define Q4 ((1u << 30) - 42 * (1u << 17) + 1)
|
||||||
|
#define OMEGA4 846468380
|
||||||
|
#define Q4_CRT_CST 140177212 // (Q1*Q2*Q3)^-1 mod Q4
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 31-bit primes and 2*2^16 roots of unity
|
||||||
|
*/
|
||||||
|
#ifdef SPQLIOS_Q120_USE_31_BIT_PRIMES
|
||||||
|
#define Q1 ((1u << 31) - 1 * (1u << 17) + 1)
|
||||||
|
#define OMEGA1 1615402923
|
||||||
|
#define Q1_CRT_CST 1811422063 // (Q2*Q3*Q4)^-1 mod Q1
|
||||||
|
|
||||||
|
#define Q2 ((1u << 31) - 4 * (1u << 17) + 1)
|
||||||
|
#define OMEGA2 1137738560
|
||||||
|
#define Q2_CRT_CST 2093150204 // (Q1*Q3*Q4)^-1 mod Q2
|
||||||
|
|
||||||
|
#define Q3 ((1u << 31) - 11 * (1u << 17) + 1)
|
||||||
|
#define OMEGA3 154880552
|
||||||
|
#define Q3_CRT_CST 164149010 // (Q1*Q2*Q4)^-1 mod Q3
|
||||||
|
|
||||||
|
#define Q4 ((1u << 31) - 23 * (1u << 17) + 1)
|
||||||
|
#define OMEGA4 558784885
|
||||||
|
#define Q4_CRT_CST 225197446 // (Q1*Q2*Q3)^-1 mod Q4
|
||||||
|
#endif
|
||||||
|
|
||||||
|
static const uint32_t PRIMES_VEC[4] = {Q1, Q2, Q3, Q4};
|
||||||
|
static const uint32_t OMEGAS_VEC[4] = {OMEGA1, OMEGA2, OMEGA3, OMEGA4};
|
||||||
|
|
||||||
|
#define MAX_ELL 10000
|
||||||
|
|
||||||
|
// each number x mod Q120 is represented by uint64_t[4] with (non-unique) values (x mod q1, x mod q2,x mod q3,x mod q4),
|
||||||
|
// each between [0 and 2^32-1]
|
||||||
|
typedef struct _q120a q120a;
|
||||||
|
|
||||||
|
// each number x mod Q120 is represented by uint64_t[4] with (non-unique) values (x mod q1, x mod q2,x mod q3,x mod q4),
|
||||||
|
// each between [0 and 2^64-1]
|
||||||
|
typedef struct _q120b q120b;
|
||||||
|
|
||||||
|
// each number x mod Q120 is represented by uint32_t[8] with values (x mod q1, 2^32x mod q1, x mod q2, 2^32.x mod q2, x
|
||||||
|
// mod q3, 2^32.x mod q3, x mod q4, 2^32.x mod q4) each between [0 and 2^32-1]
|
||||||
|
typedef struct _q120c q120c;
|
||||||
|
|
||||||
|
typedef struct _q120x2b q120x2b;
|
||||||
|
typedef struct _q120x2c q120x2c;
|
||||||
|
|
||||||
|
#endif // SPQLIOS_Q120_COMMON_H
|
||||||
5
spqlios/lib/spqlios/q120/q120_fallbacks_aarch64.c
Normal file
5
spqlios/lib/spqlios/q120/q120_fallbacks_aarch64.c
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
#include "q120_ntt_private.h"
|
||||||
|
|
||||||
|
EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data) { UNDEFINED(); }
|
||||||
|
|
||||||
|
EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data) { UNDEFINED(); }
|
||||||
340
spqlios/lib/spqlios/q120/q120_ntt.c
Normal file
340
spqlios/lib/spqlios/q120/q120_ntt.c
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
#include <assert.h>
|
||||||
|
#include <inttypes.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "q120_ntt_private.h"
|
||||||
|
|
||||||
|
q120_ntt_precomp* new_precomp(const uint64_t n) {
|
||||||
|
q120_ntt_precomp* precomp = malloc(sizeof(*precomp));
|
||||||
|
precomp->n = n;
|
||||||
|
|
||||||
|
assert(n && !(n & (n - 1)) && n <= (1 << 16)); // n is a power of 2 smaller than 2^16
|
||||||
|
const uint64_t logN = ceil(log2(n));
|
||||||
|
precomp->level_metadata = malloc((logN + 2) * sizeof(*precomp->level_metadata));
|
||||||
|
|
||||||
|
precomp->powomega = spqlios_alloc_custom_align(32, 4 * 2 * n * sizeof(*(precomp->powomega)));
|
||||||
|
|
||||||
|
return precomp;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t modq_pow(const uint32_t x, const int64_t n, const uint32_t q) {
|
||||||
|
uint64_t np = (n % (q - 1) + q - 1) % (q - 1);
|
||||||
|
|
||||||
|
uint64_t val_pow = x;
|
||||||
|
uint64_t res = 1;
|
||||||
|
while (np != 0) {
|
||||||
|
if (np & 1) res = (res * val_pow) % q;
|
||||||
|
val_pow = (val_pow * val_pow) % q;
|
||||||
|
np >>= 1;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill_omegas(const uint64_t n, uint32_t omegas[4]) {
|
||||||
|
for (uint64_t k = 0; k < 4; ++k) {
|
||||||
|
omegas[k] = modq_pow(OMEGAS_VEC[k], (1 << 16) / n, PRIMES_VEC[k]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
|
||||||
|
const uint64_t logQ = ceil(log2(PRIMES_VEC[0]));
|
||||||
|
for (int k = 1; k < 4; ++k) {
|
||||||
|
if (logQ != ceil(log2(PRIMES_VEC[k]))) {
|
||||||
|
fprintf(stderr, "The 4 primes must have the same bit-size\n");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if each omega is a 2.n primitive root of unity
|
||||||
|
for (uint64_t k = 0; k < 4; ++k) {
|
||||||
|
assert(modq_pow(omegas[k], 2 * n, PRIMES_VEC[k]) == 1);
|
||||||
|
for (uint64_t i = 1; i < 2 * n; ++i) {
|
||||||
|
assert(modq_pow(omegas[k], i, PRIMES_VEC[k]) != 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (logQ > 31) {
|
||||||
|
fprintf(stderr, "Modulus q bit-size is larger than 30 bit\n");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t fill_reduction_meta(const uint64_t bs_start, q120_ntt_reduc_step_precomp* reduc_metadata) {
|
||||||
|
// fill reduction metadata
|
||||||
|
uint64_t bs_after_reduc = -1;
|
||||||
|
{
|
||||||
|
uint64_t min_h = -1;
|
||||||
|
|
||||||
|
for (uint64_t h = bs_start / 2; h < bs_start; ++h) {
|
||||||
|
uint64_t t = 0;
|
||||||
|
for (uint64_t k = 0; k < 4; ++k) {
|
||||||
|
const uint64_t t1 = bs_start - h + (uint64_t)ceil(log2((UINT64_C(1) << h) % PRIMES_VEC[k]));
|
||||||
|
const uint64_t t2 = UINT64_C(1) + ((t1 > h) ? t1 : h);
|
||||||
|
if (t < t2) t = t2;
|
||||||
|
}
|
||||||
|
if (t < bs_after_reduc) {
|
||||||
|
min_h = h;
|
||||||
|
bs_after_reduc = t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reduc_metadata->h = min_h;
|
||||||
|
reduc_metadata->mask = (UINT64_C(1) << min_h) - 1;
|
||||||
|
for (uint64_t k = 0; k < 4; ++k) {
|
||||||
|
reduc_metadata->modulo_red_cst[k] = (UINT64_C(1) << min_h) % PRIMES_VEC[k];
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(bs_after_reduc < 64);
|
||||||
|
}
|
||||||
|
|
||||||
|
return bs_after_reduc;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t round_up_half_n(const uint64_t n) { return (n + 1) / 2; }
|
||||||
|
|
||||||
|
EXPORT q120_ntt_precomp* q120_new_ntt_bb_precomp(const uint64_t n) {
|
||||||
|
uint32_t omega_vec[4];
|
||||||
|
fill_omegas(n, omega_vec);
|
||||||
|
|
||||||
|
const uint64_t logQ = ceil(log2(PRIMES_VEC[0]));
|
||||||
|
|
||||||
|
q120_ntt_precomp* precomp = new_precomp(n);
|
||||||
|
|
||||||
|
uint64_t bs = precomp->input_bit_size = 64;
|
||||||
|
|
||||||
|
LOG("NTT parameters:\n");
|
||||||
|
LOG("\tsize = %" PRIu64 "\n", n)
|
||||||
|
LOG("\tlogQ = %" PRIu64 "\n", logQ);
|
||||||
|
LOG("\tinput bit-size = %" PRIu64 "\n", bs);
|
||||||
|
|
||||||
|
if (n == 1) return precomp;
|
||||||
|
|
||||||
|
// fill reduction metadata
|
||||||
|
uint64_t bs_after_reduc = fill_reduction_meta(bs, &(precomp->reduc_metadata));
|
||||||
|
|
||||||
|
// forward metadata
|
||||||
|
q120_ntt_step_precomp* level_metadata_ptr = precomp->level_metadata;
|
||||||
|
|
||||||
|
// first level a_k.omega^k
|
||||||
|
{
|
||||||
|
const uint64_t half_bs = (bs + 1) / 2;
|
||||||
|
level_metadata_ptr->half_bs = half_bs;
|
||||||
|
level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1);
|
||||||
|
level_metadata_ptr->bs = bs = half_bs + logQ + 1;
|
||||||
|
LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 " (a_k.omega^k) \n", n, bs);
|
||||||
|
level_metadata_ptr++;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint64_t nn = n; nn >= 4; nn /= 2) {
|
||||||
|
level_metadata_ptr->reduce = (bs == 64);
|
||||||
|
if (level_metadata_ptr->reduce) {
|
||||||
|
bs = bs_after_reduc;
|
||||||
|
LOG("\treduce output bit-size = %" PRIu64 "\n", bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs - logQ);
|
||||||
|
|
||||||
|
double bs_1 = bs + 1; // bit-size of term a+b or a-b
|
||||||
|
|
||||||
|
const uint64_t half_bs = round_up_half_n(bs_1);
|
||||||
|
uint64_t bs_2 = half_bs + logQ + 1; // bit-size of term (a-b).omega^k
|
||||||
|
bs = (bs_1 > bs_2) ? bs_1 : bs_2;
|
||||||
|
assert(bs <= 64);
|
||||||
|
|
||||||
|
level_metadata_ptr->bs = bs;
|
||||||
|
level_metadata_ptr->half_bs = half_bs;
|
||||||
|
level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1);
|
||||||
|
level_metadata_ptr++;
|
||||||
|
|
||||||
|
LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", nn / 2, bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// last level (a-b, a+b)
|
||||||
|
{
|
||||||
|
level_metadata_ptr->reduce = (bs == 64);
|
||||||
|
if (level_metadata_ptr->reduce) {
|
||||||
|
bs = bs_after_reduc;
|
||||||
|
LOG("\treduce output bit-size = %" PRIu64 "\n", bs);
|
||||||
|
}
|
||||||
|
for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = ((uint64_t)PRIMES_VEC[k] << (bs - logQ));
|
||||||
|
level_metadata_ptr->bs = ++bs;
|
||||||
|
level_metadata_ptr->half_bs = level_metadata_ptr->mask = UINT64_C(0); // not used
|
||||||
|
|
||||||
|
LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", UINT64_C(1), bs);
|
||||||
|
}
|
||||||
|
precomp->output_bit_size = bs;
|
||||||
|
|
||||||
|
// omega powers
|
||||||
|
uint64_t* powomega = malloc(sizeof(*powomega) * 2 * n);
|
||||||
|
for (uint64_t k = 0; k < 4; ++k) {
|
||||||
|
const uint64_t q = PRIMES_VEC[k];
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < 2 * n; ++i) {
|
||||||
|
powomega[i] = modq_pow(omega_vec[k], i, q);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t* powomega_ptr = precomp->powomega + k;
|
||||||
|
level_metadata_ptr = precomp->level_metadata;
|
||||||
|
|
||||||
|
{
|
||||||
|
// const uint64_t hpow = UINT64_C(1) << level_metadata_ptr->half_bs;
|
||||||
|
for (uint64_t i = 0; i < n; ++i) {
|
||||||
|
uint64_t t = powomega[i];
|
||||||
|
uint64_t t1 = (t << level_metadata_ptr->half_bs) % q;
|
||||||
|
powomega_ptr[4 * i] = (t1 << 32) + t;
|
||||||
|
}
|
||||||
|
powomega_ptr += 4 * n;
|
||||||
|
level_metadata_ptr++;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint64_t nn = n; nn >= 4; nn /= 2) {
|
||||||
|
const uint64_t halfnn = nn / 2;
|
||||||
|
const uint64_t m = n / halfnn;
|
||||||
|
|
||||||
|
// const uint64_t hpow = UINT64_C(1) << level_metadata_ptr->half_bs;
|
||||||
|
for (uint64_t i = 1; i < halfnn; ++i) {
|
||||||
|
uint64_t t = powomega[i * m];
|
||||||
|
uint64_t t1 = (t << level_metadata_ptr->half_bs) % q;
|
||||||
|
powomega_ptr[4 * (i - 1)] = (t1 << 32) + t;
|
||||||
|
}
|
||||||
|
powomega_ptr += 4 * (halfnn - 1);
|
||||||
|
level_metadata_ptr++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
free(powomega);
|
||||||
|
|
||||||
|
return precomp;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT q120_ntt_precomp* q120_new_intt_bb_precomp(const uint64_t n) {
|
||||||
|
uint32_t omega_vec[4];
|
||||||
|
fill_omegas(n, omega_vec);
|
||||||
|
|
||||||
|
const uint64_t logQ = ceil(log2(PRIMES_VEC[0]));
|
||||||
|
|
||||||
|
q120_ntt_precomp* precomp = new_precomp(n);
|
||||||
|
|
||||||
|
uint64_t bs = precomp->input_bit_size = 64;
|
||||||
|
|
||||||
|
LOG("iNTT parameters:\n");
|
||||||
|
LOG("\tsize = %" PRIu64 "\n", n)
|
||||||
|
LOG("\tlogQ = %" PRIu64 "\n", logQ);
|
||||||
|
LOG("\tinput bit-size = %" PRIu64 "\n", bs);
|
||||||
|
|
||||||
|
if (n == 1) return precomp;
|
||||||
|
|
||||||
|
// fill reduction metadata
|
||||||
|
uint64_t bs_after_reduc = fill_reduction_meta(bs, &(precomp->reduc_metadata));
|
||||||
|
|
||||||
|
// backward metadata
|
||||||
|
q120_ntt_step_precomp* level_metadata_ptr = precomp->level_metadata;
|
||||||
|
|
||||||
|
// first level (a+b, a-b) adds 1-bit
|
||||||
|
{
|
||||||
|
level_metadata_ptr->reduce = (bs == 64);
|
||||||
|
if (level_metadata_ptr->reduce) {
|
||||||
|
bs = bs_after_reduc;
|
||||||
|
LOG("\treduce output bit-size = %" PRIu64 "\n", bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs - logQ);
|
||||||
|
|
||||||
|
level_metadata_ptr->bs = ++bs;
|
||||||
|
level_metadata_ptr->half_bs = level_metadata_ptr->mask = UINT64_C(0); // not used
|
||||||
|
LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", UINT64_C(1), bs);
|
||||||
|
level_metadata_ptr++;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint64_t nn = 4; nn <= n; nn *= 2) {
|
||||||
|
level_metadata_ptr->reduce = (bs == 64);
|
||||||
|
if (level_metadata_ptr->reduce) {
|
||||||
|
bs = bs_after_reduc;
|
||||||
|
LOG("\treduce output bit-size = %" PRIu64 "\n", bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint64_t half_bs = round_up_half_n(bs);
|
||||||
|
const uint64_t bs_mult = half_bs + logQ + 1; // bit-size of term b.omega^k
|
||||||
|
bs = 1 + ((bs > bs_mult) ? bs : bs_mult); // bit-size of a+b.omega^k or a-b.omega^k
|
||||||
|
assert(bs <= 64);
|
||||||
|
|
||||||
|
level_metadata_ptr->bs = bs;
|
||||||
|
level_metadata_ptr->half_bs = half_bs;
|
||||||
|
level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1);
|
||||||
|
for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs_mult - logQ);
|
||||||
|
level_metadata_ptr++;
|
||||||
|
|
||||||
|
LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", nn / 2, bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// last level a_k.omega^k
|
||||||
|
{
|
||||||
|
level_metadata_ptr->reduce = (bs == 64);
|
||||||
|
if (level_metadata_ptr->reduce) {
|
||||||
|
bs = bs_after_reduc;
|
||||||
|
LOG("\treduce output bit-size = %" PRIu64 "\n", bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint64_t half_bs = round_up_half_n(bs);
|
||||||
|
|
||||||
|
bs = half_bs + logQ + 1; // bit-size of term a.omega^k
|
||||||
|
assert(bs <= 64);
|
||||||
|
|
||||||
|
level_metadata_ptr->bs = bs;
|
||||||
|
level_metadata_ptr->half_bs = half_bs;
|
||||||
|
level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1);
|
||||||
|
for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs - logQ);
|
||||||
|
|
||||||
|
LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", n, bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// omega powers
|
||||||
|
uint32_t* powomegabar = malloc(sizeof(*powomegabar) * 2 * n);
|
||||||
|
for (int k = 0; k < 4; ++k) {
|
||||||
|
const uint64_t q = PRIMES_VEC[k];
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < 2 * n; ++i) {
|
||||||
|
powomegabar[i] = modq_pow(omega_vec[k], -i, q);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t* powomega_ptr = precomp->powomega + k;
|
||||||
|
level_metadata_ptr = precomp->level_metadata + 1;
|
||||||
|
|
||||||
|
for (uint64_t nn = 4; nn <= n; nn *= 2) {
|
||||||
|
const uint64_t halfnn = nn / 2;
|
||||||
|
const uint64_t m = n / halfnn;
|
||||||
|
|
||||||
|
for (uint64_t i = 1; i < halfnn; ++i) {
|
||||||
|
uint64_t t = powomegabar[i * m];
|
||||||
|
uint64_t t1 = (t << level_metadata_ptr->half_bs) % q;
|
||||||
|
powomega_ptr[4 * (i - 1)] = (t1 << 32) + t;
|
||||||
|
}
|
||||||
|
powomega_ptr += 4 * (halfnn - 1);
|
||||||
|
level_metadata_ptr++;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const uint64_t invNmod = modq_pow(n, -1, q);
|
||||||
|
for (uint64_t i = 0; i < n; ++i) {
|
||||||
|
uint64_t t = (powomegabar[i] * invNmod) % q;
|
||||||
|
uint64_t t1 = (t << level_metadata_ptr->half_bs) % q;
|
||||||
|
powomega_ptr[4 * i] = (t1 << 32) + t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
free(powomegabar);
|
||||||
|
|
||||||
|
return precomp;
|
||||||
|
}
|
||||||
|
|
||||||
|
void del_precomp(q120_ntt_precomp* precomp) {
|
||||||
|
spqlios_free(precomp->powomega);
|
||||||
|
free(precomp->level_metadata);
|
||||||
|
free(precomp);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void q120_del_ntt_bb_precomp(q120_ntt_precomp* precomp) { del_precomp(precomp); }
|
||||||
|
|
||||||
|
EXPORT void q120_del_intt_bb_precomp(q120_ntt_precomp* precomp) { del_precomp(precomp); }
|
||||||
25
spqlios/lib/spqlios/q120/q120_ntt.h
Normal file
25
spqlios/lib/spqlios/q120/q120_ntt.h
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
#ifndef SPQLIOS_Q120_NTT_H
|
||||||
|
#define SPQLIOS_Q120_NTT_H
|
||||||
|
|
||||||
|
#include "../commons.h"
|
||||||
|
#include "q120_common.h"
|
||||||
|
|
||||||
|
typedef struct _q120_ntt_precomp q120_ntt_precomp;
|
||||||
|
|
||||||
|
EXPORT q120_ntt_precomp* q120_new_ntt_bb_precomp(const uint64_t n);
|
||||||
|
EXPORT void q120_del_ntt_bb_precomp(q120_ntt_precomp* precomp);
|
||||||
|
|
||||||
|
EXPORT q120_ntt_precomp* q120_new_intt_bb_precomp(const uint64_t n);
|
||||||
|
EXPORT void q120_del_intt_bb_precomp(q120_ntt_precomp* precomp);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief computes a direct ntt in-place over data.
|
||||||
|
*/
|
||||||
|
EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief computes an inverse ntt in-place over data.
|
||||||
|
*/
|
||||||
|
EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_Q120_NTT_H
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user