mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Update READMEs & add SECURITY.md (#78)
This commit is contained in:
committed by
GitHub
parent
3b94ab047e
commit
ccd94e36cc
@@ -10,7 +10,7 @@ homepage = "https://github.com/phantomzone-org/poulpy"
|
||||
documentation = "https://docs.rs/poulpy"
|
||||
|
||||
[dependencies]
|
||||
poulpy-hal = "0.1.2"
|
||||
poulpy-hal = {path="../poulpy-hal"}
|
||||
rug = {workspace = true}
|
||||
criterion = {workspace = true}
|
||||
itertools = {workspace = true}
|
||||
|
||||
@@ -1,15 +1,38 @@
|
||||
# 🐙 Poulpy-Backend
|
||||
|
||||
**Poulpy-Backend** is a Rust crate that provides concrete implementations of **`poulpy-hal`**. This crate is used to instantiate projects implemented with **`poulpy-hal`**, **`poulpy-core`** and/or **`poulpy-schemes`**.
|
||||
|
||||
## spqlios-arithmetic
|
||||
## Backends
|
||||
|
||||
### WSL/Ubuntu
|
||||
To use this crate you need to build spqlios-arithmetic, which is provided a as a git submodule:
|
||||
1) Initialize the sub-module
|
||||
2) $ cd backend/spqlios-arithmetic
|
||||
3) mdkir build
|
||||
4) cd build
|
||||
5) cmake ..
|
||||
6) make
|
||||
### cpu-spqlios
|
||||
|
||||
### Others
|
||||
Steps 3 to 6 might change depending of your platform. See [spqlios-arithmetic/wiki/build](https://github.com/tfhe/spqlios-arithmetic/wiki/build) for additional information and build options.
|
||||
This module provides a CPU AVX2 accelerated backend through C bindings over [**spqlios-arithmetic**](https://github.com/tfhe/spqlios-arithmetic).
|
||||
|
||||
- Currently supported: `FFT64` backend
|
||||
- Planned: `NTT120` backend
|
||||
|
||||
### Build Notes
|
||||
|
||||
This backend is built and compiled automatically and has been tested on wsl/ubuntu.
|
||||
|
||||
- `cmake` is invoked automatically by the build script (`build.rs`) when compiling the crate.
|
||||
- No manual setup is required beyond having a standard Rust toolchain.
|
||||
- Build options can be changed in `/build/cpu_spqlios.rs`
|
||||
- Automatic build of cpu-spqlios/spqlios-arithmetic can be disabled in `build.rs`.
|
||||
|
||||
Spqlios-arithmetic is windows/mac compatible but building for those platforms is slightly different (see [spqlios-arithmetic/wiki/build](https://github.com/tfhe/spqlios-arithmetic/wiki/build)) and has not been tested in Poulpy.
|
||||
|
||||
### Example
|
||||
|
||||
```rust
|
||||
use poulpy_backend::cpu_spqlios::FFT64;
|
||||
use poulpy_hal::{api::ModuleNew, layouts::Module};
|
||||
|
||||
let log_n: usize = 10;
|
||||
let module = Module<FFT64> = Module<FFT64>::new(1<<log_n);
|
||||
```
|
||||
|
||||
## Contributors
|
||||
|
||||
To add a backend, implement the open extension traits from **`poulpy-hal/oep`** for a struct that implements the `Backend` trait.
|
||||
This will automatically make your backend compatible with the API of **`poulpy-hal`**, **`poulpy-core`** and **`poulpy-schemes`**.
|
||||
Submodule poulpy-backend/src/cpu_spqlios/spqlios-arithmetic added at de62af3507
@@ -1,14 +0,0 @@
|
||||
# Use the Google style in this project.
|
||||
BasedOnStyle: Google
|
||||
|
||||
# Some folks prefer to write "int& foo" while others prefer "int &foo". The
|
||||
# Google Style Guide only asks for consistency within a project, we chose
|
||||
# "int& foo" for this project:
|
||||
DerivePointerAlignment: false
|
||||
PointerAlignment: Left
|
||||
|
||||
# The Google Style Guide only asks for consistency w.r.t. "east const" vs.
|
||||
# "const west" alignment of cv-qualifiers. In this project we use "east const".
|
||||
QualifierAlignment: Left
|
||||
|
||||
ColumnLimit: 120
|
||||
@@ -1,20 +0,0 @@
|
||||
name: Auto-Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Auto-Release
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 3
|
||||
# sparse-checkout: manifest.yaml scripts/auto-release.sh
|
||||
|
||||
- run:
|
||||
${{github.workspace}}/scripts/auto-release.sh
|
||||
@@ -1,6 +0,0 @@
|
||||
cmake-build-*
|
||||
.idea
|
||||
|
||||
build
|
||||
.vscode
|
||||
.*.sh
|
||||
@@ -1,69 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.8)
|
||||
project(spqlios)
|
||||
|
||||
# read the current version from the manifest file
|
||||
file(READ "manifest.yaml" manifest)
|
||||
string(REGEX MATCH "version: +(([0-9]+)\\.([0-9]+)\\.([0-9]+))" SPQLIOS_VERSION_BLAH ${manifest})
|
||||
#message(STATUS "Version: ${SPQLIOS_VERSION_BLAH}")
|
||||
set(SPQLIOS_VERSION ${CMAKE_MATCH_1})
|
||||
set(SPQLIOS_VERSION_MAJOR ${CMAKE_MATCH_2})
|
||||
set(SPQLIOS_VERSION_MINOR ${CMAKE_MATCH_3})
|
||||
set(SPQLIOS_VERSION_PATCH ${CMAKE_MATCH_4})
|
||||
message(STATUS "Compiling spqlios-fft version: ${SPQLIOS_VERSION_MAJOR}.${SPQLIOS_VERSION_MINOR}.${SPQLIOS_VERSION_PATCH}")
|
||||
|
||||
#set(ENABLE_SPQLIOS_F128 ON CACHE BOOL "Enable float128 via libquadmath")
|
||||
set(WARNING_PARANOID ON CACHE BOOL "Treat all warnings as errors")
|
||||
set(ENABLE_TESTING ON CACHE BOOL "Compiles unittests and integration tests")
|
||||
set(DEVMODE_INSTALL OFF CACHE BOOL "Install private headers and testlib (mainly for CI)")
|
||||
|
||||
if (NOT CMAKE_BUILD_TYPE OR CMAKE_BUILD_TYPE STREQUAL "")
|
||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type: Release or Debug" FORCE)
|
||||
endif()
|
||||
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||
|
||||
if (WARNING_PARANOID)
|
||||
add_compile_options(-Wall -Werror -Wno-unused-command-line-argument)
|
||||
endif()
|
||||
|
||||
message(STATUS "CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}")
|
||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
|
||||
|
||||
if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)")
|
||||
set(X86 ON)
|
||||
set(AARCH64 OFF)
|
||||
else ()
|
||||
set(X86 OFF)
|
||||
# set(ENABLE_SPQLIOS_F128 OFF) # float128 are only supported for x86 targets
|
||||
endif ()
|
||||
if (CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)")
|
||||
set(AARCH64 ON)
|
||||
endif ()
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "(Windows)|(MSYS)")
|
||||
set(WIN32 ON)
|
||||
endif ()
|
||||
if (WIN32)
|
||||
#overrides for win32
|
||||
set(X86 OFF)
|
||||
set(AARCH64 OFF)
|
||||
set(X86_WIN32 ON)
|
||||
else()
|
||||
set(X86_WIN32 OFF)
|
||||
set(WIN32 OFF)
|
||||
endif (WIN32)
|
||||
|
||||
message(STATUS "--> WIN32: ${WIN32}")
|
||||
message(STATUS "--> X86_WIN32: ${X86_WIN32}")
|
||||
message(STATUS "--> X86_LINUX: ${X86}")
|
||||
message(STATUS "--> AARCH64: ${AARCH64}")
|
||||
|
||||
# compiles the main library in spqlios
|
||||
add_subdirectory(spqlios)
|
||||
|
||||
# compiles and activates unittests and itests
|
||||
if (${ENABLE_TESTING})
|
||||
enable_testing()
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
# Contributing to SPQlios-fft
|
||||
|
||||
The spqlios-fft team encourages contributions.
|
||||
We encourage users to fix bugs, improve the documentation, write tests and to enhance the code, or ask for new features.
|
||||
We encourage researchers to contribute with implementations of their FFT or NTT algorithms.
|
||||
In the following we are trying to give some guidance on how to contribute effectively.
|
||||
|
||||
## Communication ##
|
||||
|
||||
Communication in the spqlios-fft project happens mainly on [GitHub](https://github.com/tfhe/spqlios-fft/issues).
|
||||
|
||||
All communications are public, so please make sure to maintain professional behaviour in
|
||||
all published comments. See [Code of Conduct](https://www.contributor-covenant.org/version/2/1/code_of_conduct/) for
|
||||
guidelines.
|
||||
|
||||
## Reporting Bugs or Requesting features ##
|
||||
|
||||
Bug should be filed at [https://github.com/tfhe/spqlios-fft/issues](https://github.com/tfhe/spqlios-fft/issues).
|
||||
|
||||
Features can also be requested there, in this case, please ensure that the features you request are self-contained,
|
||||
easy to define, and generic enough to be used in different use-cases. Please provide an example of use-cases if
|
||||
possible.
|
||||
|
||||
## Setting up topic branches and generating pull requests
|
||||
|
||||
This section applies to people that already have write access to the repository. Specific instructions for pull-requests
|
||||
from public forks will be given later.
|
||||
|
||||
To implement some changes, please follow these steps:
|
||||
|
||||
- Create a "topic branch". Usually, the branch name should be `username/small-title`
|
||||
or better `username/issuenumber-small-title` where `issuenumber` is the number of
|
||||
the github issue number that is tackled.
|
||||
- Push any needed commits to your branch. Make sure it compiles in `CMAKE_BUILD_TYPE=Debug` and `=Release`, with `-DWARNING_PARANOID=ON`.
|
||||
- When the branch is nearly ready for review, please open a pull request, and add the label `check-on-arm`
|
||||
- Do as many commits as necessary until all CI checks pass and all PR comments have been resolved.
|
||||
|
||||
> _During the process, you may optionnally use `git rebase -i` to clean up your commit history. If you elect to do so,
|
||||
please at the very least make sure that nobody else is working or has forked from your branch: the conflicts it would generate
|
||||
and the human hours to fix them are not worth it. `Git merge` remains the preferred option._
|
||||
|
||||
- Finally, when all reviews are positive and all CI checks pass, you may merge your branch via the github webpage.
|
||||
|
||||
### Keep your pull requests limited to a single issue
|
||||
|
||||
Pull requests should be as small/atomic as possible.
|
||||
|
||||
### Coding Conventions
|
||||
|
||||
* Please make sure that your code is formatted according to the `.clang-format` file and
|
||||
that all files end with a newline character.
|
||||
* Please make sure that all the functions declared in the public api have relevant doxygen comments.
|
||||
Preferably, functions in the private apis should also contain a brief doxygen description.
|
||||
|
||||
### Versions and History
|
||||
|
||||
* **Stable API** The project uses semantic versioning on the functions that are listed as `stable` in the documentation. A version has
|
||||
the form `x.y.z`
|
||||
* a patch release that increments `z` does not modify the stable API.
|
||||
* a minor release that increments `y` adds a new feature to the stable API.
|
||||
* In the unlikely case where we need to change or remove a feature, we will trigger a major release that
|
||||
increments `x`.
|
||||
|
||||
> _If any, we will mark those features as deprecated at least six months before the major release._
|
||||
|
||||
* **Experimental API** Features that are not part of the stable section in the documentation are experimental features: you may test them at
|
||||
your own risk,
|
||||
but keep in mind that semantic versioning does not apply to them.
|
||||
|
||||
> _If you have a use-case that uses an experimental feature, we encourage
|
||||
> you to tell us about it, so that this feature reaches to the stable section faster!_
|
||||
|
||||
* **Version history** The current version is reported in `manifest.yaml`, any change of version comes up with a tag on the main branch, and the history between releases is summarized in `Changelog.md`. It is the main source of truth for anyone who wishes to
|
||||
get insight about
|
||||
the history of the repository (not the commit graph).
|
||||
|
||||
> Note: _The commit graph of git is for git's internal use only. Its main purpose is to reduce potential merge conflicts to a minimum, even in scenario where multiple features are developped in parallel: it may therefore be non-linear. If, as humans, we like to see a linear history, please read `Changelog.md` instead!_
|
||||
@@ -1,18 +0,0 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [2.0.0] - 2024-08-21
|
||||
|
||||
- Initial release of the `vec_znx` (except convolution products), `vec_rnx` and `zn` apis.
|
||||
- Hardware acceleration available: AVX2 (most parts)
|
||||
- APIs are documented in the wiki and are in "beta mode": during the 2.x -> 3.x transition, functions whose API is satisfactory in test projects will pass in "stable mode".
|
||||
|
||||
## [1.0.0] - 2023-07-18
|
||||
|
||||
- Initial release of the double precision fft on the reim and cplx backends
|
||||
- Coeffs-space conversions cplx <-> znx32 and tnx32
|
||||
- FFT-space conversions cplx <-> reim4 layouts
|
||||
- FFT-space multiplications on the cplx, reim and reim4 layouts.
|
||||
- In this first release, the only platform supported is linux x86_64 (generic C code, and avx2/fma). It compiles on arm64, but without any acceleration.
|
||||
@@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -1,65 +0,0 @@
|
||||
# SPQlios library
|
||||
|
||||
|
||||
|
||||
The SPQlios library provides fast arithmetic for Fully Homomorphic Encryption, and other lattice constructions that arise in post quantum cryptography.
|
||||
|
||||
<img src="docs/api-full.svg">
|
||||
|
||||
Namely, it is divided into 4 sections:
|
||||
|
||||
* The low-level DFT section support FFT over 64-bit floats, as well as NTT modulo one fixed 120-bit modulus. It is an upgrade of the original spqlios-fft module embedded in the TFHE library since 2016. The DFT section exposes the traditional DFT, inverse-DFT, and coefficient-wise multiplications in DFT space.
|
||||
* The VEC_ZNX section exposes fast algebra over vectors of small integer polynomial modulo $X^N+1$. It proposed in particular efficient (prepared) vector-matrix products, scalar-vector products, convolution products, and element-wise products, operations that naturally occurs on gadget-decomposed Ring-LWE coordinates.
|
||||
* The RNX section is a simpler variant of VEC_ZNX, to represent single polynomials modulo $X^N+1$ (over the reals or over the torus) when the coefficient precision fits on 64-bit doubles. The small vector-matrix API of the RNX section is particularly adapted to reproducing the fastest CGGI-based bootstrappings.
|
||||
* The ZN section focuses over vector and matrix algebra over scalars (used by scalar LWE, or scalar key-switches, but also on non-ring schemes like Frodo, FrodoPIR, and SimplePIR).
|
||||
|
||||
### A high value target for hardware accelerations
|
||||
|
||||
SPQlios is more than a library, it is also a good target for hardware developers.
|
||||
On one hand, the arithmetic operations that are defined in the library have a clear standalone mathematical definition. And at the same time, the amount of work in each operations is sufficiently large so that meaningful functions only require a few of these.
|
||||
|
||||
This makes the SPQlios API a high value target for hardware acceleration, that targets FHE.
|
||||
|
||||
### SPQLios is not an FHE library, but a huge enabler
|
||||
|
||||
SPQlios itself is not an FHE library: there is no ciphertext, plaintext or key. It is a mathematical library that exposes efficient algebra over polynomials. Using the functions exposed, it is possible to quickly build efficient FHE libraries, with support for the main schemes based on Ring-LWE: BFV, BGV, CGGI, DM, CKKS.
|
||||
|
||||
|
||||
## Dependencies
|
||||
|
||||
The SPQLIOS-FFT library is a C library that can be compiled with a standard C compiler, and depends only on libc and libm. The API
|
||||
interface can be used in a regular C code, and any other language via classical foreign APIs.
|
||||
|
||||
The unittests and integration tests are in an optional part of the code, and are written in C++. These tests rely on
|
||||
[```benchmark```](https://github.com/google/benchmark), and [```gtest```](https://github.com/google/googletest) libraries, and therefore require a C++17 compiler.
|
||||
|
||||
Currently, the project has been tested with the gcc,g++ >= 11.3.0 compiler under Linux (x86_64). In the future, we plan to
|
||||
extend the compatibility to other compilers, platforms and operating systems.
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
The library uses a classical ```cmake``` build mechanism: use ```cmake``` to create a ```build``` folder in the top level directory and run ```make``` from inside it. This assumes that the standard tool ```cmake``` is already installed on the system, and an up-to-date c++ compiler (i.e. g++ >=11.3.0) as well.
|
||||
|
||||
It will compile the shared library in optimized mode, and ```make install``` install it to the desired prefix folder (by default ```/usr/local/lib```).
|
||||
|
||||
If you want to choose additional compile options (i.e. other installation folder, debug mode, tests), you need to run cmake manually and pass the desired options:
|
||||
```
|
||||
mkdir build
|
||||
cd build
|
||||
cmake ../src -CMAKE_INSTALL_PREFIX=/usr/
|
||||
make
|
||||
```
|
||||
The available options are the following:
|
||||
|
||||
| Variable Name | values |
|
||||
| -------------------- | ------------------------------------------------------------ |
|
||||
| CMAKE_INSTALL_PREFIX | */usr/local* installation folder (libs go in lib/ and headers in include/) |
|
||||
| WARNING_PARANOID | All warnings are shown and treated as errors. Off by default |
|
||||
| ENABLE_TESTING | Compiles unit tests and integration tests |
|
||||
|
||||
------
|
||||
|
||||
<img src="docs/logo-sandboxaq-black.svg">
|
||||
|
||||
<img src="docs/logo-inpher1.png">
|
||||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 550 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 24 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 24 KiB |
@@ -1,139 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<!-- Generator: Adobe Illustrator 24.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||
|
||||
<svg
|
||||
version="1.1"
|
||||
id="Layer_1"
|
||||
x="0px"
|
||||
y="0px"
|
||||
viewBox="0 0 270 49.4"
|
||||
style="enable-background:new 0 0 270 49.4;"
|
||||
xml:space="preserve"
|
||||
sodipodi:docname="logo-sandboxaq-black.svg"
|
||||
inkscape:version="1.3.2 (1:1.3.2+202311252150+091e20ef0f)"
|
||||
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:svg="http://www.w3.org/2000/svg"><defs
|
||||
id="defs9839">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</defs><sodipodi:namedview
|
||||
id="namedview9837"
|
||||
pagecolor="#ffffff"
|
||||
bordercolor="#000000"
|
||||
borderopacity="0.25"
|
||||
inkscape:showpageshadow="2"
|
||||
inkscape:pageopacity="0.0"
|
||||
inkscape:pagecheckerboard="0"
|
||||
inkscape:deskcolor="#d1d1d1"
|
||||
showgrid="false"
|
||||
inkscape:zoom="1.194332"
|
||||
inkscape:cx="135.64068"
|
||||
inkscape:cy="25.118645"
|
||||
inkscape:window-width="804"
|
||||
inkscape:window-height="436"
|
||||
inkscape:window-x="190"
|
||||
inkscape:window-y="27"
|
||||
inkscape:window-maximized="0"
|
||||
inkscape:current-layer="Layer_1" />
|
||||
<style
|
||||
type="text/css"
|
||||
id="style9786">
|
||||
.st0{fill:#EBB028;}
|
||||
.st1{fill:#FFFFFF;}
|
||||
</style>
|
||||
<text
|
||||
transform="matrix(1 0 0 1 393.832 -491.944)"
|
||||
class="st1"
|
||||
style="font-family:'Satoshi-Medium'; font-size:86.2078px;"
|
||||
id="text9788">SANDBOX </text>
|
||||
<text
|
||||
transform="matrix(1 0 0 1 896.332 -491.944)"
|
||||
class="st1"
|
||||
style="font-family:'Satoshi-Black'; font-size:86.2078px;"
|
||||
id="text9790">AQ</text>
|
||||
<g
|
||||
id="g9808">
|
||||
<g
|
||||
id="g9800">
|
||||
<g
|
||||
id="g9798">
|
||||
<path
|
||||
class="st0"
|
||||
d="m 8.9,9.7 v 3.9 l 29.6,17.1 v 2.7 c 0,1.2 -0.6,2.3 -1.6,2.9 L 31,39.8 v -4 L 1.4,18.6 V 15.9 C 1.4,14.7 2,13.6 3.1,13 Z"
|
||||
id="path9792" />
|
||||
<path
|
||||
class="st0"
|
||||
d="M 18.3,45.1 3.1,36.3 C 2.1,35.7 1.4,34.6 1.4,33.4 V 26 L 28,41.4 21.5,45.1 c -0.9,0.6 -2.2,0.6 -3.2,0 z"
|
||||
id="path9794" />
|
||||
<path
|
||||
class="st0"
|
||||
d="m 21.6,4.3 15.2,8.8 c 1,0.6 1.7,1.7 1.7,2.9 v 7.5 L 11.8,8 18.3,4.3 c 1,-0.6 2.3,-0.6 3.3,0 z"
|
||||
id="path9796" />
|
||||
</g>
|
||||
</g>
|
||||
<g
|
||||
id="g9806">
|
||||
<polygon
|
||||
class="st0"
|
||||
points="248.1,23.2 248.1,30 251.4,33.8 257.3,33.8 "
|
||||
id="polygon9802" />
|
||||
<path
|
||||
class="st0"
|
||||
d="m 246.9,31 -0.1,-0.1 h -0.1 c -0.2,0 -0.4,0 -0.6,0 -3.5,0 -5.7,-2.6 -5.7,-6.7 0,-4.1 2.2,-6.7 5.7,-6.7 3.5,0 5.7,2.6 5.7,6.7 0,0.3 0,0.6 0,0.9 l 3.6,4.2 c 0.7,-1.5 1,-3.2 1,-5.1 0,-6.5 -4.2,-11 -10.3,-11 -6.1,0 -10.3,4.5 -10.3,11 0,6.5 4.2,11 10.3,11 1.2,0 2.3,-0.2 3.4,-0.5 l 0.5,-0.2 z"
|
||||
id="path9804" />
|
||||
</g>
|
||||
</g><g
|
||||
id="g9824"
|
||||
style="fill:#1a1a1a">
|
||||
<path
|
||||
class="st1"
|
||||
d="m 58.7,13.2 c 4.6,0 7.4,2.5 7.4,6.5 h -4.6 c 0,-1.5 -1.1,-2.4 -2.9,-2.4 -1.9,0 -3.1,0.9 -3.1,2.3 0,1.3 0.7,1.9 2.2,2.2 l 3.2,0.7 c 3.8,0.8 5.6,2.6 5.6,5.9 0,4.1 -3.2,6.8 -8.1,6.8 -4.7,0 -7.8,-2.6 -7.8,-6.5 h 4.6 c 0,1.6 1.1,2.4 3.2,2.4 2.1,0 3.4,-0.8 3.4,-2.2 0,-1.2 -0.5,-1.8 -2,-2.1 l -3.2,-0.7 c -3.8,-0.8 -5.7,-2.9 -5.7,-6.4 0,-3.7 3.2,-6.5 7.8,-6.5 z"
|
||||
id="path9810"
|
||||
style="fill:#1a1a1a" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M 70.4,34.9 78,13.6 h 4.5 l 7.6,21.3 h -4.9 l -1.5,-4.5 h -6.9 l -1.5,4.5 z m 7.7,-8.4 h 4.2 L 80.8,22 c -0.2,-0.7 -0.5,-1.6 -0.6,-2.1 -0.1,0.5 -0.3,1.3 -0.6,2.1 z"
|
||||
id="path9812"
|
||||
style="fill:#1a1a1a" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M 95.3,34.9 V 13.6 h 4.6 l 9,13.5 V 13.6 h 4.6 v 21.3 h -4.6 l -9,-13.5 v 13.5 z"
|
||||
id="path9814"
|
||||
style="fill:#1a1a1a" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M 120.7,34.9 V 13.6 h 8 c 6.2,0 10.6,4.4 10.6,10.7 0,6.2 -4.2,10.6 -10.3,10.6 z m 4.7,-17 v 12.6 h 3.2 c 3.7,0 5.8,-2.3 5.8,-6.3 0,-4 -2.3,-6.4 -6.1,-6.4 h -2.9 z"
|
||||
id="path9816"
|
||||
style="fill:#1a1a1a" />
|
||||
<path
|
||||
class="st1"
|
||||
d="m 145.4,13.6 h 8.8 c 4.3,0 6.9,2.2 6.9,5.9 0,2.3 -1,3.9 -3,4.8 2.1,0.7 3.2,2.3 3.2,4.7 0,3.8 -2.5,5.9 -7.1,5.9 h -8.8 z m 4.7,4.1 v 4.6 h 3.7 c 1.7,0 2.6,-0.8 2.6,-2.4 0,-1.5 -0.9,-2.3 -2.6,-2.3 h -3.7 z m 0,8.5 v 4.6 h 3.9 c 1.7,0 2.6,-0.8 2.6,-2.4 0,-1.4 -0.9,-2.2 -2.6,-2.2 z"
|
||||
id="path9818"
|
||||
style="fill:#1a1a1a" />
|
||||
<path
|
||||
class="st1"
|
||||
d="m 176.5,35.2 c -6.1,0 -10.4,-4.5 -10.4,-11 0,-6.5 4.3,-11 10.4,-11 6.2,0 10.4,4.5 10.4,11 0,6.5 -4.2,11 -10.4,11 z m 0.1,-17.5 c -3.4,0 -5.5,2.4 -5.5,6.5 0,4.1 2.1,6.5 5.5,6.5 3.4,0 5.5,-2.5 5.5,-6.5 0,-4 -2.1,-6.5 -5.5,-6.5 z"
|
||||
id="path9820"
|
||||
style="fill:#1a1a1a" />
|
||||
<path
|
||||
class="st1"
|
||||
d="m 190.4,13.6 h 5.5 l 1.8,2.8 c 0.8,1.2 1.5,2.5 2.5,4.3 l 4.3,-7 h 5.4 l -6.7,10.6 6.7,10.6 h -5.5 L 203,32.7 c -1.1,-1.7 -1.8,-3 -2.8,-4.9 l -4.6,7.1 h -5.5 l 7.1,-10.6 z"
|
||||
id="path9822"
|
||||
style="fill:#1a1a1a" />
|
||||
</g><path
|
||||
class="st0"
|
||||
d="m 229,34.9 h 4.7 L 226,13.6 h -4.3 L 214,34.8 h 4.6 l 1.6,-4.5 h 7.1 z m -5.1,-14.6 c 0,0 0,0 0,0 0,-0.1 0,-0.1 0,0 l 2.2,6.2 h -4.4 z"
|
||||
id="path9826" /><g
|
||||
id="g9832">
|
||||
<path
|
||||
class="st1"
|
||||
d="m 259.5,11.2 h 3.9 v 1 h -1.3 v 3.1 h -1.3 v -3.1 h -1.3 z m 4.5,0 h 1.7 l 0.6,2.5 0.6,-2.5 h 1.7 v 4.1 h -1 v -3.1 l -0.8,3.1 h -0.9 l -0.8,-3.1 v 3.1 h -1 v -4.1 z"
|
||||
id="path9830" />
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 5.0 KiB |
@@ -1,133 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<!-- Generator: Adobe Illustrator 24.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||
|
||||
<svg
|
||||
version="1.1"
|
||||
id="Layer_1"
|
||||
x="0px"
|
||||
y="0px"
|
||||
viewBox="0 0 270 49.4"
|
||||
style="enable-background:new 0 0 270 49.4;"
|
||||
xml:space="preserve"
|
||||
sodipodi:docname="logo-sandboxaq-white.svg"
|
||||
inkscape:version="1.2.2 (1:1.2.2+202212051551+b0a8486541)"
|
||||
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:svg="http://www.w3.org/2000/svg"><defs
|
||||
id="defs9839" /><sodipodi:namedview
|
||||
id="namedview9837"
|
||||
pagecolor="#ffffff"
|
||||
bordercolor="#000000"
|
||||
borderopacity="0.25"
|
||||
inkscape:showpageshadow="2"
|
||||
inkscape:pageopacity="0.0"
|
||||
inkscape:pagecheckerboard="0"
|
||||
inkscape:deskcolor="#d1d1d1"
|
||||
showgrid="false"
|
||||
inkscape:zoom="2.3886639"
|
||||
inkscape:cx="135.22204"
|
||||
inkscape:cy="25.327967"
|
||||
inkscape:window-width="1072"
|
||||
inkscape:window-height="688"
|
||||
inkscape:window-x="0"
|
||||
inkscape:window-y="0"
|
||||
inkscape:window-maximized="1"
|
||||
inkscape:current-layer="Layer_1" />
|
||||
<style
|
||||
type="text/css"
|
||||
id="style9786">
|
||||
.st0{fill:#EBB028;}
|
||||
.st1{fill:#FFFFFF;}
|
||||
</style>
|
||||
<text
|
||||
transform="matrix(1 0 0 1 393.832 -491.944)"
|
||||
class="st1"
|
||||
style="font-family:'Satoshi-Medium'; font-size:86.2078px;"
|
||||
id="text9788">SANDBOX </text>
|
||||
<text
|
||||
transform="matrix(1 0 0 1 896.332 -491.944)"
|
||||
class="st1"
|
||||
style="font-family:'Satoshi-Black'; font-size:86.2078px;"
|
||||
id="text9790">AQ</text>
|
||||
<g
|
||||
id="g9834">
|
||||
<g
|
||||
id="g9828">
|
||||
<g
|
||||
id="g9808">
|
||||
<g
|
||||
id="g9800">
|
||||
<g
|
||||
id="g9798">
|
||||
<path
|
||||
class="st0"
|
||||
d="M8.9,9.7v3.9l29.6,17.1v2.7c0,1.2-0.6,2.3-1.6,2.9L31,39.8v-4L1.4,18.6v-2.7c0-1.2,0.6-2.3,1.7-2.9 L8.9,9.7z"
|
||||
id="path9792" />
|
||||
<path
|
||||
class="st0"
|
||||
d="M18.3,45.1L3.1,36.3c-1-0.6-1.7-1.7-1.7-2.9V26L28,41.4l-6.5,3.7C20.6,45.7,19.3,45.7,18.3,45.1z"
|
||||
id="path9794" />
|
||||
<path
|
||||
class="st0"
|
||||
d="M21.6,4.3l15.2,8.8c1,0.6,1.7,1.7,1.7,2.9v7.5L11.8,8l6.5-3.7C19.3,3.7,20.6,3.7,21.6,4.3z"
|
||||
id="path9796" />
|
||||
</g>
|
||||
</g>
|
||||
<g
|
||||
id="g9806">
|
||||
<polygon
|
||||
class="st0"
|
||||
points="248.1,23.2 248.1,30 251.4,33.8 257.3,33.8 "
|
||||
id="polygon9802" />
|
||||
<path
|
||||
class="st0"
|
||||
d="M246.9,31l-0.1-0.1l-0.1,0c-0.2,0-0.4,0-0.6,0c-3.5,0-5.7-2.6-5.7-6.7c0-4.1,2.2-6.7,5.7-6.7 s5.7,2.6,5.7,6.7c0,0.3,0,0.6,0,0.9l3.6,4.2c0.7-1.5,1-3.2,1-5.1c0-6.5-4.2-11-10.3-11c-6.1,0-10.3,4.5-10.3,11s4.2,11,10.3,11 c1.2,0,2.3-0.2,3.4-0.5l0.5-0.2L246.9,31z"
|
||||
id="path9804" />
|
||||
</g>
|
||||
</g>
|
||||
<g
|
||||
id="g9824">
|
||||
<path
|
||||
class="st1"
|
||||
d="M58.7,13.2c4.6,0,7.4,2.5,7.4,6.5h-4.6c0-1.5-1.1-2.4-2.9-2.4c-1.9,0-3.1,0.9-3.1,2.3c0,1.3,0.7,1.9,2.2,2.2 l3.2,0.7c3.8,0.8,5.6,2.6,5.6,5.9c0,4.1-3.2,6.8-8.1,6.8c-4.7,0-7.8-2.6-7.8-6.5h4.6c0,1.6,1.1,2.4,3.2,2.4 c2.1,0,3.4-0.8,3.4-2.2c0-1.2-0.5-1.8-2-2.1l-3.2-0.7c-3.8-0.8-5.7-2.9-5.7-6.4C50.9,16,54.1,13.2,58.7,13.2z"
|
||||
id="path9810" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M70.4,34.9L78,13.6h4.5l7.6,21.3h-4.9l-1.5-4.5h-6.9l-1.5,4.5H70.4z M78.1,26.5h4.2L80.8,22 c-0.2-0.7-0.5-1.6-0.6-2.1c-0.1,0.5-0.3,1.3-0.6,2.1L78.1,26.5z"
|
||||
id="path9812" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M95.3,34.9V13.6h4.6l9,13.5V13.6h4.6v21.3h-4.6l-9-13.5v13.5H95.3z"
|
||||
id="path9814" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M120.7,34.9V13.6h8c6.2,0,10.6,4.4,10.6,10.7c0,6.2-4.2,10.6-10.3,10.6H120.7z M125.4,17.9v12.6h3.2 c3.7,0,5.8-2.3,5.8-6.3c0-4-2.3-6.4-6.1-6.4H125.4z"
|
||||
id="path9816" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M145.4,13.6h8.8c4.3,0,6.9,2.2,6.9,5.9c0,2.3-1,3.9-3,4.8c2.1,0.7,3.2,2.3,3.2,4.7c0,3.8-2.5,5.9-7.1,5.9 h-8.8V13.6z M150.1,17.7v4.6h3.7c1.7,0,2.6-0.8,2.6-2.4c0-1.5-0.9-2.3-2.6-2.3H150.1z M150.1,26.2v4.6h3.9c1.7,0,2.6-0.8,2.6-2.4 c0-1.4-0.9-2.2-2.6-2.2H150.1z"
|
||||
id="path9818" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M176.5,35.2c-6.1,0-10.4-4.5-10.4-11s4.3-11,10.4-11c6.2,0,10.4,4.5,10.4,11S182.7,35.2,176.5,35.2z M176.6,17.7c-3.4,0-5.5,2.4-5.5,6.5c0,4.1,2.1,6.5,5.5,6.5c3.4,0,5.5-2.5,5.5-6.5C182.1,20.2,180,17.7,176.6,17.7z"
|
||||
id="path9820" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M190.4,13.6h5.5l1.8,2.8c0.8,1.2,1.5,2.5,2.5,4.3l4.3-7h5.4l-6.7,10.6l6.7,10.6h-5.5l-1.4-2.2 c-1.1-1.7-1.8-3-2.8-4.9l-4.6,7.1h-5.5l7.1-10.6L190.4,13.6z"
|
||||
id="path9822" />
|
||||
</g>
|
||||
<path
|
||||
class="st0"
|
||||
d="M229,34.9h4.7L226,13.6h-4.3l-7.7,21.2h4.6l1.6-4.5h7.1L229,34.9z M223.9,20.3 C223.9,20.3,223.9,20.3,223.9,20.3C223.9,20.2,223.9,20.2,223.9,20.3l2.2,6.2h-4.4L223.9,20.3z"
|
||||
id="path9826" />
|
||||
</g>
|
||||
<g
|
||||
id="g9832">
|
||||
<path
|
||||
class="st1"
|
||||
d="M259.5,11.2h3.9v1h-1.3v3.1h-1.3v-3.1h-1.3V11.2L259.5,11.2z M264,11.2h1.7l0.6,2.5l0.6-2.5h1.7v4.1h-1v-3.1 l-0.8,3.1h-0.9l-0.8-3.1v3.1h-1V11.2L264,11.2z"
|
||||
id="path9830" />
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 4.7 KiB |
@@ -1,2 +0,0 @@
|
||||
library: spqlios-fft
|
||||
version: 2.0.0
|
||||
@@ -1,27 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
# this script generates one tag if there is a version change in manifest.yaml
|
||||
cd `dirname $0`/..
|
||||
if [ "v$1" = "v-y" ]; then
|
||||
echo "production mode!";
|
||||
fi
|
||||
changes=`git diff HEAD~1..HEAD -- manifest.yaml | grep 'version:'`
|
||||
oldversion=$(echo "$changes" | grep '^-version:' | cut '-d ' -f2)
|
||||
version=$(echo "$changes" | grep '^+version:' | cut '-d ' -f2)
|
||||
echo "Versions: $oldversion --> $version"
|
||||
if [ "v$oldversion" = "v$version" ]; then
|
||||
echo "Same version - nothing to do"; exit 0;
|
||||
fi
|
||||
if [ "v$1" = "v-y" ]; then
|
||||
git config user.name github-actions
|
||||
git config user.email github-actions@github.com
|
||||
git tag -a "v$version" -m "Version $version"
|
||||
git push origin "v$version"
|
||||
else
|
||||
cat <<EOF
|
||||
# the script would do:
|
||||
git tag -a "v$version" -m "Version $version"
|
||||
git push origin "v$version"
|
||||
EOF
|
||||
fi
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
# ONLY USE A PREFIX YOU ARE CONFIDENT YOU CAN WIPE OUT ENTIRELY
|
||||
CI_INSTALL_PREFIX=/opt/spqlios
|
||||
CI_REPO_URL=https://spq-dav.algonics.net/ci
|
||||
WORKDIR=`pwd`
|
||||
if [ "x$DESTDIR" = "x" ]; then
|
||||
DESTDIR=/
|
||||
else
|
||||
mkdir -p $DESTDIR
|
||||
DESTDIR=`realpath $DESTDIR`
|
||||
fi
|
||||
DIR=`dirname "$0"`
|
||||
cd $DIR/..
|
||||
DIR=`pwd`
|
||||
|
||||
FULL_UNAME=`uname -a | tr '[A-Z]' '[a-z]'`
|
||||
HOST=`echo $FULL_UNAME | sed 's/ .*//'`
|
||||
ARCH=none
|
||||
case "$HOST" in
|
||||
*linux*)
|
||||
DISTRIB=`lsb_release -c | awk '{print $2}' | tr '[A-Z]' '[a-z]'`
|
||||
HOST=linux-$DISTRIB
|
||||
;;
|
||||
*darwin*)
|
||||
HOST=darwin
|
||||
;;
|
||||
*mingw*|*msys*)
|
||||
DISTRIB=`echo $MSYSTEM | tr '[A-Z]' '[a-z]'`
|
||||
HOST=msys64-$DISTRIB
|
||||
;;
|
||||
*)
|
||||
echo "Host unknown: $HOST";
|
||||
exit 1
|
||||
esac
|
||||
case "$FULL_UNAME" in
|
||||
*x86_64*)
|
||||
ARCH=x86_64
|
||||
;;
|
||||
*aarch64*)
|
||||
ARCH=aarch64
|
||||
;;
|
||||
*arm64*)
|
||||
ARCH=arm64
|
||||
;;
|
||||
*)
|
||||
echo "Architecture unknown: $FULL_UNAME";
|
||||
exit 1
|
||||
esac
|
||||
UNAME="$HOST-$ARCH"
|
||||
CMH=
|
||||
if [ -d lib/spqlios/.git ]; then
|
||||
CMH=`git submodule status | sed 's/\(..........\).*/\1/'`
|
||||
else
|
||||
CMH=`git rev-parse HEAD | sed 's/\(..........\).*/\1/'`
|
||||
fi
|
||||
FNAME=spqlios-arithmetic-$CMH-$UNAME.tar.gz
|
||||
|
||||
cat <<EOF
|
||||
================= CI MINI-PACKAGER ==================
|
||||
Work Dir: WORKDIR=$WORKDIR
|
||||
Spq Dir: DIR=$DIR
|
||||
Install Root: DESTDIR=$DESTDIR
|
||||
Install Prefix: CI_INSTALL_PREFIX=$CI_INSTALL_PREFIX
|
||||
Archive Name: FNAME=$FNAME
|
||||
CI WebDav: CI_REPO_URL=$CI_REPO_URL
|
||||
=====================================================
|
||||
EOF
|
||||
|
||||
if [ "x$1" = "xcreate" ]; then
|
||||
rm -rf dist
|
||||
cmake -B build -S . -DCMAKE_INSTALL_PREFIX="$CI_INSTALL_PREFIX" -DCMAKE_BUILD_TYPE=Release -DENABLE_TESTING=ON -DWARNING_PARANOID=ON -DDEVMODE_INSTALL=ON || exit 1
|
||||
cmake --build build || exit 1
|
||||
rm -rf "$DIR/dist" 2>/dev/null
|
||||
rm -f "$DIR/$FNAME" 2>/dev/null
|
||||
DESTDIR="$DIR/dist" cmake --install build || exit 1
|
||||
if [ -d "$DIR/dist$CI_INSTALL_PREFIX" ]; then
|
||||
tar -C "$DIR/dist" -cvzf "$DIR/$FNAME" .
|
||||
else
|
||||
# fix since msys can mess up the paths
|
||||
REAL_DEST=`find "$DIR/dist" -type d -exec test -d "{}$CI_INSTALL_PREFIX" \; -print`
|
||||
echo "REAL_DEST: $REAL_DEST"
|
||||
[ -d "$REAL_DEST$CI_INSTALL_PREFIX" ] && tar -C "$REAL_DEST" -cvzf "$DIR/$FNAME" .
|
||||
fi
|
||||
[ -f "$DIR/$FNAME" ] || { echo "failed to create $DIR/$FNAME"; exit 1; }
|
||||
[ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not uploading"; exit 1; }
|
||||
curl -u "$CI_CREDS" -T "$DIR/$FNAME" "$CI_REPO_URL/$FNAME"
|
||||
fi
|
||||
|
||||
if [ "x$1" = "xinstall" ]; then
|
||||
[ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not downloading"; exit 1; }
|
||||
# cleaning
|
||||
rm -rf "$DESTDIR$CI_INSTALL_PREFIX"/* 2>/dev/null
|
||||
rm -f "$DIR/$FNAME" 2>/dev/null
|
||||
# downloading
|
||||
curl -u "$CI_CREDS" -o "$DIR/$FNAME" "$CI_REPO_URL/$FNAME"
|
||||
[ -f "$DIR/$FNAME" ] || { echo "failed to download $DIR/$FNAME"; exit 0; }
|
||||
# installing
|
||||
mkdir -p $DESTDIR
|
||||
tar -C "$DESTDIR" -xvzf "$DIR/$FNAME"
|
||||
exit 0
|
||||
fi
|
||||
@@ -1,181 +0,0 @@
|
||||
#!/usr/bin/perl
|
||||
##
|
||||
## This script will help update manifest.yaml and Changelog.md before a release
|
||||
## Any merge to master that changes the version line in manifest.yaml
|
||||
## is considered as a new release.
|
||||
##
|
||||
## When ready to make a release, please run ./scripts/prepare-release
|
||||
## and commit push the final result!
|
||||
use File::Basename;
|
||||
use Cwd 'abs_path';
|
||||
|
||||
# find its way to the root of git's repository
|
||||
my $scriptsdirname = dirname(abs_path(__FILE__));
|
||||
chdir "$scriptsdirname/..";
|
||||
print "✓ Entering directory:".`pwd`;
|
||||
|
||||
# ensures that the current branch is ahead of origin/main
|
||||
my $diff= `git diff`;
|
||||
chop $diff;
|
||||
if ($diff =~ /./) {
|
||||
die("ERROR: Please commit all the changes before calling the prepare-release script.");
|
||||
} else {
|
||||
print("✓ All changes are comitted.\n");
|
||||
}
|
||||
system("git fetch origin");
|
||||
my $vcount = `git rev-list --left-right --count origin/main...HEAD`;
|
||||
$vcount =~ /^([0-9]+)[ \t]*([0-9]+)$/;
|
||||
if ($2>0) {
|
||||
die("ERROR: the current HEAD is not ahead of origin/main\n. Please use git merge origin/main.");
|
||||
} else {
|
||||
print("✓ Current HEAD is up to date with origin/main.\n");
|
||||
}
|
||||
|
||||
mkdir ".changes";
|
||||
my $currentbranch = `git rev-parse --abbrev-ref HEAD`;
|
||||
chop $currentbranch;
|
||||
$currentbranch =~ s/[^a-zA-Z._-]+/-/g;
|
||||
my $changefile=".changes/$currentbranch.md";
|
||||
my $origmanifestfile=".changes/$currentbranch--manifest.yaml";
|
||||
my $origchangelogfile=".changes/$currentbranch--Changelog.md";
|
||||
|
||||
my $exit_code=system("wget -O $origmanifestfile https://raw.githubusercontent.com/tfhe/spqlios-fft/main/manifest.yaml");
|
||||
if ($exit_code!=0 or ! -f $origmanifestfile) {
|
||||
die("ERROR: failed to download manifest.yaml");
|
||||
}
|
||||
$exit_code=system("wget -O $origchangelogfile https://raw.githubusercontent.com/tfhe/spqlios-fft/main/Changelog.md");
|
||||
if ($exit_code!=0 or ! -f $origchangelogfile) {
|
||||
die("ERROR: failed to download Changelog.md");
|
||||
}
|
||||
|
||||
# read the current version (from origin/main manifest)
|
||||
my $vmajor = 0;
|
||||
my $vminor = 0;
|
||||
my $vpatch = 0;
|
||||
my $versionline = `grep '^version: ' $origmanifestfile | cut -d" " -f2`;
|
||||
chop $versionline;
|
||||
if (not $versionline =~ /^([0-9]+)\.([0-9]+)\.([0-9]+)$/) {
|
||||
die("ERROR: invalid version in manifest file: $versionline\n");
|
||||
} else {
|
||||
$vmajor = int($1);
|
||||
$vminor = int($2);
|
||||
$vpatch = int($3);
|
||||
}
|
||||
print "Version in manifest file: $vmajor.$vminor.$vpatch\n";
|
||||
|
||||
if (not -f $changefile) {
|
||||
## create a changes file
|
||||
open F,">$changefile";
|
||||
print F "# Changefile for branch $currentbranch\n\n";
|
||||
print F "## Type of release (major,minor,patch)?\n\n";
|
||||
print F "releasetype: patch\n\n";
|
||||
print F "## What has changed (please edit)?\n\n";
|
||||
print F "- This has changed.\n";
|
||||
close F;
|
||||
}
|
||||
|
||||
system("editor $changefile");
|
||||
|
||||
# compute the new version
|
||||
my $nvmajor;
|
||||
my $nvminor;
|
||||
my $nvpatch;
|
||||
my $changelog;
|
||||
my $recordchangelog=0;
|
||||
open F,"$changefile";
|
||||
while ($line=<F>) {
|
||||
chop $line;
|
||||
if ($recordchangelog) {
|
||||
($line =~ /^$/) and next;
|
||||
$changelog .= "$line\n";
|
||||
next;
|
||||
}
|
||||
if ($line =~ /^releasetype *: *patch *$/) {
|
||||
$nvmajor=$vmajor;
|
||||
$nvminor=$vminor;
|
||||
$nvpatch=$vpatch+1;
|
||||
}
|
||||
if ($line =~ /^releasetype *: *minor *$/) {
|
||||
$nvmajor=$vmajor;
|
||||
$nvminor=$vminor+1;
|
||||
$nvpatch=0;
|
||||
}
|
||||
if ($line =~ /^releasetype *: *major *$/) {
|
||||
$nvmajor=$vmajor+1;
|
||||
$nvminor=0;
|
||||
$nvpatch=0;
|
||||
}
|
||||
if ($line =~ /^## What has changed/) {
|
||||
$recordchangelog=1;
|
||||
}
|
||||
}
|
||||
close F;
|
||||
print "New version: $nvmajor.$nvminor.$nvpatch\n";
|
||||
print "Changes:\n$changelog";
|
||||
|
||||
# updating manifest.yaml
|
||||
open F,"manifest.yaml";
|
||||
open G,">.changes/manifest.yaml";
|
||||
while ($line=<F>) {
|
||||
if ($line =~ /^version *: */) {
|
||||
print G "version: $nvmajor.$nvminor.$nvpatch\n";
|
||||
next;
|
||||
}
|
||||
print G $line;
|
||||
}
|
||||
close F;
|
||||
close G;
|
||||
# updating Changelog.md
|
||||
open F,"$origchangelogfile";
|
||||
open G,">.changes/Changelog.md";
|
||||
print G <<EOF
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
EOF
|
||||
;
|
||||
print G "## [$nvmajor.$nvminor.$nvpatch] - ".`date '+%Y-%m-%d'`."\n";
|
||||
print G "$changelog\n";
|
||||
my $skip_section=1;
|
||||
while ($line=<F>) {
|
||||
if ($line =~ /^## +\[([0-9]+)\.([0-9]+)\.([0-9]+)\] +/) {
|
||||
if ($1>$nvmajor) {
|
||||
die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n");
|
||||
} elsif ($1<$nvmajor) {
|
||||
$skip_section=0;
|
||||
} elsif ($2>$nvminor) {
|
||||
die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n");
|
||||
} elsif ($2<$nvminor) {
|
||||
$skip_section=0;
|
||||
} elsif ($3>$nvpatch) {
|
||||
die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n");
|
||||
} elsif ($2<$nvpatch) {
|
||||
$skip_section=0;
|
||||
} else {
|
||||
$skip_section=1;
|
||||
}
|
||||
}
|
||||
($skip_section) and next;
|
||||
print G $line;
|
||||
}
|
||||
close F;
|
||||
close G;
|
||||
|
||||
print "-------------------------------------\n";
|
||||
print "THIS WILL BE UPDATED:\n";
|
||||
print "-------------------------------------\n";
|
||||
system("diff -u manifest.yaml .changes/manifest.yaml");
|
||||
system("diff -u Changelog.md .changes/Changelog.md");
|
||||
print "-------------------------------------\n";
|
||||
print "To proceed: press <enter> otherwise <CTRL+C>\n";
|
||||
my $bla;
|
||||
$bla=<STDIN>;
|
||||
system("cp -vf .changes/manifest.yaml manifest.yaml");
|
||||
system("cp -vf .changes/Changelog.md Changelog.md");
|
||||
system("git commit -a -m \"Update version and changelog.\"");
|
||||
system("git push");
|
||||
print("✓ Changes have been committed and pushed!\n");
|
||||
print("✓ A new release will be created when this branch is merged to main.\n");
|
||||
|
||||
@@ -1,223 +0,0 @@
|
||||
enable_language(ASM)
|
||||
|
||||
# C source files that are compiled for all targets (i.e. reference code)
|
||||
set(SRCS_GENERIC
|
||||
commons.c
|
||||
commons_private.c
|
||||
coeffs/coeffs_arithmetic.c
|
||||
arithmetic/vec_znx.c
|
||||
arithmetic/vec_znx_dft.c
|
||||
arithmetic/vector_matrix_product.c
|
||||
cplx/cplx_common.c
|
||||
cplx/cplx_conversions.c
|
||||
cplx/cplx_fft_asserts.c
|
||||
cplx/cplx_fft_ref.c
|
||||
cplx/cplx_fftvec_ref.c
|
||||
cplx/cplx_ifft_ref.c
|
||||
cplx/spqlios_cplx_fft.c
|
||||
reim4/reim4_arithmetic_ref.c
|
||||
reim4/reim4_fftvec_addmul_ref.c
|
||||
reim4/reim4_fftvec_conv_ref.c
|
||||
reim/reim_conversions.c
|
||||
reim/reim_fft_ifft.c
|
||||
reim/reim_fft_ref.c
|
||||
reim/reim_fftvec_ref.c
|
||||
reim/reim_ifft_ref.c
|
||||
reim/reim_ifft_ref.c
|
||||
reim/reim_to_tnx_ref.c
|
||||
q120/q120_ntt.c
|
||||
q120/q120_arithmetic_ref.c
|
||||
q120/q120_arithmetic_simple.c
|
||||
arithmetic/scalar_vector_product.c
|
||||
arithmetic/vec_znx_big.c
|
||||
arithmetic/znx_small.c
|
||||
arithmetic/module_api.c
|
||||
arithmetic/zn_vmp_int8_ref.c
|
||||
arithmetic/zn_vmp_int16_ref.c
|
||||
arithmetic/zn_vmp_int32_ref.c
|
||||
arithmetic/zn_vmp_ref.c
|
||||
arithmetic/zn_api.c
|
||||
arithmetic/zn_conversions_ref.c
|
||||
arithmetic/zn_approxdecomp_ref.c
|
||||
arithmetic/vec_rnx_api.c
|
||||
arithmetic/vec_rnx_conversions_ref.c
|
||||
arithmetic/vec_rnx_svp_ref.c
|
||||
reim/reim_execute.c
|
||||
cplx/cplx_execute.c
|
||||
reim4/reim4_execute.c
|
||||
arithmetic/vec_rnx_arithmetic.c
|
||||
arithmetic/vec_rnx_approxdecomp_ref.c
|
||||
arithmetic/vec_rnx_vmp_ref.c
|
||||
)
|
||||
# C or assembly source files compiled only on x86 targets
|
||||
set(SRCS_X86
|
||||
)
|
||||
# C or assembly source files compiled only on aarch64 targets
|
||||
set(SRCS_AARCH64
|
||||
cplx/cplx_fallbacks_aarch64.c
|
||||
reim/reim_fallbacks_aarch64.c
|
||||
reim4/reim4_fallbacks_aarch64.c
|
||||
q120/q120_fallbacks_aarch64.c
|
||||
reim/reim_fft_neon.c
|
||||
)
|
||||
|
||||
# C or assembly source files compiled only on x86: avx, avx2, fma targets
|
||||
set(SRCS_FMA_C
|
||||
arithmetic/vector_matrix_product_avx.c
|
||||
cplx/cplx_conversions_avx2_fma.c
|
||||
cplx/cplx_fft_avx2_fma.c
|
||||
cplx/cplx_fft_sse.c
|
||||
cplx/cplx_fftvec_avx2_fma.c
|
||||
cplx/cplx_ifft_avx2_fma.c
|
||||
reim4/reim4_arithmetic_avx2.c
|
||||
reim4/reim4_fftvec_conv_fma.c
|
||||
reim4/reim4_fftvec_addmul_fma.c
|
||||
reim/reim_conversions_avx.c
|
||||
reim/reim_fft4_avx_fma.c
|
||||
reim/reim_fft8_avx_fma.c
|
||||
reim/reim_ifft4_avx_fma.c
|
||||
reim/reim_ifft8_avx_fma.c
|
||||
reim/reim_fft_avx2.c
|
||||
reim/reim_ifft_avx2.c
|
||||
reim/reim_to_tnx_avx.c
|
||||
reim/reim_fftvec_fma.c
|
||||
)
|
||||
set(SRCS_FMA_ASM
|
||||
cplx/cplx_fft16_avx_fma.s
|
||||
cplx/cplx_ifft16_avx_fma.s
|
||||
reim/reim_fft16_avx_fma.s
|
||||
reim/reim_ifft16_avx_fma.s
|
||||
)
|
||||
set(SRCS_FMA_WIN32_ASM
|
||||
cplx/cplx_fft16_avx_fma_win32.s
|
||||
cplx/cplx_ifft16_avx_fma_win32.s
|
||||
reim/reim_fft16_avx_fma_win32.s
|
||||
reim/reim_ifft16_avx_fma_win32.s
|
||||
)
|
||||
set_source_files_properties(${SRCS_FMA_C} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx;-mavx2")
|
||||
set_source_files_properties(${SRCS_FMA_ASM} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx;-mavx2")
|
||||
|
||||
# C or assembly source files compiled only on x86: avx512f/vl/dq + fma targets
|
||||
set(SRCS_AVX512
|
||||
cplx/cplx_fft_avx512.c
|
||||
)
|
||||
set_source_files_properties(${SRCS_AVX512} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx512f;-mavx512vl;-mavx512dq")
|
||||
|
||||
# C or assembly source files compiled only on x86: avx2 + bmi targets
|
||||
set(SRCS_AVX2
|
||||
arithmetic/vec_znx_avx.c
|
||||
coeffs/coeffs_arithmetic_avx.c
|
||||
arithmetic/vec_znx_dft_avx2.c
|
||||
arithmetic/zn_vmp_int8_avx.c
|
||||
arithmetic/zn_vmp_int16_avx.c
|
||||
arithmetic/zn_vmp_int32_avx.c
|
||||
q120/q120_arithmetic_avx2.c
|
||||
q120/q120_ntt_avx2.c
|
||||
arithmetic/vec_rnx_arithmetic_avx.c
|
||||
arithmetic/vec_rnx_approxdecomp_avx.c
|
||||
arithmetic/vec_rnx_vmp_avx.c
|
||||
|
||||
)
|
||||
set_source_files_properties(${SRCS_AVX2} PROPERTIES COMPILE_OPTIONS "-mbmi2;-mavx2")
|
||||
|
||||
# C source files on float128 via libquadmath on x86 targets targets
|
||||
set(SRCS_F128
|
||||
cplx_f128/cplx_fft_f128.c
|
||||
cplx_f128/cplx_fft_f128.h
|
||||
)
|
||||
|
||||
# H header files containing the public API (these headers are installed)
|
||||
set(HEADERSPUBLIC
|
||||
commons.h
|
||||
arithmetic/vec_znx_arithmetic.h
|
||||
arithmetic/vec_rnx_arithmetic.h
|
||||
arithmetic/zn_arithmetic.h
|
||||
cplx/cplx_fft.h
|
||||
reim/reim_fft.h
|
||||
q120/q120_common.h
|
||||
q120/q120_arithmetic.h
|
||||
q120/q120_ntt.h
|
||||
)
|
||||
|
||||
# H header files containing the private API (these headers are used internally)
|
||||
set(HEADERSPRIVATE
|
||||
commons_private.h
|
||||
cplx/cplx_fft_internal.h
|
||||
cplx/cplx_fft_private.h
|
||||
reim4/reim4_arithmetic.h
|
||||
reim4/reim4_fftvec_internal.h
|
||||
reim4/reim4_fftvec_private.h
|
||||
reim4/reim4_fftvec_public.h
|
||||
reim/reim_fft_internal.h
|
||||
reim/reim_fft_private.h
|
||||
q120/q120_arithmetic_private.h
|
||||
q120/q120_ntt_private.h
|
||||
arithmetic/vec_znx_arithmetic.h
|
||||
arithmetic/vec_rnx_arithmetic_private.h
|
||||
arithmetic/vec_rnx_arithmetic_plugin.h
|
||||
arithmetic/zn_arithmetic_private.h
|
||||
arithmetic/zn_arithmetic_plugin.h
|
||||
coeffs/coeffs_arithmetic.h
|
||||
reim/reim_fft_core_template.h
|
||||
)
|
||||
|
||||
set(SPQLIOSSOURCES
|
||||
${SRCS_GENERIC}
|
||||
${HEADERSPUBLIC}
|
||||
${HEADERSPRIVATE}
|
||||
)
|
||||
if (${X86})
|
||||
set(SPQLIOSSOURCES ${SPQLIOSSOURCES}
|
||||
${SRCS_X86}
|
||||
${SRCS_FMA_C}
|
||||
${SRCS_FMA_ASM}
|
||||
${SRCS_AVX2}
|
||||
${SRCS_AVX512}
|
||||
)
|
||||
elseif (${X86_WIN32})
|
||||
set(SPQLIOSSOURCES ${SPQLIOSSOURCES}
|
||||
#${SRCS_X86}
|
||||
${SRCS_FMA_C}
|
||||
${SRCS_FMA_WIN32_ASM}
|
||||
${SRCS_AVX2}
|
||||
${SRCS_AVX512}
|
||||
)
|
||||
elseif (${AARCH64})
|
||||
set(SPQLIOSSOURCES ${SPQLIOSSOURCES}
|
||||
${SRCS_AARCH64}
|
||||
)
|
||||
endif ()
|
||||
|
||||
|
||||
set(SPQLIOSLIBDEP
|
||||
m # libmath depencency for cosinus/sinus functions
|
||||
)
|
||||
|
||||
if (ENABLE_SPQLIOS_F128)
|
||||
find_library(quadmath REQUIRED NAMES quadmath)
|
||||
set(SPQLIOSSOURCES ${SPQLIOSSOURCES} ${SRCS_F128})
|
||||
set(SPQLIOSLIBDEP ${SPQLIOSLIBDEP} quadmath)
|
||||
endif (ENABLE_SPQLIOS_F128)
|
||||
|
||||
add_library(libspqlios-static STATIC ${SPQLIOSSOURCES})
|
||||
add_library(libspqlios SHARED ${SPQLIOSSOURCES})
|
||||
set_property(TARGET libspqlios-static PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET libspqlios PROPERTY OUTPUT_NAME spqlios)
|
||||
set_property(TARGET libspqlios-static PROPERTY OUTPUT_NAME spqlios)
|
||||
set_property(TARGET libspqlios PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET libspqlios PROPERTY SOVERSION ${SPQLIOS_VERSION_MAJOR})
|
||||
set_property(TARGET libspqlios PROPERTY VERSION ${SPQLIOS_VERSION})
|
||||
if (NOT APPLE)
|
||||
target_link_options(libspqlios-static PUBLIC -Wl,--no-undefined)
|
||||
target_link_options(libspqlios PUBLIC -Wl,--no-undefined)
|
||||
endif()
|
||||
target_link_libraries(libspqlios ${SPQLIOSLIBDEP})
|
||||
target_link_libraries(libspqlios-static ${SPQLIOSLIBDEP})
|
||||
install(TARGETS libspqlios-static)
|
||||
install(TARGETS libspqlios)
|
||||
|
||||
# install the public headers only
|
||||
foreach (file ${HEADERSPUBLIC})
|
||||
get_filename_component(dir ${file} DIRECTORY)
|
||||
install(FILES ${file} DESTINATION include/spqlios/${dir})
|
||||
endforeach ()
|
||||
@@ -1,172 +0,0 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
static void fill_generic_virtual_table(MODULE* module) {
|
||||
// TODO add default ref handler here
|
||||
module->func.vec_znx_zero = vec_znx_zero_ref;
|
||||
module->func.vec_znx_copy = vec_znx_copy_ref;
|
||||
module->func.vec_znx_negate = vec_znx_negate_ref;
|
||||
module->func.vec_znx_add = vec_znx_add_ref;
|
||||
module->func.vec_znx_sub = vec_znx_sub_ref;
|
||||
module->func.vec_znx_rotate = vec_znx_rotate_ref;
|
||||
module->func.vec_znx_mul_xp_minus_one = vec_znx_mul_xp_minus_one_ref;
|
||||
module->func.vec_znx_automorphism = vec_znx_automorphism_ref;
|
||||
module->func.vec_znx_normalize_base2k = vec_znx_normalize_base2k_ref;
|
||||
module->func.vec_znx_normalize_base2k_tmp_bytes = vec_znx_normalize_base2k_tmp_bytes_ref;
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
// TODO add avx handlers here
|
||||
module->func.vec_znx_negate = vec_znx_negate_avx;
|
||||
module->func.vec_znx_add = vec_znx_add_avx;
|
||||
module->func.vec_znx_sub = vec_znx_sub_avx;
|
||||
}
|
||||
}
|
||||
|
||||
static void fill_fft64_virtual_table(MODULE* module) {
|
||||
// TODO add default ref handler here
|
||||
// module->func.vec_znx_dft = ...;
|
||||
module->func.vec_znx_big_normalize_base2k = fft64_vec_znx_big_normalize_base2k;
|
||||
module->func.vec_znx_big_normalize_base2k_tmp_bytes = fft64_vec_znx_big_normalize_base2k_tmp_bytes;
|
||||
module->func.vec_znx_big_range_normalize_base2k = fft64_vec_znx_big_range_normalize_base2k;
|
||||
module->func.vec_znx_big_range_normalize_base2k_tmp_bytes = fft64_vec_znx_big_range_normalize_base2k_tmp_bytes;
|
||||
module->func.vec_znx_dft = fft64_vec_znx_dft;
|
||||
module->func.vec_znx_idft = fft64_vec_znx_idft;
|
||||
module->func.vec_dft_add = fft64_vec_dft_add;
|
||||
module->func.vec_dft_sub = fft64_vec_dft_sub;
|
||||
module->func.vec_znx_idft_tmp_bytes = fft64_vec_znx_idft_tmp_bytes;
|
||||
module->func.vec_znx_idft_tmp_a = fft64_vec_znx_idft_tmp_a;
|
||||
module->func.vec_znx_big_add = fft64_vec_znx_big_add;
|
||||
module->func.vec_znx_big_add_small = fft64_vec_znx_big_add_small;
|
||||
module->func.vec_znx_big_add_small2 = fft64_vec_znx_big_add_small2;
|
||||
module->func.vec_znx_big_sub = fft64_vec_znx_big_sub;
|
||||
module->func.vec_znx_big_sub_small_a = fft64_vec_znx_big_sub_small_a;
|
||||
module->func.vec_znx_big_sub_small_b = fft64_vec_znx_big_sub_small_b;
|
||||
module->func.vec_znx_big_sub_small2 = fft64_vec_znx_big_sub_small2;
|
||||
module->func.vec_znx_big_rotate = fft64_vec_znx_big_rotate;
|
||||
module->func.vec_znx_big_automorphism = fft64_vec_znx_big_automorphism;
|
||||
module->func.svp_prepare = fft64_svp_prepare_ref;
|
||||
module->func.svp_apply_dft = fft64_svp_apply_dft_ref;
|
||||
module->func.svp_apply_dft_to_dft = fft64_svp_apply_dft_to_dft_ref;
|
||||
module->func.znx_small_single_product = fft64_znx_small_single_product;
|
||||
module->func.znx_small_single_product_tmp_bytes = fft64_znx_small_single_product_tmp_bytes;
|
||||
module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_ref;
|
||||
module->func.vmp_prepare_tmp_bytes = fft64_vmp_prepare_tmp_bytes;
|
||||
module->func.vmp_apply_dft = fft64_vmp_apply_dft_ref;
|
||||
module->func.vmp_apply_dft_add = fft64_vmp_apply_dft_add_ref;
|
||||
module->func.vmp_apply_dft_tmp_bytes = fft64_vmp_apply_dft_tmp_bytes;
|
||||
module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_ref;
|
||||
module->func.vmp_apply_dft_to_dft_add = fft64_vmp_apply_dft_to_dft_add_ref;
|
||||
module->func.vmp_apply_dft_to_dft_tmp_bytes = fft64_vmp_apply_dft_to_dft_tmp_bytes;
|
||||
module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft;
|
||||
module->func.bytes_of_vec_znx_big = fft64_bytes_of_vec_znx_big;
|
||||
module->func.bytes_of_svp_ppol = fft64_bytes_of_svp_ppol;
|
||||
module->func.bytes_of_vmp_pmat = fft64_bytes_of_vmp_pmat;
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
// TODO add avx handlers here
|
||||
// TODO: enable when avx implementation is done
|
||||
module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_avx;
|
||||
module->func.vmp_apply_dft = fft64_vmp_apply_dft_avx;
|
||||
module->func.vmp_apply_dft_add = fft64_vmp_apply_dft_add_avx;
|
||||
module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_avx;
|
||||
module->func.vmp_apply_dft_to_dft_add = fft64_vmp_apply_dft_to_dft_add_avx;
|
||||
}
|
||||
}
|
||||
|
||||
static void fill_ntt120_virtual_table(MODULE* module) {
|
||||
// TODO add default ref handler here
|
||||
// module->func.vec_znx_dft = ...;
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
// TODO add avx handlers here
|
||||
module->func.vec_znx_dft = ntt120_vec_znx_dft_avx;
|
||||
module->func.vec_znx_idft = ntt120_vec_znx_idft_avx;
|
||||
module->func.vec_znx_idft_tmp_bytes = ntt120_vec_znx_idft_tmp_bytes_avx;
|
||||
module->func.vec_znx_idft_tmp_a = ntt120_vec_znx_idft_tmp_a_avx;
|
||||
}
|
||||
}
|
||||
|
||||
static void fill_virtual_table(MODULE* module) {
|
||||
fill_generic_virtual_table(module);
|
||||
switch (module->module_type) {
|
||||
case FFT64:
|
||||
fill_fft64_virtual_table(module);
|
||||
break;
|
||||
case NTT120:
|
||||
fill_ntt120_virtual_table(module);
|
||||
break;
|
||||
default:
|
||||
NOT_SUPPORTED(); // invalid type
|
||||
}
|
||||
}
|
||||
|
||||
static void fill_fft64_precomp(MODULE* module) {
|
||||
// fill any necessary precomp stuff
|
||||
module->mod.fft64.p_conv = new_reim_from_znx64_precomp(module->m, 50);
|
||||
module->mod.fft64.p_fft = new_reim_fft_precomp(module->m, 0);
|
||||
module->mod.fft64.p_reim_to_znx = new_reim_to_znx64_precomp(module->m, module->m, 63);
|
||||
module->mod.fft64.p_ifft = new_reim_ifft_precomp(module->m, 0);
|
||||
module->mod.fft64.p_addmul = new_reim_fftvec_addmul_precomp(module->m);
|
||||
module->mod.fft64.mul_fft = new_reim_fftvec_mul_precomp(module->m);
|
||||
module->mod.fft64.add_fft = new_reim_fftvec_add_precomp(module->m);
|
||||
module->mod.fft64.sub_fft = new_reim_fftvec_sub_precomp(module->m);
|
||||
}
|
||||
static void fill_ntt120_precomp(MODULE* module) {
|
||||
// fill any necessary precomp stuff
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
module->mod.q120.p_ntt = q120_new_ntt_bb_precomp(module->nn);
|
||||
module->mod.q120.p_intt = q120_new_intt_bb_precomp(module->nn);
|
||||
}
|
||||
}
|
||||
|
||||
static void fill_module_precomp(MODULE* module) {
|
||||
switch (module->module_type) {
|
||||
case FFT64:
|
||||
fill_fft64_precomp(module);
|
||||
break;
|
||||
case NTT120:
|
||||
fill_ntt120_precomp(module);
|
||||
break;
|
||||
default:
|
||||
NOT_SUPPORTED(); // invalid type
|
||||
}
|
||||
}
|
||||
|
||||
static void fill_module(MODULE* module, uint64_t nn, MODULE_TYPE mtype) {
|
||||
// init to zero to ensure that any non-initialized field bug is detected
|
||||
// by at least a "proper" segfault
|
||||
memset(module, 0, sizeof(MODULE));
|
||||
module->module_type = mtype;
|
||||
module->nn = nn;
|
||||
module->m = nn >> 1;
|
||||
fill_module_precomp(module);
|
||||
fill_virtual_table(module);
|
||||
}
|
||||
|
||||
EXPORT MODULE* new_module_info(uint64_t N, MODULE_TYPE mtype) {
|
||||
MODULE* m = (MODULE*)malloc(sizeof(MODULE));
|
||||
fill_module(m, N, mtype);
|
||||
return m;
|
||||
}
|
||||
|
||||
EXPORT void delete_module_info(MODULE* mod) {
|
||||
switch (mod->module_type) {
|
||||
case FFT64:
|
||||
free(mod->mod.fft64.p_conv);
|
||||
free(mod->mod.fft64.p_fft);
|
||||
free(mod->mod.fft64.p_ifft);
|
||||
free(mod->mod.fft64.p_reim_to_znx);
|
||||
free(mod->mod.fft64.mul_fft);
|
||||
free(mod->mod.fft64.p_addmul);
|
||||
break;
|
||||
case NTT120:
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
q120_del_ntt_bb_precomp(mod->mod.q120.p_ntt);
|
||||
q120_del_intt_bb_precomp(mod->mod.q120.p_intt);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
free(mod);
|
||||
}
|
||||
|
||||
EXPORT uint64_t module_get_n(const MODULE* module) { return module->nn; }
|
||||
@@ -1,102 +0,0 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
EXPORT uint64_t bytes_of_svp_ppol(const MODULE* module) { return module->func.bytes_of_svp_ppol(module); }
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_svp_ppol(const MODULE* module) { return module->nn * sizeof(double); }
|
||||
|
||||
EXPORT SVP_PPOL* new_svp_ppol(const MODULE* module) { return spqlios_alloc(bytes_of_svp_ppol(module)); }
|
||||
|
||||
EXPORT void delete_svp_ppol(SVP_PPOL* ppol) { spqlios_free(ppol); }
|
||||
|
||||
// public wrappers
|
||||
EXPORT void svp_prepare(const MODULE* module, // N
|
||||
SVP_PPOL* ppol, // output
|
||||
const int64_t* pol // a
|
||||
) {
|
||||
module->func.svp_prepare(module, ppol, pol);
|
||||
}
|
||||
|
||||
/** @brief prepares a svp polynomial */
|
||||
EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N
|
||||
SVP_PPOL* ppol, // output
|
||||
const int64_t* pol // a
|
||||
) {
|
||||
reim_from_znx64(module->mod.fft64.p_conv, ppol, pol);
|
||||
reim_fft(module->mod.fft64.p_fft, (double*)ppol);
|
||||
}
|
||||
|
||||
EXPORT void svp_apply_dft(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl) {
|
||||
module->func.svp_apply_dft(module, // N
|
||||
res,
|
||||
res_size, // output
|
||||
ppol, // prepared pol
|
||||
a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void svp_apply_dft_to_dft(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size,
|
||||
uint64_t res_cols, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, uint64_t a_cols) {
|
||||
module->func.svp_apply_dft_to_dft(module, // N
|
||||
res, res_size, res_cols, // output
|
||||
ppol, a, a_size, a_cols // prepared pol
|
||||
);
|
||||
}
|
||||
|
||||
// result = ppol * a
|
||||
EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
double* const dres = (double*)res;
|
||||
double* const dppol = (double*)ppol;
|
||||
|
||||
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||
const int64_t* a_ptr = a + i * a_sl;
|
||||
double* const res_ptr = dres + i * nn;
|
||||
// copy the polynomial to res, apply fft in place, call fftvec_mul in place.
|
||||
reim_from_znx64(module->mod.fft64.p_conv, res_ptr, a_ptr);
|
||||
reim_fft(module->mod.fft64.p_fft, res_ptr);
|
||||
reim_fftvec_mul(module->mod.fft64.mul_fft, res_ptr, res_ptr, dppol);
|
||||
}
|
||||
|
||||
// then extend with zeros
|
||||
memset(dres + auto_end_idx * nn, 0, (res_size - auto_end_idx) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
// result = ppol * a
|
||||
EXPORT void fft64_svp_apply_dft_to_dft_ref(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size,
|
||||
uint64_t res_cols, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size,
|
||||
uint64_t a_cols // a
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t res_sl = nn * res_cols;
|
||||
const uint64_t a_sl = nn * a_cols;
|
||||
double* const dres = (double*)res;
|
||||
double* const da = (double*)a;
|
||||
double* const dppol = (double*)ppol;
|
||||
|
||||
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||
const double* a_ptr = da + i * a_sl;
|
||||
double* const res_ptr = dres + i * res_sl;
|
||||
reim_fftvec_mul(module->mod.fft64.mul_fft, res_ptr, a_ptr, dppol);
|
||||
}
|
||||
|
||||
// then extend with zeros
|
||||
for (uint64_t i = auto_end_idx; i < res_size; i++) {
|
||||
memset(dres + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
@@ -1,344 +0,0 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
void fft64_init_rnx_module_precomp(MOD_RNX* module) {
|
||||
// Add here initialization of items that are in the precomp
|
||||
const uint64_t m = module->m;
|
||||
module->precomp.fft64.p_fft = new_reim_fft_precomp(m, 0);
|
||||
module->precomp.fft64.p_ifft = new_reim_ifft_precomp(m, 0);
|
||||
module->precomp.fft64.p_fftvec_add = new_reim_fftvec_add_precomp(m);
|
||||
module->precomp.fft64.p_fftvec_mul = new_reim_fftvec_mul_precomp(m);
|
||||
module->precomp.fft64.p_fftvec_addmul = new_reim_fftvec_addmul_precomp(m);
|
||||
}
|
||||
|
||||
void fft64_finalize_rnx_module_precomp(MOD_RNX* module) {
|
||||
// Add here deleters for items that are in the precomp
|
||||
delete_reim_fft_precomp(module->precomp.fft64.p_fft);
|
||||
delete_reim_ifft_precomp(module->precomp.fft64.p_ifft);
|
||||
delete_reim_fftvec_add_precomp(module->precomp.fft64.p_fftvec_add);
|
||||
delete_reim_fftvec_mul_precomp(module->precomp.fft64.p_fftvec_mul);
|
||||
delete_reim_fftvec_addmul_precomp(module->precomp.fft64.p_fftvec_addmul);
|
||||
}
|
||||
|
||||
void fft64_init_rnx_module_vtable(MOD_RNX* module) {
|
||||
// Add function pointers here
|
||||
module->vtable.vec_rnx_add = vec_rnx_add_ref;
|
||||
module->vtable.vec_rnx_zero = vec_rnx_zero_ref;
|
||||
module->vtable.vec_rnx_copy = vec_rnx_copy_ref;
|
||||
module->vtable.vec_rnx_negate = vec_rnx_negate_ref;
|
||||
module->vtable.vec_rnx_sub = vec_rnx_sub_ref;
|
||||
module->vtable.vec_rnx_rotate = vec_rnx_rotate_ref;
|
||||
module->vtable.vec_rnx_automorphism = vec_rnx_automorphism_ref;
|
||||
module->vtable.vec_rnx_mul_xp_minus_one = vec_rnx_mul_xp_minus_one_ref;
|
||||
module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref;
|
||||
module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_ref;
|
||||
module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref;
|
||||
module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_ref;
|
||||
module->vtable.rnx_vmp_prepare_tmp_bytes = fft64_rnx_vmp_prepare_tmp_bytes_ref;
|
||||
module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_ref;
|
||||
module->vtable.rnx_vmp_prepare_dblptr = fft64_rnx_vmp_prepare_dblptr_ref;
|
||||
module->vtable.rnx_vmp_prepare_row = fft64_rnx_vmp_prepare_row_ref;
|
||||
module->vtable.bytes_of_rnx_vmp_pmat = fft64_bytes_of_rnx_vmp_pmat;
|
||||
module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_ref;
|
||||
module->vtable.vec_rnx_to_znx32 = vec_rnx_to_znx32_ref;
|
||||
module->vtable.vec_rnx_from_znx32 = vec_rnx_from_znx32_ref;
|
||||
module->vtable.vec_rnx_to_tnx32 = vec_rnx_to_tnx32_ref;
|
||||
module->vtable.vec_rnx_from_tnx32 = vec_rnx_from_tnx32_ref;
|
||||
module->vtable.vec_rnx_to_tnxdbl = vec_rnx_to_tnxdbl_ref;
|
||||
module->vtable.bytes_of_rnx_svp_ppol = fft64_bytes_of_rnx_svp_ppol;
|
||||
module->vtable.rnx_svp_prepare = fft64_rnx_svp_prepare_ref;
|
||||
module->vtable.rnx_svp_apply = fft64_rnx_svp_apply_ref;
|
||||
|
||||
// Add optimized function pointers here
|
||||
if (CPU_SUPPORTS("avx")) {
|
||||
module->vtable.vec_rnx_add = vec_rnx_add_avx;
|
||||
module->vtable.vec_rnx_sub = vec_rnx_sub_avx;
|
||||
module->vtable.vec_rnx_negate = vec_rnx_negate_avx;
|
||||
module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx;
|
||||
module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_avx;
|
||||
module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx;
|
||||
module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_avx;
|
||||
module->vtable.rnx_vmp_prepare_tmp_bytes = fft64_rnx_vmp_prepare_tmp_bytes_avx;
|
||||
module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_avx;
|
||||
module->vtable.rnx_vmp_prepare_dblptr = fft64_rnx_vmp_prepare_dblptr_avx;
|
||||
module->vtable.rnx_vmp_prepare_row = fft64_rnx_vmp_prepare_row_avx;
|
||||
module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_avx;
|
||||
}
|
||||
}
|
||||
|
||||
void init_rnx_module_info(MOD_RNX* module, //
|
||||
uint64_t n, RNX_MODULE_TYPE mtype) {
|
||||
memset(module, 0, sizeof(MOD_RNX));
|
||||
module->n = n;
|
||||
module->m = n >> 1;
|
||||
module->mtype = mtype;
|
||||
switch (mtype) {
|
||||
case FFT64:
|
||||
fft64_init_rnx_module_precomp(module);
|
||||
fft64_init_rnx_module_vtable(module);
|
||||
break;
|
||||
default:
|
||||
NOT_SUPPORTED(); // unknown mtype
|
||||
}
|
||||
}
|
||||
|
||||
void finalize_rnx_module_info(MOD_RNX* module) {
|
||||
if (module->custom) module->custom_deleter(module->custom);
|
||||
switch (module->mtype) {
|
||||
case FFT64:
|
||||
fft64_finalize_rnx_module_precomp(module);
|
||||
// fft64_finalize_rnx_module_vtable(module); // nothing to finalize
|
||||
break;
|
||||
default:
|
||||
NOT_SUPPORTED(); // unknown mtype
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT MOD_RNX* new_rnx_module_info(uint64_t nn, RNX_MODULE_TYPE mtype) {
|
||||
MOD_RNX* res = (MOD_RNX*)malloc(sizeof(MOD_RNX));
|
||||
init_rnx_module_info(res, nn, mtype);
|
||||
return res;
|
||||
}
|
||||
|
||||
EXPORT void delete_rnx_module_info(MOD_RNX* module_info) {
|
||||
finalize_rnx_module_info(module_info);
|
||||
free(module_info);
|
||||
}
|
||||
|
||||
EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module) { return module->n; }
|
||||
|
||||
/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */
|
||||
EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||
uint64_t nrows, uint64_t ncols) { // dimensions
|
||||
return (RNX_VMP_PMAT*)spqlios_alloc(bytes_of_rnx_vmp_pmat(module, nrows, ncols));
|
||||
}
|
||||
EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr) { spqlios_free(ptr); }
|
||||
|
||||
//////////////// wrappers //////////////////
|
||||
|
||||
/** @brief sets res = a + b */
|
||||
EXPORT void vec_rnx_add( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->vtable.vec_rnx_add(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = 0 */
|
||||
EXPORT void vec_rnx_zero( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||
) {
|
||||
module->vtable.vec_rnx_zero(module, res, res_size, res_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_rnx_copy( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_copy(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = -a */
|
||||
EXPORT void vec_rnx_negate( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_negate(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a - b */
|
||||
EXPORT void vec_rnx_sub( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->vtable.vec_rnx_sub(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void vec_rnx_rotate( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_rotate(module, p, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void vec_rnx_automorphism( //
|
||||
const MOD_RNX* module, // N
|
||||
int64_t p, // X -> X^p
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_automorphism(module, p, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_mul_xp_minus_one( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_mul_xp_minus_one(module, p, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||
EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||
uint64_t nrows, uint64_t ncols) { // dimensions
|
||||
return module->vtable.bytes_of_rnx_vmp_pmat(module, nrows, ncols);
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void rnx_vmp_prepare_contiguous( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* a, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
module->vtable.rnx_vmp_prepare_contiguous(module, pmat, a, nrows, ncols, tmp_space);
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void rnx_vmp_prepare_dblptr( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double** a, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
module->vtable.rnx_vmp_prepare_dblptr(module, pmat, a, nrows, ncols, tmp_space);
|
||||
}
|
||||
|
||||
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||
EXPORT void rnx_vmp_prepare_row( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* a, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
module->vtable.rnx_vmp_prepare_row(module, pmat, a, row_i, nrows, ncols, tmp_space);
|
||||
}
|
||||
|
||||
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||
EXPORT uint64_t rnx_vmp_prepare_tmp_bytes(const MOD_RNX* module) {
|
||||
return module->vtable.rnx_vmp_prepare_tmp_bytes(module);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product res = a x pmat */
|
||||
EXPORT void rnx_vmp_apply_tmp_a( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
module->vtable.rnx_vmp_apply_tmp_a(module, res, res_size, res_sl, tmpa, a_size, a_sl, pmat, nrows, ncols, tmp_space);
|
||||
}
|
||||
|
||||
EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res size
|
||||
uint64_t a_size, // a size
|
||||
uint64_t nrows, uint64_t ncols // prep matrix dims
|
||||
) {
|
||||
return module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes(module, res_size, a_size, nrows, ncols);
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT void rnx_vmp_apply_dft_to_dft( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
module->vtable.rnx_vmp_apply_dft_to_dft(module, res, res_size, res_sl, a_dft, a_size, a_sl, pmat, nrows, ncols,
|
||||
tmp_space);
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) {
|
||||
return module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes(module, res_size, a_size, nrows, ncols);
|
||||
}
|
||||
|
||||
EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->vtable.bytes_of_rnx_svp_ppol(module); }
|
||||
|
||||
EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N
|
||||
RNX_SVP_PPOL* ppol, // output
|
||||
const double* pol // a
|
||||
) {
|
||||
module->vtable.rnx_svp_prepare(module, ppol, pol);
|
||||
}
|
||||
|
||||
EXPORT void rnx_svp_apply( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.rnx_svp_apply(module, // N
|
||||
res, res_size, res_sl, // output
|
||||
ppol, // prepared pol
|
||||
a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void rnx_approxdecomp_from_tnxdbl( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a) { // a
|
||||
module->vtable.rnx_approxdecomp_from_tnxdbl(module, gadget, res, res_size, res_sl, a);
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_to_znx32( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_to_znx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_from_znx32( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_from_znx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_to_tnx32( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_to_tnx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_from_tnx32( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_from_tnx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_to_tnxdbl( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_to_tnxdbl(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
#include <memory.h>
|
||||
|
||||
#include "immintrin.h"
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) */
|
||||
EXPORT void rnx_approxdecomp_from_tnxdbl_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
if (nn < 4) return rnx_approxdecomp_from_tnxdbl_ref(module, gadget, res, res_size, res_sl, a);
|
||||
const uint64_t ell = gadget->ell;
|
||||
const __m256i k = _mm256_set1_epi64x(gadget->k);
|
||||
const __m256d add_cst = _mm256_set1_pd(gadget->add_cst);
|
||||
const __m256i and_mask = _mm256_set1_epi64x(gadget->and_mask);
|
||||
const __m256i or_mask = _mm256_set1_epi64x(gadget->or_mask);
|
||||
const __m256d sub_cst = _mm256_set1_pd(gadget->sub_cst);
|
||||
const uint64_t msize = res_size <= ell ? res_size : ell;
|
||||
// gadget decompose column by column
|
||||
if (msize == ell) {
|
||||
// this is the main scenario when msize == ell
|
||||
double* const last_r = res + (msize - 1) * res_sl;
|
||||
for (uint64_t j = 0; j < nn; j += 4) {
|
||||
double* rr = last_r + j;
|
||||
const double* aa = a + j;
|
||||
__m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst);
|
||||
__m256i t_int = _mm256_castpd_si256(t_dbl);
|
||||
do {
|
||||
__m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask);
|
||||
_mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst));
|
||||
t_int = _mm256_srlv_epi64(t_int, k);
|
||||
rr -= res_sl;
|
||||
} while (rr >= res);
|
||||
}
|
||||
} else if (msize > 0) {
|
||||
// otherwise, if msize < ell: there is one additional rshift
|
||||
const __m256i first_rsh = _mm256_set1_epi64x((ell - msize) * gadget->k);
|
||||
double* const last_r = res + (msize - 1) * res_sl;
|
||||
for (uint64_t j = 0; j < nn; j += 4) {
|
||||
double* rr = last_r + j;
|
||||
const double* aa = a + j;
|
||||
__m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst);
|
||||
__m256i t_int = _mm256_srlv_epi64(_mm256_castpd_si256(t_dbl), first_rsh);
|
||||
do {
|
||||
__m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask);
|
||||
_mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst));
|
||||
t_int = _mm256_srlv_epi64(t_int, k);
|
||||
rr -= res_sl;
|
||||
} while (rr >= res);
|
||||
}
|
||||
}
|
||||
// zero-out the last slices (if any)
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
#include <memory.h>
|
||||
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
typedef union di {
|
||||
double dv;
|
||||
uint64_t uv;
|
||||
} di_t;
|
||||
|
||||
/** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */
|
||||
EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t k, uint64_t ell // base 2^K and size
|
||||
) {
|
||||
if (k * ell > 50) return spqlios_error("gadget requires a too large fp precision");
|
||||
TNXDBL_APPROXDECOMP_GADGET* res = spqlios_alloc(sizeof(TNXDBL_APPROXDECOMP_GADGET));
|
||||
res->k = k;
|
||||
res->ell = ell;
|
||||
// double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[)
|
||||
union di add_cst;
|
||||
add_cst.dv = UINT64_C(3) << (51 - ell * k);
|
||||
for (uint64_t i = 0; i < ell; ++i) {
|
||||
add_cst.uv |= UINT64_C(1) << ((i + 1) * k - 1);
|
||||
}
|
||||
res->add_cst = add_cst.dv;
|
||||
// uint64_t and_mask; // uint64(2^(K)-1)
|
||||
res->and_mask = (UINT64_C(1) << k) - 1;
|
||||
// uint64_t or_mask; // double(2^52)
|
||||
union di or_mask;
|
||||
or_mask.dv = (UINT64_C(1) << 52);
|
||||
res->or_mask = or_mask.uv;
|
||||
// double sub_cst; // double(2^52 + 2^(K-1))
|
||||
res->sub_cst = ((UINT64_C(1) << 52) + (UINT64_C(1) << (k - 1)));
|
||||
return res;
|
||||
}
|
||||
|
||||
EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget) { spqlios_free(gadget); }
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) */
|
||||
EXPORT void rnx_approxdecomp_from_tnxdbl_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t k = gadget->k;
|
||||
const uint64_t ell = gadget->ell;
|
||||
const double add_cst = gadget->add_cst;
|
||||
const uint64_t and_mask = gadget->and_mask;
|
||||
const uint64_t or_mask = gadget->or_mask;
|
||||
const double sub_cst = gadget->sub_cst;
|
||||
const uint64_t msize = res_size <= ell ? res_size : ell;
|
||||
const uint64_t first_rsh = (ell - msize) * k;
|
||||
// gadget decompose column by column
|
||||
if (msize > 0) {
|
||||
double* const last_r = res + (msize - 1) * res_sl;
|
||||
for (uint64_t j = 0; j < nn; ++j) {
|
||||
double* rr = last_r + j;
|
||||
di_t t = {.dv = a[j] + add_cst};
|
||||
if (msize < ell) t.uv >>= first_rsh;
|
||||
do {
|
||||
di_t u;
|
||||
u.uv = (t.uv & and_mask) | or_mask;
|
||||
*rr = u.dv - sub_cst;
|
||||
t.uv >>= k;
|
||||
rr -= res_sl;
|
||||
} while (rr >= res);
|
||||
}
|
||||
}
|
||||
// zero-out the last slices (if any)
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
@@ -1,223 +0,0 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "../coeffs/coeffs_arithmetic.h"
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
void rnx_add_ref(uint64_t nn, double* res, const double* a, const double* b) {
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
res[i] = a[i] + b[i];
|
||||
}
|
||||
}
|
||||
|
||||
void rnx_sub_ref(uint64_t nn, double* res, const double* a, const double* b) {
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
res[i] = a[i] - b[i];
|
||||
}
|
||||
}
|
||||
|
||||
void rnx_negate_ref(uint64_t nn, double* res, const double* a) {
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
res[i] = -a[i];
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = a + b */
|
||||
EXPORT void vec_rnx_add_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
if (a_size < b_size) {
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_add_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
memcpy(res + i * res_sl, b + i * b_sl, nn * sizeof(double));
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
} else {
|
||||
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_add_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = 0 */
|
||||
EXPORT void vec_rnx_zero_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
for (uint64_t i = 0; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_rnx_copy_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
|
||||
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||
// rotate up to the smallest dimension
|
||||
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||
double* res_ptr = res + i * res_sl;
|
||||
const double* a_ptr = a + i * a_sl;
|
||||
memcpy(res_ptr, a_ptr, nn * sizeof(double));
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = -a */
|
||||
EXPORT void vec_rnx_negate_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
|
||||
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||
// rotate up to the smallest dimension
|
||||
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||
double* res_ptr = res + i * res_sl;
|
||||
const double* a_ptr = a + i * a_sl;
|
||||
rnx_negate_ref(nn, res_ptr, a_ptr);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = a - b */
|
||||
EXPORT void vec_rnx_sub_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
if (a_size < b_size) {
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_sub_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
rnx_negate_ref(nn, res + i * res_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
} else {
|
||||
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_sub_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void vec_rnx_rotate_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
|
||||
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||
// rotate up to the smallest dimension
|
||||
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||
double* res_ptr = res + i * res_sl;
|
||||
const double* a_ptr = a + i * a_sl;
|
||||
if (res_ptr == a_ptr) {
|
||||
rnx_rotate_inplace_f64(nn, p, res_ptr);
|
||||
} else {
|
||||
rnx_rotate_f64(nn, p, res_ptr, a_ptr);
|
||||
}
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void vec_rnx_automorphism_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
int64_t p, // X -> X^p
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
|
||||
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||
// rotate up to the smallest dimension
|
||||
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||
double* res_ptr = res + i * res_sl;
|
||||
const double* a_ptr = a + i * a_sl;
|
||||
if (res_ptr == a_ptr) {
|
||||
rnx_automorphism_inplace_f64(nn, p, res_ptr);
|
||||
} else {
|
||||
rnx_automorphism_f64(nn, p, res_ptr, a_ptr);
|
||||
}
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = a . (X^p - 1) */
|
||||
EXPORT void vec_rnx_mul_xp_minus_one_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
|
||||
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||
// rotate up to the smallest dimension
|
||||
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||
double* res_ptr = res + i * res_sl;
|
||||
const double* a_ptr = a + i * a_sl;
|
||||
if (res_ptr == a_ptr) {
|
||||
rnx_mul_xp_minus_one_inplace_f64(nn, p, res_ptr);
|
||||
} else {
|
||||
rnx_mul_xp_minus_one_f64(nn, p, res_ptr, a_ptr);
|
||||
}
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
@@ -1,356 +0,0 @@
|
||||
#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_H
|
||||
#define SPQLIOS_VEC_RNX_ARITHMETIC_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../commons.h"
|
||||
|
||||
/**
|
||||
* We support the following module families:
|
||||
* - FFT64:
|
||||
* the overall precision should fit at all times over 52 bits.
|
||||
*/
|
||||
typedef enum rnx_module_type_t { FFT64 } RNX_MODULE_TYPE;
|
||||
|
||||
/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */
|
||||
typedef struct rnx_module_info_t MOD_RNX;
|
||||
|
||||
/**
|
||||
* @brief obtain a module info for ring dimension N
|
||||
* the module-info knows about:
|
||||
* - the dimension N (or the complex dimension m=N/2)
|
||||
* - any moduleuted fft or ntt items
|
||||
* - the hardware (avx, arm64, x86, ...)
|
||||
*/
|
||||
EXPORT MOD_RNX* new_rnx_module_info(uint64_t N, RNX_MODULE_TYPE mode);
|
||||
EXPORT void delete_rnx_module_info(MOD_RNX* module_info);
|
||||
EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module);
|
||||
|
||||
// basic arithmetic
|
||||
|
||||
/** @brief sets res = 0 */
|
||||
EXPORT void vec_rnx_zero( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||
);
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_rnx_copy( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = -a */
|
||||
EXPORT void vec_rnx_negate( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a + b */
|
||||
EXPORT void vec_rnx_add( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a - b */
|
||||
EXPORT void vec_rnx_sub( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void vec_rnx_rotate( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a . (X^p - 1) */
|
||||
EXPORT void vec_rnx_mul_xp_minus_one( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void vec_rnx_automorphism( //
|
||||
const MOD_RNX* module, // N
|
||||
int64_t p, // X -> X^p
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// conversions //
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
||||
EXPORT void vec_rnx_to_znx32( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_from_znx32( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_to_tnx32( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_from_tnx32( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_to_tnx32x2( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_from_tnx32x2( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_to_tnxdbl( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// isolated products (n.log(n), but not particularly optimized //
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
||||
/** @brief res = a * b : small polynomial product */
|
||||
EXPORT void rnx_small_single_product( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, // output
|
||||
const double* a, // a
|
||||
const double* b, // b
|
||||
uint8_t* tmp); // scratch space
|
||||
|
||||
EXPORT uint64_t rnx_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||
|
||||
/** @brief res = a * b centermod 1: small polynomial product */
|
||||
EXPORT void tnxdbl_small_single_product( //
|
||||
const MOD_RNX* module, // N
|
||||
double* torus_res, // output
|
||||
const double* int_a, // a
|
||||
const double* torus_b, // b
|
||||
uint8_t* tmp); // scratch space
|
||||
|
||||
EXPORT uint64_t tnxdbl_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||
|
||||
/** @brief res = a * b: small polynomial product */
|
||||
EXPORT void znx32_small_single_product( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* int_res, // output
|
||||
const int32_t* int_a, // a
|
||||
const int32_t* int_b, // b
|
||||
uint8_t* tmp); // scratch space
|
||||
|
||||
EXPORT uint64_t znx32_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||
|
||||
/** @brief res = a * b centermod 1: small polynomial product */
|
||||
EXPORT void tnx32_small_single_product( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* torus_res, // output
|
||||
const int32_t* int_a, // a
|
||||
const int32_t* torus_b, // b
|
||||
uint8_t* tmp); // scratch space
|
||||
|
||||
EXPORT uint64_t tnx32_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// prepared gadget decompositions (optimized) //
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
||||
// decompose from tnx32
|
||||
|
||||
typedef struct tnx32_approxdecomp_gadget_t TNX32_APPROXDECOMP_GADGET;
|
||||
|
||||
/** @brief new gadget: delete with delete_tnx32_approxdecomp_gadget */
|
||||
EXPORT TNX32_APPROXDECOMP_GADGET* new_tnx32_approxdecomp_gadget( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t k, uint64_t ell // base 2^K and size
|
||||
);
|
||||
EXPORT void delete_tnx32_approxdecomp_gadget(const MOD_RNX* module);
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) */
|
||||
EXPORT void rnx_approxdecomp_from_tnx32( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNX32_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a // a
|
||||
);
|
||||
|
||||
// decompose from tnx32x2
|
||||
|
||||
typedef struct tnx32x2_approxdecomp_gadget_t TNX32X2_APPROXDECOMP_GADGET;
|
||||
|
||||
/** @brief new gadget: delete with delete_tnx32x2_approxdecomp_gadget */
|
||||
EXPORT TNX32X2_APPROXDECOMP_GADGET* new_tnx32x2_approxdecomp_gadget(const MOD_RNX* module, uint64_t ka, uint64_t ella,
|
||||
uint64_t kb, uint64_t ellb);
|
||||
EXPORT void delete_tnx32x2_approxdecomp_gadget(const MOD_RNX* module);
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) */
|
||||
EXPORT void rnx_approxdecomp_from_tnx32x2( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNX32X2_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a // a
|
||||
);
|
||||
|
||||
// decompose from tnxdbl
|
||||
|
||||
typedef struct tnxdbl_approxdecomp_gadget_t TNXDBL_APPROXDECOMP_GADGET;
|
||||
|
||||
/** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */
|
||||
EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t k, uint64_t ell // base 2^K and size
|
||||
);
|
||||
EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget);
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) */
|
||||
EXPORT void rnx_approxdecomp_from_tnxdbl( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a); // a
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// prepared scalar-vector product (optimized) //
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
||||
/** @brief opaque type that represents a polynomial of RnX prepared for a scalar-vector product */
|
||||
typedef struct rnx_svp_ppol_t RNX_SVP_PPOL;
|
||||
|
||||
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||
EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module); // N
|
||||
|
||||
/** @brief allocates a prepared vector (release with delete_rnx_svp_ppol) */
|
||||
EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module); // N
|
||||
|
||||
/** @brief frees memory for a prepared vector */
|
||||
EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* res);
|
||||
|
||||
/** @brief prepares a svp polynomial */
|
||||
EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N
|
||||
RNX_SVP_PPOL* ppol, // output
|
||||
const double* pol // a
|
||||
);
|
||||
|
||||
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||
EXPORT void rnx_svp_apply( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// prepared vector-matrix product (optimized) //
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
||||
typedef struct rnx_vmp_pmat_t RNX_VMP_PMAT;
|
||||
|
||||
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||
EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||
uint64_t nrows, uint64_t ncols); // dimensions
|
||||
|
||||
/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */
|
||||
EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||
uint64_t nrows, uint64_t ncols); // dimensions
|
||||
EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr);
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void rnx_vmp_prepare_contiguous( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* a, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void rnx_vmp_prepare_dblptr( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double** a, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||
EXPORT void rnx_vmp_prepare_row( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* a, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||
EXPORT uint64_t rnx_vmp_prepare_tmp_bytes(const MOD_RNX* module);
|
||||
|
||||
/** @brief applies a vmp product res = a x pmat */
|
||||
EXPORT void rnx_vmp_apply_tmp_a( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res size
|
||||
uint64_t a_size, // a size
|
||||
uint64_t nrows, uint64_t ncols // prep matrix dims
|
||||
);
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT void rnx_vmp_apply_dft_to_dft( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
|
||||
/** @brief sets res = DFT(a) */
|
||||
EXPORT void vec_rnx_dft(const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = iDFT(a_dft) -- idft is not normalized */
|
||||
EXPORT void vec_rnx_idft(const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a_dft, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
#endif // SPQLIOS_VEC_RNX_ARITHMETIC_H
|
||||
@@ -1,189 +0,0 @@
|
||||
#include <immintrin.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
void rnx_add_avx(uint64_t nn, double* res, const double* a, const double* b) {
|
||||
if (nn < 8) {
|
||||
if (nn == 4) {
|
||||
_mm256_storeu_pd(res, _mm256_add_pd(_mm256_loadu_pd(a), _mm256_loadu_pd(b)));
|
||||
} else if (nn == 2) {
|
||||
_mm_storeu_pd(res, _mm_add_pd(_mm_loadu_pd(a), _mm_loadu_pd(b)));
|
||||
} else if (nn == 1) {
|
||||
*res = *a + *b;
|
||||
} else {
|
||||
NOT_SUPPORTED(); // not a power of 2
|
||||
}
|
||||
return;
|
||||
}
|
||||
// general case: nn >= 8
|
||||
__m256d x0, x1, x2, x3, x4, x5;
|
||||
const double* aa = a;
|
||||
const double* bb = b;
|
||||
double* rr = res;
|
||||
double* const rrend = res + nn;
|
||||
do {
|
||||
x0 = _mm256_loadu_pd(aa);
|
||||
x1 = _mm256_loadu_pd(aa + 4);
|
||||
x2 = _mm256_loadu_pd(bb);
|
||||
x3 = _mm256_loadu_pd(bb + 4);
|
||||
x4 = _mm256_add_pd(x0, x2);
|
||||
x5 = _mm256_add_pd(x1, x3);
|
||||
_mm256_storeu_pd(rr, x4);
|
||||
_mm256_storeu_pd(rr + 4, x5);
|
||||
aa += 8;
|
||||
bb += 8;
|
||||
rr += 8;
|
||||
} while (rr < rrend);
|
||||
}
|
||||
|
||||
void rnx_sub_avx(uint64_t nn, double* res, const double* a, const double* b) {
|
||||
if (nn < 8) {
|
||||
if (nn == 4) {
|
||||
_mm256_storeu_pd(res, _mm256_sub_pd(_mm256_loadu_pd(a), _mm256_loadu_pd(b)));
|
||||
} else if (nn == 2) {
|
||||
_mm_storeu_pd(res, _mm_sub_pd(_mm_loadu_pd(a), _mm_loadu_pd(b)));
|
||||
} else if (nn == 1) {
|
||||
*res = *a - *b;
|
||||
} else {
|
||||
NOT_SUPPORTED(); // not a power of 2
|
||||
}
|
||||
return;
|
||||
}
|
||||
// general case: nn >= 8
|
||||
__m256d x0, x1, x2, x3, x4, x5;
|
||||
const double* aa = a;
|
||||
const double* bb = b;
|
||||
double* rr = res;
|
||||
double* const rrend = res + nn;
|
||||
do {
|
||||
x0 = _mm256_loadu_pd(aa);
|
||||
x1 = _mm256_loadu_pd(aa + 4);
|
||||
x2 = _mm256_loadu_pd(bb);
|
||||
x3 = _mm256_loadu_pd(bb + 4);
|
||||
x4 = _mm256_sub_pd(x0, x2);
|
||||
x5 = _mm256_sub_pd(x1, x3);
|
||||
_mm256_storeu_pd(rr, x4);
|
||||
_mm256_storeu_pd(rr + 4, x5);
|
||||
aa += 8;
|
||||
bb += 8;
|
||||
rr += 8;
|
||||
} while (rr < rrend);
|
||||
}
|
||||
|
||||
void rnx_negate_avx(uint64_t nn, double* res, const double* b) {
|
||||
if (nn < 8) {
|
||||
if (nn == 4) {
|
||||
_mm256_storeu_pd(res, _mm256_sub_pd(_mm256_set1_pd(0), _mm256_loadu_pd(b)));
|
||||
} else if (nn == 2) {
|
||||
_mm_storeu_pd(res, _mm_sub_pd(_mm_set1_pd(0), _mm_loadu_pd(b)));
|
||||
} else if (nn == 1) {
|
||||
*res = -*b;
|
||||
} else {
|
||||
NOT_SUPPORTED(); // not a power of 2
|
||||
}
|
||||
return;
|
||||
}
|
||||
// general case: nn >= 8
|
||||
__m256d x2, x3, x4, x5;
|
||||
const __m256d ZERO = _mm256_set1_pd(0);
|
||||
const double* bb = b;
|
||||
double* rr = res;
|
||||
double* const rrend = res + nn;
|
||||
do {
|
||||
x2 = _mm256_loadu_pd(bb);
|
||||
x3 = _mm256_loadu_pd(bb + 4);
|
||||
x4 = _mm256_sub_pd(ZERO, x2);
|
||||
x5 = _mm256_sub_pd(ZERO, x3);
|
||||
_mm256_storeu_pd(rr, x4);
|
||||
_mm256_storeu_pd(rr + 4, x5);
|
||||
bb += 8;
|
||||
rr += 8;
|
||||
} while (rr < rrend);
|
||||
}
|
||||
|
||||
/** @brief sets res = a + b */
|
||||
EXPORT void vec_rnx_add_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
if (a_size < b_size) {
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_add_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
memcpy(res + i * res_sl, b + i * b_sl, nn * sizeof(double));
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
} else {
|
||||
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_add_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = -a */
|
||||
EXPORT void vec_rnx_negate_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_negate_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = a - b */
|
||||
EXPORT void vec_rnx_sub_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
if (a_size < b_size) {
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_sub_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
rnx_negate_avx(nn, res + i * res_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
} else {
|
||||
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_sub_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H
|
||||
#define SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H
|
||||
|
||||
#include "vec_rnx_arithmetic.h"
|
||||
|
||||
typedef typeof(vec_rnx_zero) VEC_RNX_ZERO_F;
|
||||
typedef typeof(vec_rnx_copy) VEC_RNX_COPY_F;
|
||||
typedef typeof(vec_rnx_negate) VEC_RNX_NEGATE_F;
|
||||
typedef typeof(vec_rnx_add) VEC_RNX_ADD_F;
|
||||
typedef typeof(vec_rnx_sub) VEC_RNX_SUB_F;
|
||||
typedef typeof(vec_rnx_rotate) VEC_RNX_ROTATE_F;
|
||||
typedef typeof(vec_rnx_mul_xp_minus_one) VEC_RNX_MUL_XP_MINUS_ONE_F;
|
||||
typedef typeof(vec_rnx_automorphism) VEC_RNX_AUTOMORPHISM_F;
|
||||
typedef typeof(vec_rnx_to_znx32) VEC_RNX_TO_ZNX32_F;
|
||||
typedef typeof(vec_rnx_from_znx32) VEC_RNX_FROM_ZNX32_F;
|
||||
typedef typeof(vec_rnx_to_tnx32) VEC_RNX_TO_TNX32_F;
|
||||
typedef typeof(vec_rnx_from_tnx32) VEC_RNX_FROM_TNX32_F;
|
||||
typedef typeof(vec_rnx_to_tnx32x2) VEC_RNX_TO_TNX32X2_F;
|
||||
typedef typeof(vec_rnx_from_tnx32x2) VEC_RNX_FROM_TNX32X2_F;
|
||||
typedef typeof(vec_rnx_to_tnxdbl) VEC_RNX_TO_TNXDBL_F;
|
||||
// typedef typeof(vec_rnx_from_tnxdbl) VEC_RNX_FROM_TNXDBL_F;
|
||||
typedef typeof(rnx_small_single_product) RNX_SMALL_SINGLE_PRODUCT_F;
|
||||
typedef typeof(rnx_small_single_product_tmp_bytes) RNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||
typedef typeof(tnxdbl_small_single_product) TNXDBL_SMALL_SINGLE_PRODUCT_F;
|
||||
typedef typeof(tnxdbl_small_single_product_tmp_bytes) TNXDBL_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||
typedef typeof(znx32_small_single_product) ZNX32_SMALL_SINGLE_PRODUCT_F;
|
||||
typedef typeof(znx32_small_single_product_tmp_bytes) ZNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||
typedef typeof(tnx32_small_single_product) TNX32_SMALL_SINGLE_PRODUCT_F;
|
||||
typedef typeof(tnx32_small_single_product_tmp_bytes) TNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||
typedef typeof(rnx_approxdecomp_from_tnx32) RNX_APPROXDECOMP_FROM_TNX32_F;
|
||||
typedef typeof(rnx_approxdecomp_from_tnx32x2) RNX_APPROXDECOMP_FROM_TNX32X2_F;
|
||||
typedef typeof(rnx_approxdecomp_from_tnxdbl) RNX_APPROXDECOMP_FROM_TNXDBL_F;
|
||||
typedef typeof(bytes_of_rnx_svp_ppol) BYTES_OF_RNX_SVP_PPOL_F;
|
||||
typedef typeof(rnx_svp_prepare) RNX_SVP_PREPARE_F;
|
||||
typedef typeof(rnx_svp_apply) RNX_SVP_APPLY_F;
|
||||
typedef typeof(bytes_of_rnx_vmp_pmat) BYTES_OF_RNX_VMP_PMAT_F;
|
||||
typedef typeof(rnx_vmp_prepare_contiguous) RNX_VMP_PREPARE_CONTIGUOUS_F;
|
||||
typedef typeof(rnx_vmp_prepare_dblptr) RNX_VMP_PREPARE_DBLPTR_F;
|
||||
typedef typeof(rnx_vmp_prepare_row) RNX_VMP_PREPARE_ROW_F;
|
||||
typedef typeof(rnx_vmp_prepare_tmp_bytes) RNX_VMP_PREPARE_TMP_BYTES_F;
|
||||
typedef typeof(rnx_vmp_apply_tmp_a) RNX_VMP_APPLY_TMP_A_F;
|
||||
typedef typeof(rnx_vmp_apply_tmp_a_tmp_bytes) RNX_VMP_APPLY_TMP_A_TMP_BYTES_F;
|
||||
typedef typeof(rnx_vmp_apply_dft_to_dft) RNX_VMP_APPLY_DFT_TO_DFT_F;
|
||||
typedef typeof(rnx_vmp_apply_dft_to_dft_tmp_bytes) RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F;
|
||||
typedef typeof(vec_rnx_dft) VEC_RNX_DFT_F;
|
||||
typedef typeof(vec_rnx_idft) VEC_RNX_IDFT_F;
|
||||
|
||||
typedef struct rnx_module_vtable_t RNX_MODULE_VTABLE;
|
||||
struct rnx_module_vtable_t {
|
||||
VEC_RNX_ZERO_F* vec_rnx_zero;
|
||||
VEC_RNX_COPY_F* vec_rnx_copy;
|
||||
VEC_RNX_NEGATE_F* vec_rnx_negate;
|
||||
VEC_RNX_ADD_F* vec_rnx_add;
|
||||
VEC_RNX_SUB_F* vec_rnx_sub;
|
||||
VEC_RNX_ROTATE_F* vec_rnx_rotate;
|
||||
VEC_RNX_MUL_XP_MINUS_ONE_F* vec_rnx_mul_xp_minus_one;
|
||||
VEC_RNX_AUTOMORPHISM_F* vec_rnx_automorphism;
|
||||
VEC_RNX_TO_ZNX32_F* vec_rnx_to_znx32;
|
||||
VEC_RNX_FROM_ZNX32_F* vec_rnx_from_znx32;
|
||||
VEC_RNX_TO_TNX32_F* vec_rnx_to_tnx32;
|
||||
VEC_RNX_FROM_TNX32_F* vec_rnx_from_tnx32;
|
||||
VEC_RNX_TO_TNX32X2_F* vec_rnx_to_tnx32x2;
|
||||
VEC_RNX_FROM_TNX32X2_F* vec_rnx_from_tnx32x2;
|
||||
VEC_RNX_TO_TNXDBL_F* vec_rnx_to_tnxdbl;
|
||||
RNX_SMALL_SINGLE_PRODUCT_F* rnx_small_single_product;
|
||||
RNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* rnx_small_single_product_tmp_bytes;
|
||||
TNXDBL_SMALL_SINGLE_PRODUCT_F* tnxdbl_small_single_product;
|
||||
TNXDBL_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* tnxdbl_small_single_product_tmp_bytes;
|
||||
ZNX32_SMALL_SINGLE_PRODUCT_F* znx32_small_single_product;
|
||||
ZNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx32_small_single_product_tmp_bytes;
|
||||
TNX32_SMALL_SINGLE_PRODUCT_F* tnx32_small_single_product;
|
||||
TNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* tnx32_small_single_product_tmp_bytes;
|
||||
RNX_APPROXDECOMP_FROM_TNX32_F* rnx_approxdecomp_from_tnx32;
|
||||
RNX_APPROXDECOMP_FROM_TNX32X2_F* rnx_approxdecomp_from_tnx32x2;
|
||||
RNX_APPROXDECOMP_FROM_TNXDBL_F* rnx_approxdecomp_from_tnxdbl;
|
||||
BYTES_OF_RNX_SVP_PPOL_F* bytes_of_rnx_svp_ppol;
|
||||
RNX_SVP_PREPARE_F* rnx_svp_prepare;
|
||||
RNX_SVP_APPLY_F* rnx_svp_apply;
|
||||
BYTES_OF_RNX_VMP_PMAT_F* bytes_of_rnx_vmp_pmat;
|
||||
RNX_VMP_PREPARE_CONTIGUOUS_F* rnx_vmp_prepare_contiguous;
|
||||
RNX_VMP_PREPARE_DBLPTR_F* rnx_vmp_prepare_dblptr;
|
||||
RNX_VMP_PREPARE_ROW_F* rnx_vmp_prepare_row;
|
||||
RNX_VMP_PREPARE_TMP_BYTES_F* rnx_vmp_prepare_tmp_bytes;
|
||||
RNX_VMP_APPLY_TMP_A_F* rnx_vmp_apply_tmp_a;
|
||||
RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* rnx_vmp_apply_tmp_a_tmp_bytes;
|
||||
RNX_VMP_APPLY_DFT_TO_DFT_F* rnx_vmp_apply_dft_to_dft;
|
||||
RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* rnx_vmp_apply_dft_to_dft_tmp_bytes;
|
||||
VEC_RNX_DFT_F* vec_rnx_dft;
|
||||
VEC_RNX_IDFT_F* vec_rnx_idft;
|
||||
};
|
||||
|
||||
#endif // SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H
|
||||
@@ -1,309 +0,0 @@
|
||||
#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H
|
||||
#define SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H
|
||||
|
||||
#include "../commons_private.h"
|
||||
#include "../reim/reim_fft.h"
|
||||
#include "vec_rnx_arithmetic.h"
|
||||
#include "vec_rnx_arithmetic_plugin.h"
|
||||
|
||||
typedef struct fft64_rnx_module_precomp_t FFT64_RNX_MODULE_PRECOMP;
|
||||
struct fft64_rnx_module_precomp_t {
|
||||
REIM_FFT_PRECOMP* p_fft;
|
||||
REIM_IFFT_PRECOMP* p_ifft;
|
||||
REIM_FFTVEC_ADD_PRECOMP* p_fftvec_add;
|
||||
REIM_FFTVEC_MUL_PRECOMP* p_fftvec_mul;
|
||||
REIM_FFTVEC_ADDMUL_PRECOMP* p_fftvec_addmul;
|
||||
};
|
||||
|
||||
typedef union rnx_module_precomp_t RNX_MODULE_PRECOMP;
|
||||
union rnx_module_precomp_t {
|
||||
FFT64_RNX_MODULE_PRECOMP fft64;
|
||||
};
|
||||
|
||||
void fft64_init_rnx_module_precomp(MOD_RNX* module);
|
||||
|
||||
void fft64_finalize_rnx_module_precomp(MOD_RNX* module);
|
||||
|
||||
/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */
|
||||
struct rnx_module_info_t {
|
||||
uint64_t n;
|
||||
uint64_t m;
|
||||
RNX_MODULE_TYPE mtype;
|
||||
RNX_MODULE_VTABLE vtable;
|
||||
RNX_MODULE_PRECOMP precomp;
|
||||
void* custom;
|
||||
void (*custom_deleter)(void*);
|
||||
};
|
||||
|
||||
void init_rnx_module_info(MOD_RNX* module, //
|
||||
uint64_t, RNX_MODULE_TYPE mtype);
|
||||
|
||||
void finalize_rnx_module_info(MOD_RNX* module);
|
||||
|
||||
void fft64_init_rnx_module_vtable(MOD_RNX* module);
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// prepared gadget decompositions (optimized) //
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
||||
struct tnx32_approxdec_gadget_t {
|
||||
uint64_t k;
|
||||
uint64_t ell;
|
||||
int32_t add_cst; // 1/2.(sum 2^-(i+1)K)
|
||||
int32_t rshift_base; // 32 - K
|
||||
int64_t and_mask; // 2^K-1
|
||||
int64_t or_mask; // double(2^52)
|
||||
double sub_cst; // double(2^52 + 2^(K-1))
|
||||
uint8_t rshifts[8]; // 32 - (i+1).K
|
||||
};
|
||||
|
||||
struct tnx32x2_approxdec_gadget_t {
|
||||
// TODO
|
||||
};
|
||||
|
||||
struct tnxdbl_approxdecomp_gadget_t {
|
||||
uint64_t k;
|
||||
uint64_t ell;
|
||||
double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[)
|
||||
uint64_t and_mask; // uint64(2^(K)-1)
|
||||
uint64_t or_mask; // double(2^52)
|
||||
double sub_cst; // double(2^52 + 2^(K-1))
|
||||
};
|
||||
|
||||
EXPORT void vec_rnx_add_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
EXPORT void vec_rnx_add_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = 0 */
|
||||
EXPORT void vec_rnx_zero_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||
);
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_rnx_copy_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = -a */
|
||||
EXPORT void vec_rnx_negate_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = -a */
|
||||
EXPORT void vec_rnx_negate_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a - b */
|
||||
EXPORT void vec_rnx_sub_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a - b */
|
||||
EXPORT void vec_rnx_sub_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void vec_rnx_rotate_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void vec_rnx_automorphism_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
int64_t p, // X -> X^p
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||
EXPORT uint64_t fft64_bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||
uint64_t nrows, uint64_t ncols);
|
||||
|
||||
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_prepare_dblptr_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double** mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_prepare_dblptr_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double** mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_prepare_row_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* mat, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_prepare_row_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* mat, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_ref(const MOD_RNX* module);
|
||||
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_avx(const MOD_RNX* module);
|
||||
|
||||
EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_apply_tmp_a_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
|
||||
/// gadget decompositions
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) */
|
||||
EXPORT void rnx_approxdecomp_from_tnxdbl_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a); // a
|
||||
EXPORT void rnx_approxdecomp_from_tnxdbl_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a); // a
|
||||
|
||||
EXPORT void vec_rnx_mul_xp_minus_one_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_to_znx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_from_znx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_to_tnx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_from_tnx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_to_tnxdbl_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module); // N
|
||||
|
||||
/** @brief prepares a svp polynomial */
|
||||
EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N
|
||||
RNX_SVP_PPOL* ppol, // output
|
||||
const double* pol // a
|
||||
);
|
||||
|
||||
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||
EXPORT void fft64_rnx_svp_apply_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
#endif // SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H
|
||||
@@ -1,91 +0,0 @@
|
||||
#include <memory.h>
|
||||
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
#include "zn_arithmetic_private.h"
|
||||
|
||||
EXPORT void vec_rnx_to_znx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
dbl_round_to_i32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||
}
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_from_znx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
i32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||
}
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
EXPORT void vec_rnx_to_tnx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
dbl_to_tn32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||
}
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
EXPORT void vec_rnx_from_tnx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
tn32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||
}
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
static void dbl_to_tndbl_ref( //
|
||||
const void* UNUSED, // N
|
||||
double* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
static const double OFF_CST = INT64_C(3) << 51;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
double ai = a[i] + OFF_CST;
|
||||
res[i] = a[i] - (ai - OFF_CST);
|
||||
}
|
||||
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_to_tnxdbl_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
dbl_to_tndbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||
}
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "../coeffs/coeffs_arithmetic.h"
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->n * sizeof(double); }
|
||||
|
||||
EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module) { return spqlios_alloc(bytes_of_rnx_svp_ppol(module)); }
|
||||
|
||||
EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* ppol) { spqlios_free(ppol); }
|
||||
|
||||
/** @brief prepares a svp polynomial */
|
||||
EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N
|
||||
RNX_SVP_PPOL* ppol, // output
|
||||
const double* pol // a
|
||||
) {
|
||||
double* const dppol = (double*)ppol;
|
||||
rnx_divide_by_m_ref(module->n, module->m, dppol, pol);
|
||||
reim_fft(module->precomp.fft64.p_fft, dppol);
|
||||
}
|
||||
|
||||
EXPORT void fft64_rnx_svp_apply_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
double* const dppol = (double*)ppol;
|
||||
|
||||
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||
const double* a_ptr = a + i * a_sl;
|
||||
double* const res_ptr = res + i * res_sl;
|
||||
// copy the polynomial to res, apply fft in place, call fftvec
|
||||
// _mul, apply ifft in place.
|
||||
memcpy(res_ptr, a_ptr, nn * sizeof(double));
|
||||
reim_fft(module->precomp.fft64.p_fft, (double*)res_ptr);
|
||||
reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, res_ptr, res_ptr, dppol);
|
||||
reim_ifft(module->precomp.fft64.p_ifft, res_ptr);
|
||||
}
|
||||
|
||||
// then extend with zeros
|
||||
for (uint64_t i = auto_end_idx; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
@@ -1,254 +0,0 @@
|
||||
#include <assert.h>
|
||||
#include <immintrin.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "../coeffs/coeffs_arithmetic.h"
|
||||
#include "../reim/reim_fft.h"
|
||||
#include "../reim4/reim4_arithmetic.h"
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
// there is an edge case if nn < 8
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t m = module->m;
|
||||
|
||||
double* const dtmp = (double*)tmp_space;
|
||||
double* const output_mat = (double*)pmat;
|
||||
double* start_addr = (double*)pmat;
|
||||
uint64_t offset = nrows * ncols * 8;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
rnx_divide_by_m_avx(nn, m, dtmp, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, dtmp);
|
||||
|
||||
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||
// special case: last column out of an odd column number
|
||||
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||
+ row_i * 8;
|
||||
} else {
|
||||
// general case: columns go by pair
|
||||
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||
+ row_i * 2 * 8 // third: row index
|
||||
+ (col_i % 2) * 8;
|
||||
}
|
||||
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
// extract blk from tmp and save it
|
||||
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, dtmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
double* res = output_mat + (col_i * nrows + row_i) * nn;
|
||||
rnx_divide_by_m_avx(nn, m, res, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void fft64_rnx_vmp_prepare_dblptr_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double** mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
fft64_rnx_vmp_prepare_row_avx(module, pmat, mat[row_i], row_i, nrows, ncols, tmp_space);
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||
EXPORT void fft64_rnx_vmp_prepare_row_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
// there is an edge case if nn < 8
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t m = module->m;
|
||||
|
||||
double* const dtmp = (double*)tmp_space;
|
||||
double* const output_mat = (double*)pmat;
|
||||
double* start_addr = (double*)pmat;
|
||||
uint64_t offset = nrows * ncols * 8;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
rnx_divide_by_m_avx(nn, m, dtmp, row + col_i * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, dtmp);
|
||||
|
||||
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||
// special case: last column out of an odd column number
|
||||
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||
+ row_i * 8;
|
||||
} else {
|
||||
// general case: columns go by pair
|
||||
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||
+ row_i * 2 * 8 // third: row index
|
||||
+ (col_i % 2) * 8;
|
||||
}
|
||||
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
// extract blk from tmp and save it
|
||||
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, dtmp);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
double* res = output_mat + (col_i * nrows + row_i) * nn;
|
||||
rnx_divide_by_m_avx(nn, m, res, row + col_i * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
const uint64_t m = module->m;
|
||||
const uint64_t nn = module->n;
|
||||
|
||||
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||
|
||||
double* mat_input = (double*)pmat;
|
||||
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||
|
||||
if (row_max > 0 && col_max > 0) {
|
||||
if (nn >= 8) {
|
||||
// let's do some prefetching of the GSW key, since on some cpus,
|
||||
// it helps
|
||||
const uint64_t ms4 = m >> 2; // m/4
|
||||
const uint64_t gsw_iter_doubles = 8 * nrows * ncols;
|
||||
const uint64_t pref_doubles = 1200;
|
||||
const double* gsw_pref_ptr = mat_input;
|
||||
const double* const gsw_ptr_end = mat_input + ms4 * gsw_iter_doubles;
|
||||
const double* gsw_pref_ptr_target = mat_input + pref_doubles;
|
||||
for (; gsw_pref_ptr < gsw_pref_ptr_target; gsw_pref_ptr += 8) {
|
||||
__builtin_prefetch(gsw_pref_ptr, 0, _MM_HINT_T0);
|
||||
}
|
||||
const double* mat_blk_start;
|
||||
uint64_t blk_i;
|
||||
for (blk_i = 0, mat_blk_start = mat_input; blk_i < ms4; blk_i++, mat_blk_start += gsw_iter_doubles) {
|
||||
// prefetch the next iteration
|
||||
if (gsw_pref_ptr_target < gsw_ptr_end) {
|
||||
gsw_pref_ptr_target += gsw_iter_doubles;
|
||||
if (gsw_pref_ptr_target > gsw_ptr_end) gsw_pref_ptr_target = gsw_ptr_end;
|
||||
for (; gsw_pref_ptr < gsw_pref_ptr_target; gsw_pref_ptr += 8) {
|
||||
__builtin_prefetch(gsw_pref_ptr, 0, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
reim4_extract_1blk_from_contiguous_reim_sl_avx(m, a_sl, row_max, blk_i, extracted_blk, a_dft);
|
||||
// apply mat2cols
|
||||
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||
uint64_t col_offset = col_i * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
|
||||
reim4_save_1blk_to_reim_avx(m, blk_i, res + col_i * res_sl, mat2cols_output);
|
||||
reim4_save_1blk_to_reim_avx(m, blk_i, res + (col_i + 1) * res_sl, mat2cols_output + 8);
|
||||
}
|
||||
|
||||
// check if col_max is odd, then special case
|
||||
if (col_max % 2 == 1) {
|
||||
uint64_t last_col = col_max - 1;
|
||||
uint64_t col_offset = last_col * (8 * nrows);
|
||||
|
||||
// the last column is alone in the pmat: vec_mat1col
|
||||
if (ncols == col_max) {
|
||||
reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
} else {
|
||||
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
}
|
||||
reim4_save_1blk_to_reim_avx(m, blk_i, res + last_col * res_sl, mat2cols_output);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const double* in;
|
||||
uint64_t in_sl;
|
||||
if (res == a_dft) {
|
||||
// it is in place: copy the input vector
|
||||
in = (double*)tmp_space;
|
||||
in_sl = nn;
|
||||
// vec_rnx_copy(module, (double*)tmp_space, row_max, nn, a_dft, row_max, a_sl);
|
||||
for (uint64_t row_i = 0; row_i < row_max; row_i++) {
|
||||
memcpy((double*)tmp_space + row_i * nn, a_dft + row_i * a_sl, nn * sizeof(double));
|
||||
}
|
||||
} else {
|
||||
// it is out of place: do the product directly
|
||||
in = a_dft;
|
||||
in_sl = a_sl;
|
||||
}
|
||||
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||
{
|
||||
reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, //
|
||||
res + col_i * res_sl, //
|
||||
in, //
|
||||
pmat_col);
|
||||
}
|
||||
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||
reim_fftvec_addmul(module->precomp.fft64.p_fftvec_addmul, //
|
||||
res + col_i * res_sl, //
|
||||
in + row_i * in_sl, //
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// zero out remaining bytes (if any)
|
||||
for (uint64_t i = col_max; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product res = a x pmat */
|
||||
EXPORT void fft64_rnx_vmp_apply_tmp_a_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t cols = ncols < res_size ? ncols : res_size;
|
||||
|
||||
// fft is done in place on the input (tmpa is destroyed)
|
||||
for (uint64_t i = 0; i < rows; ++i) {
|
||||
reim_fft(module->precomp.fft64.p_fft, tmpa + i * a_sl);
|
||||
}
|
||||
fft64_rnx_vmp_apply_dft_to_dft_avx(module, //
|
||||
res, cols, res_sl, //
|
||||
tmpa, rows, a_sl, //
|
||||
pmat, nrows, ncols, //
|
||||
tmp_space);
|
||||
// ifft is done in place on the output
|
||||
for (uint64_t i = 0; i < cols; ++i) {
|
||||
reim_ifft(module->precomp.fft64.p_ifft, res + i * res_sl);
|
||||
}
|
||||
// zero out the remaining positions
|
||||
for (uint64_t i = cols; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
@@ -1,309 +0,0 @@
|
||||
#include <assert.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "../coeffs/coeffs_arithmetic.h"
|
||||
#include "../reim/reim_fft.h"
|
||||
#include "../reim4/reim4_arithmetic.h"
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||
EXPORT uint64_t fft64_bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||
uint64_t nrows, uint64_t ncols) { // dimensions
|
||||
return nrows * ncols * module->n * sizeof(double);
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
// there is an edge case if nn < 8
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t m = module->m;
|
||||
|
||||
double* const dtmp = (double*)tmp_space;
|
||||
double* const output_mat = (double*)pmat;
|
||||
double* start_addr = (double*)pmat;
|
||||
uint64_t offset = nrows * ncols * 8;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
rnx_divide_by_m_ref(nn, m, dtmp, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, dtmp);
|
||||
|
||||
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||
// special case: last column out of an odd column number
|
||||
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||
+ row_i * 8;
|
||||
} else {
|
||||
// general case: columns go by pair
|
||||
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||
+ row_i * 2 * 8 // third: row index
|
||||
+ (col_i % 2) * 8;
|
||||
}
|
||||
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
// extract blk from tmp and save it
|
||||
reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, dtmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
double* res = output_mat + (col_i * nrows + row_i) * nn;
|
||||
rnx_divide_by_m_ref(nn, m, res, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void fft64_rnx_vmp_prepare_dblptr_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double** mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
fft64_rnx_vmp_prepare_row_ref(module, pmat, mat[row_i], row_i, nrows, ncols, tmp_space);
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||
EXPORT void fft64_rnx_vmp_prepare_row_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
// there is an edge case if nn < 8
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t m = module->m;
|
||||
|
||||
double* const dtmp = (double*)tmp_space;
|
||||
double* const output_mat = (double*)pmat;
|
||||
double* start_addr = (double*)pmat;
|
||||
uint64_t offset = nrows * ncols * 8;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
rnx_divide_by_m_ref(nn, m, dtmp, row + col_i * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, dtmp);
|
||||
|
||||
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||
// special case: last column out of an odd column number
|
||||
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||
+ row_i * 8;
|
||||
} else {
|
||||
// general case: columns go by pair
|
||||
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||
+ row_i * 2 * 8 // third: row index
|
||||
+ (col_i % 2) * 8;
|
||||
}
|
||||
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
// extract blk from tmp and save it
|
||||
reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, dtmp);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
double* res = output_mat + (col_i * nrows + row_i) * nn;
|
||||
rnx_divide_by_m_ref(nn, m, res, row + col_i * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_ref(const MOD_RNX* module) {
|
||||
const uint64_t nn = module->n;
|
||||
return nn * sizeof(int64_t);
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
const uint64_t m = module->m;
|
||||
const uint64_t nn = module->n;
|
||||
|
||||
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||
|
||||
double* mat_input = (double*)pmat;
|
||||
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||
|
||||
if (row_max > 0 && col_max > 0) {
|
||||
if (nn >= 8) {
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||
|
||||
reim4_extract_1blk_from_contiguous_reim_sl_ref(m, a_sl, row_max, blk_i, extracted_blk, a_dft);
|
||||
// apply mat2cols
|
||||
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||
uint64_t col_offset = col_i * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
|
||||
reim4_save_1blk_to_reim_ref(m, blk_i, res + col_i * res_sl, mat2cols_output);
|
||||
reim4_save_1blk_to_reim_ref(m, blk_i, res + (col_i + 1) * res_sl, mat2cols_output + 8);
|
||||
}
|
||||
|
||||
// check if col_max is odd, then special case
|
||||
if (col_max % 2 == 1) {
|
||||
uint64_t last_col = col_max - 1;
|
||||
uint64_t col_offset = last_col * (8 * nrows);
|
||||
|
||||
// the last column is alone in the pmat: vec_mat1col
|
||||
if (ncols == col_max) {
|
||||
reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
} else {
|
||||
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
}
|
||||
reim4_save_1blk_to_reim_ref(m, blk_i, res + last_col * res_sl, mat2cols_output);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const double* in;
|
||||
uint64_t in_sl;
|
||||
if (res == a_dft) {
|
||||
// it is in place: copy the input vector
|
||||
in = (double*)tmp_space;
|
||||
in_sl = nn;
|
||||
// vec_rnx_copy(module, (double*)tmp_space, row_max, nn, a_dft, row_max, a_sl);
|
||||
for (uint64_t row_i = 0; row_i < row_max; row_i++) {
|
||||
memcpy((double*)tmp_space + row_i * nn, a_dft + row_i * a_sl, nn * sizeof(double));
|
||||
}
|
||||
} else {
|
||||
// it is out of place: do the product directly
|
||||
in = a_dft;
|
||||
in_sl = a_sl;
|
||||
}
|
||||
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||
{
|
||||
reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, //
|
||||
res + col_i * res_sl, //
|
||||
in, //
|
||||
pmat_col);
|
||||
}
|
||||
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||
reim_fftvec_addmul(module->precomp.fft64.p_fftvec_addmul, //
|
||||
res + col_i * res_sl, //
|
||||
in + row_i * in_sl, //
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// zero out remaining bytes (if any)
|
||||
for (uint64_t i = col_max; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product res = a x pmat */
|
||||
EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t cols = ncols < res_size ? ncols : res_size;
|
||||
|
||||
// fft is done in place on the input (tmpa is destroyed)
|
||||
for (uint64_t i = 0; i < rows; ++i) {
|
||||
reim_fft(module->precomp.fft64.p_fft, tmpa + i * a_sl);
|
||||
}
|
||||
fft64_rnx_vmp_apply_dft_to_dft_ref(module, //
|
||||
res, cols, res_sl, //
|
||||
tmpa, rows, a_sl, //
|
||||
pmat, nrows, ncols, //
|
||||
tmp_space);
|
||||
// ifft is done in place on the output
|
||||
for (uint64_t i = 0; i < cols; ++i) {
|
||||
reim_ifft(module->precomp.fft64.p_ifft, res + i * res_sl);
|
||||
}
|
||||
// zero out the remaining positions
|
||||
for (uint64_t i = cols; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) {
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
|
||||
return (128) + (64 * row_max);
|
||||
}
|
||||
|
||||
#ifdef __APPLE__
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) {
|
||||
return fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref(module, res_size, a_size, nrows, ncols);
|
||||
}
|
||||
#else
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref")));
|
||||
#endif
|
||||
// avx aliases that need to be defined in the same .c file
|
||||
|
||||
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||
#ifdef __APPLE__
|
||||
#pragma weak fft64_rnx_vmp_prepare_tmp_bytes_avx = fft64_rnx_vmp_prepare_tmp_bytes_ref
|
||||
#else
|
||||
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_avx(const MOD_RNX* module)
|
||||
__attribute((alias("fft64_rnx_vmp_prepare_tmp_bytes_ref")));
|
||||
#endif
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
#ifdef __APPLE__
|
||||
#pragma weak fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref
|
||||
#else
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref")));
|
||||
#endif
|
||||
|
||||
#ifdef __APPLE__
|
||||
#pragma weak fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref
|
||||
#else
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref")));
|
||||
#endif
|
||||
// wrappers
|
||||
@@ -1,369 +0,0 @@
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "../coeffs/coeffs_arithmetic.h"
|
||||
#include "../q120/q120_arithmetic.h"
|
||||
#include "../q120/q120_ntt.h"
|
||||
#include "../reim/reim_fft_internal.h"
|
||||
#include "../reim4/reim4_arithmetic.h"
|
||||
#include "vec_znx_arithmetic.h"
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
// general function (virtual dispatch)
|
||||
|
||||
EXPORT void vec_znx_add(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->func.vec_znx_add(module, // N
|
||||
res, res_size, res_sl, // res
|
||||
a, a_size, a_sl, // a
|
||||
b, b_size, b_sl // b
|
||||
);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_sub(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->func.vec_znx_sub(module, // N
|
||||
res, res_size, res_sl, // res
|
||||
a, a_size, a_sl, // a
|
||||
b, b_size, b_sl // b
|
||||
);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_rotate(const MODULE* module, // N
|
||||
const int64_t p, // rotation value
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->func.vec_znx_rotate(module, // N
|
||||
p, // p
|
||||
res, res_size, res_sl, // res
|
||||
a, a_size, a_sl // a
|
||||
);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_mul_xp_minus_one(const MODULE* module, // N
|
||||
const int64_t p, // p
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->func.vec_znx_mul_xp_minus_one(module, // N
|
||||
p, // p
|
||||
res, res_size, res_sl, // res
|
||||
a, a_size, a_sl // a
|
||||
);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_automorphism(const MODULE* module, // N
|
||||
const int64_t p, // X->X^p
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->func.vec_znx_automorphism(module, // N
|
||||
p, // p
|
||||
res, res_size, res_sl, // res
|
||||
a, a_size, a_sl // a
|
||||
);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_normalize_base2k(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t log2_base2k, // output base 2^K
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
uint8_t* tmp_space // scratch space of size >= N
|
||||
) {
|
||||
module->func.vec_znx_normalize_base2k(module, nn, // N
|
||||
log2_base2k, // log2_base2k
|
||||
res, res_size, res_sl, // res
|
||||
a, a_size, a_sl, // a
|
||||
tmp_space);
|
||||
}
|
||||
|
||||
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||
) {
|
||||
return module->func.vec_znx_normalize_base2k_tmp_bytes(module, nn // N
|
||||
);
|
||||
}
|
||||
|
||||
// specialized function (ref)
|
||||
|
||||
EXPORT void vec_znx_add_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
if (a_size <= b_size) {
|
||||
const uint64_t sum_idx = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||
// add up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||
znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then copy to the largest dimension
|
||||
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||
znx_copy_i64_ref(nn, res + i * res_sl, b + i * b_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
} else {
|
||||
const uint64_t sum_idx = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||
// add up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||
znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then copy to the largest dimension
|
||||
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||
znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_sub_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
if (a_size <= b_size) {
|
||||
const uint64_t sub_idx = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||
// subtract up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||
znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then negate to the largest dimension
|
||||
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||
znx_negate_i64_ref(nn, res + i * res_sl, b + i * b_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
} else {
|
||||
const uint64_t sub_idx = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||
// subtract up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||
znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then copy to the largest dimension
|
||||
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||
znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_rotate_ref(const MODULE* module, // N
|
||||
const int64_t p, // rotation value
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||
// rotate up to the smallest dimension
|
||||
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||
int64_t* res_ptr = res + i * res_sl;
|
||||
const int64_t* a_ptr = a + i * a_sl;
|
||||
if (res_ptr == a_ptr) {
|
||||
znx_rotate_inplace_i64(nn, p, res_ptr);
|
||||
} else {
|
||||
znx_rotate_i64(nn, p, res_ptr, a_ptr);
|
||||
}
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_mul_xp_minus_one_ref(const MODULE* module, // N
|
||||
const int64_t p, // p
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||
int64_t* res_ptr = res + i * res_sl;
|
||||
const int64_t* a_ptr = a + i * a_sl;
|
||||
if (res_ptr == a_ptr) {
|
||||
znx_mul_xp_minus_one_inplace_i64(nn, p, res_ptr);
|
||||
} else {
|
||||
znx_mul_xp_minus_one_i64(nn, p, res_ptr, a_ptr);
|
||||
}
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_automorphism_ref(const MODULE* module, // N
|
||||
const int64_t p, // X->X^p
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||
|
||||
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||
int64_t* res_ptr = res + i * res_sl;
|
||||
const int64_t* a_ptr = a + i * a_sl;
|
||||
if (res_ptr == a_ptr) {
|
||||
znx_automorphism_inplace_i64(nn, p, res_ptr);
|
||||
} else {
|
||||
znx_automorphism_i64(nn, p, res_ptr, a_ptr);
|
||||
}
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = auto_end_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t log2_base2k, // output base 2^K
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
uint8_t* tmp_space // scratch space of size >= N
|
||||
) {
|
||||
|
||||
// use MSB limb of res for carry propagation
|
||||
int64_t* cout = (int64_t*)tmp_space;
|
||||
int64_t* cin = 0x0;
|
||||
|
||||
// propagate carry until first limb of res
|
||||
int64_t i = a_size - 1;
|
||||
for (; i >= res_size; --i) {
|
||||
znx_normalize(nn, log2_base2k, 0x0, cout, a + i * a_sl, cin);
|
||||
cin = cout;
|
||||
}
|
||||
|
||||
// propagate carry and normalize
|
||||
for (; i >= 1; --i) {
|
||||
znx_normalize(nn, log2_base2k, res + i * res_sl, cout, a + i * a_sl, cin);
|
||||
cin = cout;
|
||||
}
|
||||
|
||||
// normalize last limb
|
||||
znx_normalize(nn, log2_base2k, res, 0x0, a, cin);
|
||||
|
||||
// extend result with zeros
|
||||
for (uint64_t i = a_size; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module, uint64_t nn // N
|
||||
) {
|
||||
return nn * sizeof(int64_t);
|
||||
}
|
||||
|
||||
// alias have to be defined in this unit: do not move
|
||||
#ifdef __APPLE__
|
||||
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||
const MODULE* module, // N
|
||||
uint64_t nn
|
||||
) {
|
||||
return vec_znx_normalize_base2k_tmp_bytes_ref(module, nn);
|
||||
}
|
||||
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( //
|
||||
const MODULE* module, // N
|
||||
uint64_t nn
|
||||
) {
|
||||
return vec_znx_normalize_base2k_tmp_bytes_ref(module, nn);
|
||||
}
|
||||
#else
|
||||
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( //
|
||||
const MODULE* module, // N
|
||||
uint64_t nn
|
||||
) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
|
||||
|
||||
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||
const MODULE* module, // N
|
||||
uint64_t nn
|
||||
) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
|
||||
#endif
|
||||
|
||||
/** @brief sets res = 0 */
|
||||
EXPORT void vec_znx_zero(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||
) {
|
||||
module->func.vec_znx_zero(module, res, res_size, res_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_znx_copy(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->func.vec_znx_copy(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_znx_negate(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->func.vec_znx_negate(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_zero_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||
) {
|
||||
uint64_t nn = module->nn;
|
||||
for (uint64_t i = 0; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_copy_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
uint64_t nn = module->nn;
|
||||
uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < smin; ++i) {
|
||||
znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
for (uint64_t i = smin; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_negate_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
uint64_t nn = module->nn;
|
||||
uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < smin; ++i) {
|
||||
znx_negate_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
for (uint64_t i = smin; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
@@ -1,370 +0,0 @@
|
||||
#ifndef SPQLIOS_VEC_ZNX_ARITHMETIC_H
|
||||
#define SPQLIOS_VEC_ZNX_ARITHMETIC_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../commons.h"
|
||||
#include "../reim/reim_fft.h"
|
||||
|
||||
/**
|
||||
* We support the following module families:
|
||||
* - FFT64:
|
||||
* all the polynomials should fit at all times over 52 bits.
|
||||
* for FHE implementations, the recommended limb-sizes are
|
||||
* between K=10 and 20, which is good for low multiplicative depths.
|
||||
* - NTT120:
|
||||
* all the polynomials should fit at all times over 119 bits.
|
||||
* for FHE implementations, the recommended limb-sizes are
|
||||
* between K=20 and 40, which is good for large multiplicative depths.
|
||||
*/
|
||||
typedef enum module_type_t { FFT64, NTT120 } MODULE_TYPE;
|
||||
|
||||
/** @brief opaque structure that describr the modules (ZnX,TnX) and the hardware */
|
||||
typedef struct module_info_t MODULE;
|
||||
/** @brief opaque type that represents a prepared matrix */
|
||||
typedef struct vmp_pmat_t VMP_PMAT;
|
||||
/** @brief opaque type that represents a vector of znx in DFT space */
|
||||
typedef struct vec_znx_dft_t VEC_ZNX_DFT;
|
||||
/** @brief opaque type that represents a vector of znx in large coeffs space */
|
||||
typedef struct vec_znx_bigcoeff_t VEC_ZNX_BIG;
|
||||
/** @brief opaque type that represents a prepared scalar vector product */
|
||||
typedef struct svp_ppol_t SVP_PPOL;
|
||||
/** @brief opaque type that represents a prepared left convolution vector product */
|
||||
typedef struct cnv_pvec_l_t CNV_PVEC_L;
|
||||
/** @brief opaque type that represents a prepared right convolution vector product */
|
||||
typedef struct cnv_pvec_r_t CNV_PVEC_R;
|
||||
|
||||
/** @brief bytes needed for a vec_znx in DFT space */
|
||||
EXPORT uint64_t bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||
uint64_t size);
|
||||
|
||||
/** @brief allocates a vec_znx in DFT space */
|
||||
EXPORT VEC_ZNX_DFT* new_vec_znx_dft(const MODULE* module, // N
|
||||
uint64_t size);
|
||||
|
||||
/** @brief frees memory from a vec_znx in DFT space */
|
||||
EXPORT void delete_vec_znx_dft(VEC_ZNX_DFT* res);
|
||||
|
||||
/** @brief bytes needed for a vec_znx_big */
|
||||
EXPORT uint64_t bytes_of_vec_znx_big(const MODULE* module, // N
|
||||
uint64_t size);
|
||||
|
||||
/** @brief allocates a vec_znx_big */
|
||||
EXPORT VEC_ZNX_BIG* new_vec_znx_big(const MODULE* module, // N
|
||||
uint64_t size);
|
||||
/** @brief frees memory from a vec_znx_big */
|
||||
EXPORT void delete_vec_znx_big(VEC_ZNX_BIG* res);
|
||||
|
||||
/** @brief bytes needed for a prepared vector */
|
||||
EXPORT uint64_t bytes_of_svp_ppol(const MODULE* module); // N
|
||||
|
||||
/** @brief allocates a prepared vector */
|
||||
EXPORT SVP_PPOL* new_svp_ppol(const MODULE* module); // N
|
||||
|
||||
/** @brief frees memory for a prepared vector */
|
||||
EXPORT void delete_svp_ppol(SVP_PPOL* res);
|
||||
|
||||
/** @brief bytes needed for a prepared matrix */
|
||||
EXPORT uint64_t bytes_of_vmp_pmat(const MODULE* module, // N
|
||||
uint64_t nrows, uint64_t ncols);
|
||||
|
||||
/** @brief allocates a prepared matrix */
|
||||
EXPORT VMP_PMAT* new_vmp_pmat(const MODULE* module, // N
|
||||
uint64_t nrows, uint64_t ncols);
|
||||
|
||||
/** @brief frees memory for a prepared matrix */
|
||||
EXPORT void delete_vmp_pmat(VMP_PMAT* res);
|
||||
|
||||
/**
|
||||
* @brief obtain a module info for ring dimension N
|
||||
* the module-info knows about:
|
||||
* - the dimension N (or the complex dimension m=N/2)
|
||||
* - any moduleuted fft or ntt items
|
||||
* - the hardware (avx, arm64, x86, ...)
|
||||
*/
|
||||
EXPORT MODULE* new_module_info(uint64_t N, MODULE_TYPE mode);
|
||||
EXPORT void delete_module_info(MODULE* module_info);
|
||||
EXPORT uint64_t module_get_n(const MODULE* module);
|
||||
|
||||
/** @brief sets res = 0 */
|
||||
EXPORT void vec_znx_zero(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||
);
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_znx_copy(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_znx_negate(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a + b */
|
||||
EXPORT void vec_znx_add(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a - b */
|
||||
EXPORT void vec_znx_sub(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = k-normalize-reduce(a) */
|
||||
EXPORT void vec_znx_normalize_base2k(const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t log2_base2k, // output base 2^K
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
uint8_t* tmp_space // scratch space (size >= N)
|
||||
);
|
||||
|
||||
/** @brief returns the minimal byte length of scratch space for vec_znx_normalize_base2k */
|
||||
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||
);
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void vec_znx_rotate(const MODULE* module, // N
|
||||
const int64_t p, // rotation value
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a * (X^{p} - 1) */
|
||||
EXPORT void vec_znx_mul_xp_minus_one(const MODULE* module, // N
|
||||
const int64_t p, // rotation value
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void vec_znx_automorphism(const MODULE* module, // N
|
||||
const int64_t p, // X-X^p
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = 0 */
|
||||
EXPORT void vec_dft_zero(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size // res
|
||||
);
|
||||
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void vec_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a-b */
|
||||
EXPORT void vec_dft_sub(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
);
|
||||
|
||||
/** @brief sets res = DFT(a) */
|
||||
EXPORT void vec_znx_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = iDFT(a_dft) -- output in big coeffs space */
|
||||
EXPORT void vec_znx_idft(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
uint8_t* tmp // scratch space
|
||||
);
|
||||
|
||||
/** @brief tmp bytes required for vec_znx_idft */
|
||||
EXPORT uint64_t vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn);
|
||||
|
||||
/**
|
||||
* @brief sets res = iDFT(a_dft) -- output in big coeffs space
|
||||
*
|
||||
* @note a_dft is overwritten
|
||||
*/
|
||||
EXPORT void vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||
);
|
||||
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void vec_znx_big_add(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
);
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void vec_znx_big_add_small(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
EXPORT void vec_znx_big_add_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a-b */
|
||||
EXPORT void vec_znx_big_sub(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
);
|
||||
EXPORT void vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
EXPORT void vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
);
|
||||
EXPORT void vec_znx_big_sub_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = k-normalize(a) -- output in int64 coeffs space */
|
||||
EXPORT void vec_znx_big_normalize_base2k(const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t log2_base2k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
uint8_t* tmp_space // temp space
|
||||
);
|
||||
|
||||
/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */
|
||||
EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||
);
|
||||
|
||||
/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */
|
||||
EXPORT void vec_znx_big_range_normalize_base2k( //
|
||||
const MODULE* module, // MODULE
|
||||
uint64_t nn,
|
||||
uint64_t log2_base2k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_range_begin, uint64_t a_range_xend, uint64_t a_range_step, // range
|
||||
uint8_t* tmp_space // temp space
|
||||
);
|
||||
|
||||
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
|
||||
EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||
const MODULE* module, uint64_t nn // N
|
||||
);
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void vec_znx_big_rotate(const MODULE* module, // N
|
||||
int64_t p, // rotation value
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void vec_znx_big_automorphism(const MODULE* module, // N
|
||||
int64_t p, // X-X^p
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||
EXPORT void svp_apply_dft(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||
EXPORT void svp_apply_dft_to_dft(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size,
|
||||
uint64_t res_cols, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, uint64_t a_cols // a
|
||||
);
|
||||
|
||||
/** @brief prepares a svp polynomial */
|
||||
EXPORT void svp_prepare(const MODULE* module, // N
|
||||
SVP_PPOL* ppol, // output
|
||||
const int64_t* pol // a
|
||||
);
|
||||
|
||||
/** @brief res = a * b : small integer polynomial product */
|
||||
EXPORT void znx_small_single_product(const MODULE* module, // N
|
||||
int64_t* res, // output
|
||||
const int64_t* a, // a
|
||||
const int64_t* b, // b
|
||||
uint8_t* tmp);
|
||||
|
||||
/** @brief tmp bytes required for znx_small_single_product */
|
||||
EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn);
|
||||
|
||||
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||
EXPORT uint64_t vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t nrows, uint64_t ncols);
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void vmp_prepare_contiguous(const MODULE* module, // N
|
||||
VMP_PMAT* pmat, // output
|
||||
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) adds to res inplace */
|
||||
EXPORT void vmp_apply_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) */
|
||||
EXPORT void vmp_apply_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
|
||||
/** @brief applies vmp product */
|
||||
EXPORT void vmp_apply_dft_to_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||
const uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
|
||||
/** @brief applies vmp product and adds to res inplace */
|
||||
EXPORT void vmp_apply_dft_to_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||
const uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
#endif // SPQLIOS_VEC_ZNX_ARITHMETIC_H
|
||||
@@ -1,563 +0,0 @@
|
||||
#ifndef SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H
|
||||
#define SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H
|
||||
|
||||
#include "../commons_private.h"
|
||||
#include "../q120/q120_ntt.h"
|
||||
#include "vec_znx_arithmetic.h"
|
||||
|
||||
/**
|
||||
* Layouts families:
|
||||
*
|
||||
* fft64:
|
||||
* K: <= 20, N: <= 65536, ell: <= 200
|
||||
* vec<ZnX> normalized: represented by int64
|
||||
* vec<ZnX> large: represented by int64 (expect <=52 bits)
|
||||
* vec<ZnX> DFT: represented by double (reim_fft space)
|
||||
* On AVX2 inftastructure, PMAT, LCNV, RCNV use a special reim4_fft space
|
||||
*
|
||||
* ntt120:
|
||||
* K: <= 50, N: <= 65536, ell: <= 80
|
||||
* vec<ZnX> normalized: represented by int64
|
||||
* vec<ZnX> large: represented by int128 (expect <=120 bits)
|
||||
* vec<ZnX> DFT: represented by int64x4 (ntt120 space)
|
||||
* On AVX2 inftastructure, PMAT, LCNV, RCNV use a special ntt120 space
|
||||
*
|
||||
* ntt104:
|
||||
* K: <= 40, N: <= 65536, ell: <= 80
|
||||
* vec<ZnX> normalized: represented by int64
|
||||
* vec<ZnX> large: represented by int128 (expect <=120 bits)
|
||||
* vec<ZnX> DFT: represented by int64x4 (ntt120 space)
|
||||
* On AVX512 inftastructure, PMAT, LCNV, RCNV use a special ntt104 space
|
||||
*/
|
||||
|
||||
struct fft64_module_info_t {
|
||||
// pre-computation for reim_fft
|
||||
REIM_FFT_PRECOMP* p_fft;
|
||||
// pre-computation for add_fft
|
||||
REIM_FFTVEC_ADD_PRECOMP* add_fft;
|
||||
// pre-computation for add_fft
|
||||
REIM_FFTVEC_SUB_PRECOMP* sub_fft;
|
||||
// pre-computation for mul_fft
|
||||
REIM_FFTVEC_MUL_PRECOMP* mul_fft;
|
||||
// pre-computation for reim_from_znx6
|
||||
REIM_FROM_ZNX64_PRECOMP* p_conv;
|
||||
// pre-computation for reim_tp_znx6
|
||||
REIM_TO_ZNX64_PRECOMP* p_reim_to_znx;
|
||||
// pre-computation for reim_fft
|
||||
REIM_IFFT_PRECOMP* p_ifft;
|
||||
// pre-computation for reim_fftvec_addmul
|
||||
REIM_FFTVEC_ADDMUL_PRECOMP* p_addmul;
|
||||
};
|
||||
|
||||
struct q120_module_info_t {
|
||||
// pre-computation for q120b to q120b ntt
|
||||
q120_ntt_precomp* p_ntt;
|
||||
// pre-computation for q120b to q120b intt
|
||||
q120_ntt_precomp* p_intt;
|
||||
};
|
||||
|
||||
// TODO add function types here
|
||||
typedef typeof(vec_znx_zero) VEC_ZNX_ZERO_F;
|
||||
typedef typeof(vec_znx_copy) VEC_ZNX_COPY_F;
|
||||
typedef typeof(vec_znx_negate) VEC_ZNX_NEGATE_F;
|
||||
typedef typeof(vec_znx_add) VEC_ZNX_ADD_F;
|
||||
typedef typeof(vec_znx_dft) VEC_ZNX_DFT_F;
|
||||
typedef typeof(vec_dft_add) VEC_DFT_ADD_F;
|
||||
typedef typeof(vec_dft_sub) VEC_DFT_SUB_F;
|
||||
typedef typeof(vec_znx_idft) VEC_ZNX_IDFT_F;
|
||||
typedef typeof(vec_znx_idft_tmp_bytes) VEC_ZNX_IDFT_TMP_BYTES_F;
|
||||
typedef typeof(vec_znx_idft_tmp_a) VEC_ZNX_IDFT_TMP_A_F;
|
||||
typedef typeof(vec_znx_sub) VEC_ZNX_SUB_F;
|
||||
typedef typeof(vec_znx_rotate) VEC_ZNX_ROTATE_F;
|
||||
typedef typeof(vec_znx_mul_xp_minus_one) VEC_ZNX_MUL_XP_MINUS_ONE_F;
|
||||
typedef typeof(vec_znx_automorphism) VEC_ZNX_AUTOMORPHISM_F;
|
||||
typedef typeof(vec_znx_normalize_base2k) VEC_ZNX_NORMALIZE_BASE2K_F;
|
||||
typedef typeof(vec_znx_normalize_base2k_tmp_bytes) VEC_ZNX_NORMALIZE_BASE2K_TMP_BYTES_F;
|
||||
typedef typeof(vec_znx_big_normalize_base2k) VEC_ZNX_BIG_NORMALIZE_BASE2K_F;
|
||||
typedef typeof(vec_znx_big_normalize_base2k_tmp_bytes) VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F;
|
||||
typedef typeof(vec_znx_big_range_normalize_base2k) VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F;
|
||||
typedef typeof(vec_znx_big_range_normalize_base2k_tmp_bytes) VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F;
|
||||
typedef typeof(vec_znx_big_add) VEC_ZNX_BIG_ADD_F;
|
||||
typedef typeof(vec_znx_big_add_small) VEC_ZNX_BIG_ADD_SMALL_F;
|
||||
typedef typeof(vec_znx_big_add_small2) VEC_ZNX_BIG_ADD_SMALL2_F;
|
||||
typedef typeof(vec_znx_big_sub) VEC_ZNX_BIG_SUB_F;
|
||||
typedef typeof(vec_znx_big_sub_small_a) VEC_ZNX_BIG_SUB_SMALL_A_F;
|
||||
typedef typeof(vec_znx_big_sub_small_b) VEC_ZNX_BIG_SUB_SMALL_B_F;
|
||||
typedef typeof(vec_znx_big_sub_small2) VEC_ZNX_BIG_SUB_SMALL2_F;
|
||||
typedef typeof(vec_znx_big_rotate) VEC_ZNX_BIG_ROTATE_F;
|
||||
typedef typeof(vec_znx_big_automorphism) VEC_ZNX_BIG_AUTOMORPHISM_F;
|
||||
typedef typeof(svp_prepare) SVP_PREPARE;
|
||||
typedef typeof(svp_apply_dft) SVP_APPLY_DFT_F;
|
||||
typedef typeof(svp_apply_dft_to_dft) SVP_APPLY_DFT_TO_DFT_F;
|
||||
typedef typeof(znx_small_single_product) ZNX_SMALL_SINGLE_PRODUCT_F;
|
||||
typedef typeof(znx_small_single_product_tmp_bytes) ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||
typedef typeof(vmp_prepare_contiguous) VMP_PREPARE_CONTIGUOUS_F;
|
||||
typedef typeof(vmp_prepare_tmp_bytes) VMP_PREPARE_TMP_BYTES_F;
|
||||
typedef typeof(vmp_apply_dft) VMP_APPLY_DFT_F;
|
||||
typedef typeof(vmp_apply_dft_add) VMP_APPLY_DFT_ADD_F;
|
||||
typedef typeof(vmp_apply_dft_tmp_bytes) VMP_APPLY_DFT_TMP_BYTES_F;
|
||||
typedef typeof(vmp_apply_dft_to_dft) VMP_APPLY_DFT_TO_DFT_F;
|
||||
typedef typeof(vmp_apply_dft_to_dft_add) VMP_APPLY_DFT_TO_DFT_ADD_F;
|
||||
typedef typeof(vmp_apply_dft_to_dft_tmp_bytes) VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F;
|
||||
typedef typeof(bytes_of_vec_znx_dft) BYTES_OF_VEC_ZNX_DFT_F;
|
||||
typedef typeof(bytes_of_vec_znx_big) BYTES_OF_VEC_ZNX_BIG_F;
|
||||
typedef typeof(bytes_of_svp_ppol) BYTES_OF_SVP_PPOL_F;
|
||||
typedef typeof(bytes_of_vmp_pmat) BYTES_OF_VMP_PMAT_F;
|
||||
|
||||
struct module_virtual_functions_t {
|
||||
// TODO add functions here
|
||||
VEC_ZNX_ZERO_F* vec_znx_zero;
|
||||
VEC_ZNX_COPY_F* vec_znx_copy;
|
||||
VEC_ZNX_NEGATE_F* vec_znx_negate;
|
||||
VEC_ZNX_ADD_F* vec_znx_add;
|
||||
VEC_ZNX_DFT_F* vec_znx_dft;
|
||||
VEC_DFT_ADD_F* vec_dft_add;
|
||||
VEC_DFT_SUB_F* vec_dft_sub;
|
||||
VEC_ZNX_IDFT_F* vec_znx_idft;
|
||||
VEC_ZNX_IDFT_TMP_BYTES_F* vec_znx_idft_tmp_bytes;
|
||||
VEC_ZNX_IDFT_TMP_A_F* vec_znx_idft_tmp_a;
|
||||
VEC_ZNX_SUB_F* vec_znx_sub;
|
||||
VEC_ZNX_ROTATE_F* vec_znx_rotate;
|
||||
VEC_ZNX_MUL_XP_MINUS_ONE_F* vec_znx_mul_xp_minus_one;
|
||||
VEC_ZNX_AUTOMORPHISM_F* vec_znx_automorphism;
|
||||
VEC_ZNX_NORMALIZE_BASE2K_F* vec_znx_normalize_base2k;
|
||||
VEC_ZNX_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_normalize_base2k_tmp_bytes;
|
||||
VEC_ZNX_BIG_NORMALIZE_BASE2K_F* vec_znx_big_normalize_base2k;
|
||||
VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_big_normalize_base2k_tmp_bytes;
|
||||
VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F* vec_znx_big_range_normalize_base2k;
|
||||
VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_big_range_normalize_base2k_tmp_bytes;
|
||||
VEC_ZNX_BIG_ADD_F* vec_znx_big_add;
|
||||
VEC_ZNX_BIG_ADD_SMALL_F* vec_znx_big_add_small;
|
||||
VEC_ZNX_BIG_ADD_SMALL2_F* vec_znx_big_add_small2;
|
||||
VEC_ZNX_BIG_SUB_F* vec_znx_big_sub;
|
||||
VEC_ZNX_BIG_SUB_SMALL_A_F* vec_znx_big_sub_small_a;
|
||||
VEC_ZNX_BIG_SUB_SMALL_B_F* vec_znx_big_sub_small_b;
|
||||
VEC_ZNX_BIG_SUB_SMALL2_F* vec_znx_big_sub_small2;
|
||||
VEC_ZNX_BIG_ROTATE_F* vec_znx_big_rotate;
|
||||
VEC_ZNX_BIG_AUTOMORPHISM_F* vec_znx_big_automorphism;
|
||||
SVP_PREPARE* svp_prepare;
|
||||
SVP_APPLY_DFT_F* svp_apply_dft;
|
||||
SVP_APPLY_DFT_TO_DFT_F* svp_apply_dft_to_dft;
|
||||
ZNX_SMALL_SINGLE_PRODUCT_F* znx_small_single_product;
|
||||
ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx_small_single_product_tmp_bytes;
|
||||
VMP_PREPARE_CONTIGUOUS_F* vmp_prepare_contiguous;
|
||||
VMP_PREPARE_TMP_BYTES_F* vmp_prepare_tmp_bytes;
|
||||
VMP_APPLY_DFT_F* vmp_apply_dft;
|
||||
VMP_APPLY_DFT_ADD_F* vmp_apply_dft_add;
|
||||
VMP_APPLY_DFT_TMP_BYTES_F* vmp_apply_dft_tmp_bytes;
|
||||
VMP_APPLY_DFT_TO_DFT_F* vmp_apply_dft_to_dft;
|
||||
VMP_APPLY_DFT_TO_DFT_ADD_F* vmp_apply_dft_to_dft_add;
|
||||
VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* vmp_apply_dft_to_dft_tmp_bytes;
|
||||
BYTES_OF_VEC_ZNX_DFT_F* bytes_of_vec_znx_dft;
|
||||
BYTES_OF_VEC_ZNX_BIG_F* bytes_of_vec_znx_big;
|
||||
BYTES_OF_SVP_PPOL_F* bytes_of_svp_ppol;
|
||||
BYTES_OF_VMP_PMAT_F* bytes_of_vmp_pmat;
|
||||
};
|
||||
|
||||
union backend_module_info_t {
|
||||
struct fft64_module_info_t fft64;
|
||||
struct q120_module_info_t q120;
|
||||
};
|
||||
|
||||
struct module_info_t {
|
||||
// generic parameters
|
||||
MODULE_TYPE module_type;
|
||||
uint64_t nn;
|
||||
uint64_t m;
|
||||
// backend_dependent functions
|
||||
union backend_module_info_t mod;
|
||||
// virtual functions
|
||||
struct module_virtual_functions_t func;
|
||||
};
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||
uint64_t size);
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_vec_znx_big(const MODULE* module, // N
|
||||
uint64_t size);
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_svp_ppol(const MODULE* module); // N
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_vmp_pmat(const MODULE* module, // N
|
||||
uint64_t nrows, uint64_t ncols);
|
||||
|
||||
EXPORT void vec_znx_zero_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_copy_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_negate_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_negate_avx(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_add_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
EXPORT void vec_znx_add_avx(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_sub_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_sub_avx(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t log2_base2k, // output base 2^K
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // inp
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module, uint64_t nn // N
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_rotate_ref(const MODULE* module, // N
|
||||
const int64_t p, // rotation value
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_mul_xp_minus_one_ref(const MODULE* module, // N
|
||||
const int64_t p, // rotation value
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_automorphism_ref(const MODULE* module, // N
|
||||
const int64_t p, // X->X^p
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vmp_prepare_ref(const MODULE* module, // N
|
||||
VMP_PMAT* pmat, // output
|
||||
const int64_t* mat, uint64_t nrows, uint64_t ncols // a
|
||||
);
|
||||
|
||||
EXPORT void vmp_apply_dft_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
|
||||
EXPORT void vec_dft_zero_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size // res
|
||||
);
|
||||
|
||||
EXPORT void vec_dft_add_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
);
|
||||
|
||||
EXPORT void vec_dft_sub_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
);
|
||||
|
||||
EXPORT void vec_dft_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_idft_ref(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size);
|
||||
|
||||
EXPORT void vec_znx_big_normalize_ref(const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||
EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||
EXPORT void fft64_svp_apply_dft_to_dft_ref(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size,
|
||||
uint64_t res_cols, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size,
|
||||
uint64_t a_cols // a
|
||||
);
|
||||
|
||||
/** @brief sets res = k-normalize(a) -- output in int64 coeffs space */
|
||||
EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
uint8_t* tmp_space // temp space
|
||||
);
|
||||
|
||||
/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */
|
||||
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||
|
||||
);
|
||||
|
||||
/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */
|
||||
EXPORT void fft64_vec_znx_big_range_normalize_base2k(const MODULE* module, // MODULE
|
||||
uint64_t nn,
|
||||
uint64_t log2_base2k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_range_begin, // a
|
||||
uint64_t a_range_xend, uint64_t a_range_step, // range
|
||||
uint8_t* tmp_space // temp space
|
||||
);
|
||||
|
||||
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
|
||||
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||
);
|
||||
|
||||
EXPORT void fft64_vec_znx_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void fft64_vec_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
);
|
||||
|
||||
EXPORT void fft64_vec_dft_sub(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
);
|
||||
|
||||
EXPORT void fft64_vec_znx_idft(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
uint8_t* tmp // scratch space
|
||||
);
|
||||
|
||||
EXPORT uint64_t fft64_vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn);
|
||||
|
||||
EXPORT void fft64_vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||
);
|
||||
|
||||
EXPORT void ntt120_vec_znx_dft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** */
|
||||
EXPORT void ntt120_vec_znx_idft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
uint8_t* tmp // scratch space
|
||||
);
|
||||
|
||||
EXPORT uint64_t ntt120_vec_znx_idft_tmp_bytes_avx(const MODULE* module, uint64_t nn);
|
||||
|
||||
EXPORT void ntt120_vec_znx_idft_tmp_a_avx(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||
);
|
||||
|
||||
// big additions/subtractions
|
||||
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void fft64_vec_znx_big_add(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
);
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void fft64_vec_znx_big_add_small(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
EXPORT void fft64_vec_znx_big_add_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a-b */
|
||||
EXPORT void fft64_vec_znx_big_sub(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
);
|
||||
EXPORT void fft64_vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
EXPORT void fft64_vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
);
|
||||
EXPORT void fft64_vec_znx_big_sub_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void fft64_vec_znx_big_rotate(const MODULE* module, // N
|
||||
int64_t p, // rotation value
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void fft64_vec_znx_big_automorphism(const MODULE* module, // N
|
||||
int64_t p, // X-X^p
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** @brief prepares a svp polynomial */
|
||||
EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N
|
||||
SVP_PPOL* ppol, // output
|
||||
const int64_t* pol // a
|
||||
);
|
||||
|
||||
/** @brief res = a * b : small integer polynomial product */
|
||||
EXPORT void fft64_znx_small_single_product(const MODULE* module, // N
|
||||
int64_t* res, // output
|
||||
const int64_t* a, // a
|
||||
const int64_t* b, // b
|
||||
uint8_t* tmp);
|
||||
|
||||
/** @brief tmp bytes required for znx_small_single_product */
|
||||
EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn);
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, // N
|
||||
VMP_PMAT* pmat, // output
|
||||
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, // N
|
||||
VMP_PMAT* pmat, // output
|
||||
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||
EXPORT uint64_t fft64_vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t nrows, uint64_t ncols);
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) and adds to res inplace */
|
||||
EXPORT void fft64_vmp_apply_dft_add_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) */
|
||||
EXPORT void fft64_vmp_apply_dft_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) */
|
||||
EXPORT void fft64_vmp_apply_dft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) and adds to res inplace*/
|
||||
EXPORT void fft64_vmp_apply_dft_add_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief this inner function could be very handy */
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||
const uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
|
||||
/** @brief applies rmp product and adds to res inplace */
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_add_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
|
||||
/** @brief this inner function could be very handy */
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||
const uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
|
||||
/** @brief applies rmp product and adds to res inplace */
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_add_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t fft64_vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t fft64_vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
#endif // SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H
|
||||
@@ -1,103 +0,0 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "../coeffs/coeffs_arithmetic.h"
|
||||
#include "../reim4/reim4_arithmetic.h"
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
// specialized function (ref)
|
||||
|
||||
// Note: these functions do not have an avx variant.
|
||||
#define znx_copy_i64_avx znx_copy_i64_ref
|
||||
#define znx_zero_i64_avx znx_zero_i64_ref
|
||||
|
||||
EXPORT void vec_znx_add_avx(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
if (a_size <= b_size) {
|
||||
const uint64_t sum_idx = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||
// add up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||
znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then copy to the largest dimension
|
||||
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||
znx_copy_i64_avx(nn, res + i * res_sl, b + i * b_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||
}
|
||||
} else {
|
||||
const uint64_t sum_idx = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||
// add up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||
znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then copy to the largest dimension
|
||||
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||
znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_sub_avx(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
if (a_size <= b_size) {
|
||||
const uint64_t sub_idx = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||
// subtract up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||
znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then negate to the largest dimension
|
||||
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||
znx_negate_i64_avx(nn, res + i * res_sl, b + i * b_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||
}
|
||||
} else {
|
||||
const uint64_t sub_idx = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||
// subtract up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||
znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then copy to the largest dimension
|
||||
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||
znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_negate_avx(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
uint64_t nn = module->nn;
|
||||
uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < smin; ++i) {
|
||||
znx_negate_i64_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
for (uint64_t i = smin; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
@@ -1,278 +0,0 @@
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
EXPORT uint64_t bytes_of_vec_znx_big(const MODULE* module, // N
|
||||
uint64_t size) {
|
||||
return module->func.bytes_of_vec_znx_big(module, size);
|
||||
}
|
||||
|
||||
// public wrappers
|
||||
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void vec_znx_big_add(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
) {
|
||||
module->func.vec_znx_big_add(module, res, res_size, a, a_size, b, b_size);
|
||||
}
|
||||
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void vec_znx_big_add_small(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->func.vec_znx_big_add_small(module, res, res_size, a, a_size, b, b_size, b_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_big_add_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->func.vec_znx_big_add_small2(module, res, res_size, a, a_size, a_sl, b, b_size, b_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a-b */
|
||||
EXPORT void vec_znx_big_sub(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
) {
|
||||
module->func.vec_znx_big_sub(module, res, res_size, a, a_size, b, b_size);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->func.vec_znx_big_sub_small_b(module, res, res_size, a, a_size, b, b_size, b_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
) {
|
||||
module->func.vec_znx_big_sub_small_a(module, res, res_size, a, a_size, a_sl, b, b_size);
|
||||
}
|
||||
EXPORT void vec_znx_big_sub_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->func.vec_znx_big_sub_small2(module, res, res_size, a, a_size, a_sl, b, b_size, b_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void vec_znx_big_rotate(const MODULE* module, // N
|
||||
int64_t p, // rotation value
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
) {
|
||||
module->func.vec_znx_big_rotate(module, p, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void vec_znx_big_automorphism(const MODULE* module, // N
|
||||
int64_t p, // X-X^p
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
) {
|
||||
module->func.vec_znx_big_automorphism(module, p, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
// private wrappers
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_vec_znx_big(const MODULE* module, // N
|
||||
uint64_t size) {
|
||||
return module->nn * size * sizeof(double);
|
||||
}
|
||||
|
||||
EXPORT VEC_ZNX_BIG* new_vec_znx_big(const MODULE* module, // N
|
||||
uint64_t size) {
|
||||
return spqlios_alloc(bytes_of_vec_znx_big(module, size));
|
||||
}
|
||||
|
||||
EXPORT void delete_vec_znx_big(VEC_ZNX_BIG* res) { spqlios_free(res); }
|
||||
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void fft64_vec_znx_big_add(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
) {
|
||||
const uint64_t n = module->nn;
|
||||
vec_znx_add(module, //
|
||||
(int64_t*)res, res_size, n, //
|
||||
(int64_t*)a, a_size, n, //
|
||||
(int64_t*)b, b_size, n);
|
||||
}
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void fft64_vec_znx_big_add_small(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t n = module->nn;
|
||||
vec_znx_add(module, //
|
||||
(int64_t*)res, res_size, n, //
|
||||
(int64_t*)a, a_size, n, //
|
||||
b, b_size, b_sl);
|
||||
}
|
||||
EXPORT void fft64_vec_znx_big_add_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t n = module->nn;
|
||||
vec_znx_add(module, //
|
||||
(int64_t*)res, res_size, n, //
|
||||
a, a_size, a_sl, //
|
||||
b, b_size, b_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a-b */
|
||||
EXPORT void fft64_vec_znx_big_sub(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
) {
|
||||
const uint64_t n = module->nn;
|
||||
vec_znx_sub(module, //
|
||||
(int64_t*)res, res_size, n, //
|
||||
(int64_t*)a, a_size, n, //
|
||||
(int64_t*)b, b_size, n);
|
||||
}
|
||||
|
||||
EXPORT void fft64_vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t n = module->nn;
|
||||
vec_znx_sub(module, //
|
||||
(int64_t*)res, res_size, n, //
|
||||
(int64_t*)a, a_size, //
|
||||
n, b, b_size, b_sl);
|
||||
}
|
||||
EXPORT void fft64_vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
) {
|
||||
const uint64_t n = module->nn;
|
||||
vec_znx_sub(module, //
|
||||
(int64_t*)res, res_size, n, //
|
||||
a, a_size, a_sl, //
|
||||
(int64_t*)b, b_size, n);
|
||||
}
|
||||
EXPORT void fft64_vec_znx_big_sub_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t n = module->nn;
|
||||
vec_znx_sub(module, //
|
||||
(int64_t*)res, res_size, //
|
||||
n, a, a_size, //
|
||||
a_sl, b, b_size, b_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void fft64_vec_znx_big_rotate(const MODULE* module, // N
|
||||
int64_t p, // rotation value
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
) {
|
||||
uint64_t nn = module->nn;
|
||||
vec_znx_rotate(module, p, (int64_t*)res, res_size, nn, (int64_t*)a, a_size, nn);
|
||||
}
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void fft64_vec_znx_big_automorphism(const MODULE* module, // N
|
||||
int64_t p, // X-X^p
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
) {
|
||||
uint64_t nn = module->nn;
|
||||
vec_znx_automorphism(module, p, (int64_t*)res, res_size, nn, (int64_t*)a, a_size, nn);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_big_normalize_base2k(const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
uint8_t* tmp_space // temp space
|
||||
) {
|
||||
module->func.vec_znx_big_normalize_base2k(module, // MODULE
|
||||
nn, // N
|
||||
k, // base-2^k
|
||||
res, res_size, res_sl, // res
|
||||
a, a_size, // a
|
||||
tmp_space);
|
||||
}
|
||||
|
||||
EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||
) {
|
||||
return module->func.vec_znx_big_normalize_base2k_tmp_bytes(module, nn // N
|
||||
);
|
||||
}
|
||||
|
||||
/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */
|
||||
EXPORT void vec_znx_big_range_normalize_base2k( //
|
||||
const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t log2_base2k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_range_begin, uint64_t a_range_xend, uint64_t a_range_step, // range
|
||||
uint8_t* tmp_space // temp space
|
||||
) {
|
||||
module->func.vec_znx_big_range_normalize_base2k(module, nn, log2_base2k, res, res_size, res_sl, a, a_range_begin,
|
||||
a_range_xend, a_range_step, tmp_space);
|
||||
}
|
||||
|
||||
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
|
||||
EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||
const MODULE* module, // MODULE
|
||||
uint64_t nn // N
|
||||
) {
|
||||
return module->func.vec_znx_big_range_normalize_base2k_tmp_bytes(module, nn);
|
||||
}
|
||||
|
||||
EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
uint8_t* tmp_space) {
|
||||
uint64_t a_sl = nn;
|
||||
module->func.vec_znx_normalize_base2k(module, // N
|
||||
nn,
|
||||
k, // log2_base2k
|
||||
res, res_size, res_sl, // res
|
||||
(int64_t*)a, a_size, a_sl, // a
|
||||
tmp_space);
|
||||
}
|
||||
|
||||
EXPORT void fft64_vec_znx_big_range_normalize_base2k( //
|
||||
const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_begin, uint64_t a_end, uint64_t a_step, // a
|
||||
uint8_t* tmp_space) {
|
||||
// convert the range indexes to int64[] slices
|
||||
const int64_t* a_st = ((int64_t*)a) + nn * a_begin;
|
||||
const uint64_t a_size = (a_end + a_step - 1 - a_begin) / a_step;
|
||||
const uint64_t a_sl = nn * a_step;
|
||||
// forward the call
|
||||
module->func.vec_znx_normalize_base2k(module, // MODULE
|
||||
nn, // N
|
||||
k, // log2_base2k
|
||||
res, res_size, res_sl, // res
|
||||
a_st, a_size, a_sl, // a
|
||||
tmp_space);
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "../q120/q120_arithmetic.h"
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
EXPORT void vec_znx_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
return module->func.vec_znx_dft(module, res, res_size, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
) {
|
||||
return module->func.vec_dft_add(module, res, res_size, a, a_size, b, b_size);
|
||||
}
|
||||
|
||||
EXPORT void vec_dft_sub(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
) {
|
||||
return module->func.vec_dft_sub(module, res, res_size, a, a_size, b, b_size);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_idft(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
uint8_t* tmp // scratch space
|
||||
) {
|
||||
return module->func.vec_znx_idft(module, res, res_size, a_dft, a_size, tmp);
|
||||
}
|
||||
|
||||
EXPORT uint64_t vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn) { return module->func.vec_znx_idft_tmp_bytes(module, nn); }
|
||||
|
||||
EXPORT void vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||
) {
|
||||
return module->func.vec_znx_idft_tmp_a(module, res, res_size, a_dft, a_size);
|
||||
}
|
||||
|
||||
EXPORT uint64_t bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||
uint64_t size) {
|
||||
return module->func.bytes_of_vec_znx_dft(module, size);
|
||||
}
|
||||
|
||||
// fft64 backend
|
||||
EXPORT uint64_t fft64_bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||
uint64_t size) {
|
||||
return module->nn * size * sizeof(double);
|
||||
}
|
||||
|
||||
EXPORT VEC_ZNX_DFT* new_vec_znx_dft(const MODULE* module, // N
|
||||
uint64_t size) {
|
||||
return spqlios_alloc(bytes_of_vec_znx_dft(module, size));
|
||||
}
|
||||
|
||||
EXPORT void delete_vec_znx_dft(VEC_ZNX_DFT* res) { spqlios_free(res); }
|
||||
|
||||
EXPORT void fft64_vec_znx_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
reim_from_znx64(module->mod.fft64.p_conv, ((double*)res) + i * nn, a + i * a_sl);
|
||||
reim_fft(module->mod.fft64.p_fft, ((double*)res) + i * nn);
|
||||
}
|
||||
|
||||
// fill up remaining part with 0's
|
||||
double* const dres = (double*)res;
|
||||
memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
EXPORT void fft64_vec_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
) {
|
||||
const uint64_t smin0 = a_size < b_size ? a_size : b_size;
|
||||
const uint64_t smin = res_size < smin0 ? res_size : smin0;
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
reim_fftvec_add(module->mod.fft64.add_fft, ((double*)res) + i * nn, ((double*)a) + i * nn, ((double*)b) + i * nn);
|
||||
}
|
||||
|
||||
// fill remain `res` part with 0's
|
||||
double* const dres = (double*)res;
|
||||
memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
EXPORT void fft64_vec_dft_sub(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
) {
|
||||
const uint64_t smin0 = a_size < b_size ? a_size : b_size;
|
||||
const uint64_t smin = res_size < smin0 ? res_size : smin0;
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
reim_fftvec_sub(module->mod.fft64.sub_fft, ((double*)res) + i * nn, ((double*)a) + i * nn, ((double*)b) + i * nn);
|
||||
}
|
||||
|
||||
// fill remain `res` part with 0's
|
||||
double* const dres = (double*)res;
|
||||
memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
EXPORT void fft64_vec_znx_idft(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
uint8_t* tmp // unused
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
if ((double*)res != (double*)a_dft) {
|
||||
memcpy(res, a_dft, smin * nn * sizeof(double));
|
||||
}
|
||||
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
reim_ifft(module->mod.fft64.p_ifft, ((double*)res) + i * nn);
|
||||
reim_to_znx64(module->mod.fft64.p_reim_to_znx, ((int64_t*)res) + i * nn, ((int64_t*)res) + i * nn);
|
||||
}
|
||||
|
||||
// fill up remaining part with 0's
|
||||
int64_t* const dres = (int64_t*)res;
|
||||
memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
EXPORT uint64_t fft64_vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn) { return 0; }
|
||||
|
||||
EXPORT void fft64_vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
|
||||
int64_t* const tres = (int64_t*)res;
|
||||
double* const ta = (double*)a_dft;
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
reim_ifft(module->mod.fft64.p_ifft, ta + i * nn);
|
||||
reim_to_znx64(module->mod.fft64.p_reim_to_znx, tres + i * nn, ta + i * nn);
|
||||
}
|
||||
|
||||
// fill up remaining part with 0's
|
||||
memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
// ntt120 backend
|
||||
|
||||
EXPORT void ntt120_vec_znx_dft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
|
||||
int64_t* tres = (int64_t*)res;
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
q120_b_from_znx64_simple(nn, (q120b*)(tres + i * nn * 4), a + i * a_sl);
|
||||
q120_ntt_bb_avx2(module->mod.q120.p_ntt, (q120b*)(tres + i * nn * 4));
|
||||
}
|
||||
|
||||
// fill up remaining part with 0's
|
||||
memset(tres + smin * nn * 4, 0, (res_size - smin) * nn * 4 * sizeof(int64_t));
|
||||
}
|
||||
|
||||
EXPORT void ntt120_vec_znx_idft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
uint8_t* tmp) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
|
||||
__int128_t* const tres = (__int128_t*)res;
|
||||
const int64_t* const ta = (int64_t*)a_dft;
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
memcpy(tmp, ta + i * nn * 4, nn * 4 * sizeof(uint64_t));
|
||||
q120_intt_bb_avx2(module->mod.q120.p_intt, (q120b*)tmp);
|
||||
q120_b_to_znx128_simple(nn, tres + i * nn, (q120b*)tmp);
|
||||
}
|
||||
|
||||
// fill up remaining part with 0's
|
||||
memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(*tres));
|
||||
}
|
||||
|
||||
EXPORT uint64_t ntt120_vec_znx_idft_tmp_bytes_avx(const MODULE* module, uint64_t nn) { return nn * 4 * sizeof(uint64_t); }
|
||||
|
||||
EXPORT void ntt120_vec_znx_idft_tmp_a_avx(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
|
||||
__int128_t* const tres = (__int128_t*)res;
|
||||
int64_t* const ta = (int64_t*)a_dft;
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
q120_intt_bb_avx2(module->mod.q120.p_intt, (q120b*)(ta + i * nn * 4));
|
||||
q120_b_to_znx128_simple(nn, tres + i * nn, (q120b*)(ta + i * nn * 4));
|
||||
}
|
||||
|
||||
// fill up remaining part with 0's
|
||||
memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(*tres));
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
@@ -1,369 +0,0 @@
|
||||
#include <assert.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "../reim4/reim4_arithmetic.h"
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
EXPORT uint64_t bytes_of_vmp_pmat(const MODULE* module, // N
|
||||
uint64_t nrows, uint64_t ncols // dimensions
|
||||
) {
|
||||
return module->func.bytes_of_vmp_pmat(module, nrows, ncols);
|
||||
}
|
||||
|
||||
// fft64
|
||||
EXPORT uint64_t fft64_bytes_of_vmp_pmat(const MODULE* module, // N
|
||||
uint64_t nrows, uint64_t ncols // dimensions
|
||||
) {
|
||||
return module->nn * nrows * ncols * sizeof(double);
|
||||
}
|
||||
|
||||
EXPORT VMP_PMAT* new_vmp_pmat(const MODULE* module, // N
|
||||
uint64_t nrows, uint64_t ncols // dimensions
|
||||
) {
|
||||
return spqlios_alloc(bytes_of_vmp_pmat(module, nrows, ncols));
|
||||
}
|
||||
|
||||
EXPORT void delete_vmp_pmat(VMP_PMAT* res) { spqlios_free(res); }
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void vmp_prepare_contiguous(const MODULE* module, // N
|
||||
VMP_PMAT* pmat, // output
|
||||
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
module->func.vmp_prepare_contiguous(module, pmat, mat, nrows, ncols, tmp_space);
|
||||
}
|
||||
|
||||
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||
EXPORT uint64_t vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t nrows, uint64_t ncols) {
|
||||
return module->func.vmp_prepare_tmp_bytes(module, nn, nrows, ncols);
|
||||
}
|
||||
|
||||
EXPORT double* get_blk_addr(uint64_t row_i, uint64_t col_i, uint64_t nrows, uint64_t ncols, const VMP_PMAT* pmat) {
|
||||
double* output_mat = (double*)pmat;
|
||||
|
||||
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||
// special case: last column out of an odd column number
|
||||
return output_mat + col_i * nrows * 8 // col == ncols-1
|
||||
+ row_i * 8;
|
||||
} else {
|
||||
// general case: columns go by pair
|
||||
return output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||
+ row_i * 2 * 8 // third: row index
|
||||
+ (col_i % 2) * 8;
|
||||
}
|
||||
}
|
||||
|
||||
void fft64_store_svp_ppol_into_vmp_pmat_row_blk_ref(uint64_t nn, uint64_t m, const SVP_PPOL* svp_ppol, uint64_t row_i,
|
||||
uint64_t col_i, uint64_t nrows, uint64_t ncols, VMP_PMAT* pmat) {
|
||||
double* start_addr = get_blk_addr(row_i, col_i, nrows, ncols, pmat);
|
||||
uint64_t offset = nrows * ncols * 8;
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, (double*)svp_ppol);
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, // N
|
||||
VMP_PMAT* pmat, // output
|
||||
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
// there is an edge case if nn < 8
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t m = module->m;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)tmp_space, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->mod.fft64.p_fft, (double*)tmp_space);
|
||||
fft64_store_svp_ppol_into_vmp_pmat_row_blk_ref(nn, m, (SVP_PPOL*)tmp_space, row_i, col_i, nrows, ncols, pmat);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
double* res = (double*)pmat + (col_i * nrows + row_i) * nn;
|
||||
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)res, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->mod.fft64.p_fft, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||
EXPORT uint64_t fft64_vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t nrows, uint64_t ncols) {
|
||||
return nn * sizeof(int64_t);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) and adds to res inplace */
|
||||
EXPORT void fft64_vmp_apply_dft_add_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||
|
||||
VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space;
|
||||
uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double);
|
||||
|
||||
fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl);
|
||||
fft64_vmp_apply_dft_to_dft_add_ref(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, pmat_scale,
|
||||
new_tmp_space);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) */
|
||||
EXPORT void fft64_vmp_apply_dft_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||
|
||||
VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space;
|
||||
uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double);
|
||||
|
||||
fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl);
|
||||
fft64_vmp_apply_dft_to_dft_ref(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, new_tmp_space);
|
||||
}
|
||||
|
||||
/** @brief like fft64_vmp_apply_dft_to_dft_ref but adds in place */
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_add_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
const uint64_t m = module->m;
|
||||
const uint64_t nn = module->nn;
|
||||
assert(nn >= 8);
|
||||
|
||||
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||
|
||||
double* mat_input = (double*)pmat;
|
||||
double* vec_input = (double*)a_dft;
|
||||
double* vec_output = (double*)res;
|
||||
|
||||
// const uint64_t row_max0 = res_size < a_size ? res_size: a_size;
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||
|
||||
reim4_extract_1blk_from_contiguous_reim_ref(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft);
|
||||
|
||||
if (pmat_scale % 2 == 0) {
|
||||
// apply mat2cols
|
||||
for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) {
|
||||
uint64_t col_offset = col_pmat * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + col_res * nn, mat2cols_output);
|
||||
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8);
|
||||
}
|
||||
} else {
|
||||
uint64_t col_offset = (pmat_scale - 1) * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output, mat2cols_output + 8);
|
||||
|
||||
// apply mat2cols
|
||||
for (uint64_t col_res = 1, col_pmat = pmat_scale + 1; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) {
|
||||
uint64_t col_offset = col_pmat * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
|
||||
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + col_res * nn, mat2cols_output);
|
||||
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8);
|
||||
}
|
||||
}
|
||||
|
||||
// check if col_max is odd, then special case
|
||||
if (col_max % 2 == 1) {
|
||||
uint64_t last_col = col_max - 1;
|
||||
uint64_t col_offset = last_col * (8 * nrows);
|
||||
|
||||
if (last_col >= pmat_scale) {
|
||||
// the last column is alone in the pmat: vec_mat1col
|
||||
if (ncols == col_max) {
|
||||
reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
} else {
|
||||
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
}
|
||||
|
||||
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + (last_col - pmat_scale) * nn, mat2cols_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max; col_res += 1, col_pmat += 1) {
|
||||
double* pmat_col = mat_input + col_pmat * nrows * nn;
|
||||
for (uint64_t row_i = 0; row_i < row_max; row_i++) {
|
||||
reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_res * nn, vec_input + row_i * nn,
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// zero out remaining bytes
|
||||
memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
/** @brief this inner function could be very handy */
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||
const uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
const uint64_t m = module->m;
|
||||
const uint64_t nn = module->nn;
|
||||
assert(nn >= 8);
|
||||
|
||||
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||
|
||||
double* mat_input = (double*)pmat;
|
||||
double* vec_input = (double*)a_dft;
|
||||
double* vec_output = (double*)res;
|
||||
|
||||
// const uint64_t row_max0 = res_size < a_size ? res_size: a_size;
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||
|
||||
reim4_extract_1blk_from_contiguous_reim_ref(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft);
|
||||
// apply mat2cols
|
||||
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||
uint64_t col_offset = col_i * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
|
||||
reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + col_i * nn, mat2cols_output);
|
||||
reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + (col_i + 1) * nn, mat2cols_output + 8);
|
||||
}
|
||||
|
||||
// check if col_max is odd, then special case
|
||||
if (col_max % 2 == 1) {
|
||||
uint64_t last_col = col_max - 1;
|
||||
uint64_t col_offset = last_col * (8 * nrows);
|
||||
|
||||
// the last column is alone in the pmat: vec_mat1col
|
||||
if (ncols == col_max) {
|
||||
reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
} else {
|
||||
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
}
|
||||
reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + last_col * nn, mat2cols_output);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||
for (uint64_t row_i = 0; row_i < 1; row_i++) {
|
||||
reim_fftvec_mul(module->mod.fft64.mul_fft, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||
reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// zero out remaining bytes
|
||||
memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t fft64_vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) {
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
return (row_max * nn * sizeof(double)) + (128) + (64 * row_max);
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t fft64_vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) {
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
|
||||
return (128) + (64 * row_max);
|
||||
}
|
||||
|
||||
EXPORT void vmp_apply_dft_to_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||
const uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
module->func.vmp_apply_dft_to_dft(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, tmp_space);
|
||||
}
|
||||
|
||||
EXPORT void vmp_apply_dft_to_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
module->func.vmp_apply_dft_to_dft_add(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, pmat_scale,
|
||||
tmp_space);
|
||||
}
|
||||
|
||||
EXPORT uint64_t vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) {
|
||||
return module->func.vmp_apply_dft_to_dft_tmp_bytes(module, nn, res_size, a_size, nrows, ncols);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) adds to res inplace */
|
||||
EXPORT void vmp_apply_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
module->func.vmp_apply_dft_add(module, res, res_size, a, a_size, a_sl, pmat, nrows, ncols, pmat_scale, tmp_space);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) */
|
||||
EXPORT void vmp_apply_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
module->func.vmp_apply_dft(module, res, res_size, a, a_size, a_sl, pmat, nrows, ncols, tmp_space);
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) {
|
||||
return module->func.vmp_apply_dft_tmp_bytes(module, nn, res_size, a_size, nrows, ncols);
|
||||
}
|
||||
@@ -1,244 +0,0 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "../reim4/reim4_arithmetic.h"
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, // N
|
||||
VMP_PMAT* pmat, // output
|
||||
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
// there is an edge case if nn < 8
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t m = module->m;
|
||||
|
||||
double* output_mat = (double*)pmat;
|
||||
double* start_addr = (double*)pmat;
|
||||
uint64_t offset = nrows * ncols * 8;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)tmp_space, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->mod.fft64.p_fft, (double*)tmp_space);
|
||||
|
||||
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||
// special case: last column out of an odd column number
|
||||
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||
+ row_i * 8;
|
||||
} else {
|
||||
// general case: columns go by pair
|
||||
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||
+ row_i * 2 * 8 // third: row index
|
||||
+ (col_i % 2) * 8;
|
||||
}
|
||||
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
// extract blk from tmp and save it
|
||||
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, (double*)tmp_space);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
double* res = (double*)pmat + (col_i * nrows + row_i) * nn;
|
||||
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)res, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->mod.fft64.p_fft, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double* get_blk_addr(int row, int col, int nrows, int ncols, VMP_PMAT* pmat);
|
||||
|
||||
void fft64_store_svp_ppol_into_vmp_pmat_row_blk_avx(uint64_t nn, uint64_t m, const SVP_PPOL* svp_ppol, uint64_t row_i,
|
||||
uint64_t col_i, uint64_t nrows, uint64_t ncols, VMP_PMAT* pmat) {
|
||||
double* start_addr = get_blk_addr(row_i, col_i, nrows, ncols, pmat);
|
||||
uint64_t offset = nrows * ncols * 8;
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, (double*)svp_ppol);
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) abd adds to res inplace */
|
||||
EXPORT void fft64_vmp_apply_dft_add_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||
|
||||
VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space;
|
||||
uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double);
|
||||
|
||||
fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl);
|
||||
fft64_vmp_apply_dft_to_dft_add_avx(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, pmat_scale,
|
||||
new_tmp_space);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) */
|
||||
EXPORT void fft64_vmp_apply_dft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||
|
||||
VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space;
|
||||
uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double);
|
||||
|
||||
fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl);
|
||||
fft64_vmp_apply_dft_to_dft_avx(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, new_tmp_space);
|
||||
}
|
||||
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_add_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
const uint64_t m = module->m;
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||
|
||||
double* mat_input = (double*)pmat;
|
||||
double* vec_input = (double*)a_dft;
|
||||
double* vec_output = (double*)res;
|
||||
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||
|
||||
reim4_extract_1blk_from_contiguous_reim_avx(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft);
|
||||
|
||||
if (pmat_scale % 2 == 0) {
|
||||
for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) {
|
||||
uint64_t col_offset = col_pmat * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + col_res * nn, mat2cols_output);
|
||||
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8);
|
||||
}
|
||||
} else {
|
||||
uint64_t col_offset = (pmat_scale - 1) * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output, mat2cols_output + 8);
|
||||
|
||||
for (uint64_t col_res = 1, col_pmat = pmat_scale + 1; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) {
|
||||
uint64_t col_offset = col_pmat * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + col_res * nn, mat2cols_output);
|
||||
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8);
|
||||
}
|
||||
}
|
||||
|
||||
// check if col_max is odd, then special case
|
||||
if (col_max % 2 == 1) {
|
||||
uint64_t last_col = col_max - 1;
|
||||
uint64_t col_offset = last_col * (8 * nrows);
|
||||
|
||||
if (last_col >= pmat_scale) {
|
||||
// the last column is alone in the pmat: vec_mat1col
|
||||
if (ncols == col_max)
|
||||
reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
else {
|
||||
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
}
|
||||
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + (last_col - pmat_scale) * nn, mat2cols_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max; col_res += 1, col_pmat += 1) {
|
||||
double* pmat_col = mat_input + col_pmat * nrows * nn;
|
||||
for (uint64_t row_i = 0; row_i < row_max; row_i++) {
|
||||
reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_res * nn, vec_input + row_i * nn,
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// zero out remaining bytes
|
||||
memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
/** @brief this inner function could be very handy */
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||
const uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
const uint64_t m = module->m;
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||
|
||||
double* mat_input = (double*)pmat;
|
||||
double* vec_input = (double*)a_dft;
|
||||
double* vec_output = (double*)res;
|
||||
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||
|
||||
reim4_extract_1blk_from_contiguous_reim_avx(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft);
|
||||
// apply mat2cols
|
||||
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||
uint64_t col_offset = col_i * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
|
||||
reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + col_i * nn, mat2cols_output);
|
||||
reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + (col_i + 1) * nn, mat2cols_output + 8);
|
||||
}
|
||||
|
||||
// check if col_max is odd, then special case
|
||||
if (col_max % 2 == 1) {
|
||||
uint64_t last_col = col_max - 1;
|
||||
uint64_t col_offset = last_col * (8 * nrows);
|
||||
|
||||
// the last column is alone in the pmat: vec_mat1col
|
||||
if (ncols == col_max)
|
||||
reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
else {
|
||||
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
}
|
||||
reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + last_col * nn, mat2cols_output);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||
for (uint64_t row_i = 0; row_i < 1; row_i++) {
|
||||
reim_fftvec_mul(module->mod.fft64.mul_fft, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||
reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// zero out remaining bytes
|
||||
memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double));
|
||||
}
|
||||
@@ -1,185 +0,0 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "zn_arithmetic_private.h"
|
||||
|
||||
void default_init_z_module_precomp(MOD_Z* module) {
|
||||
// Add here initialization of items that are in the precomp
|
||||
}
|
||||
|
||||
void default_finalize_z_module_precomp(MOD_Z* module) {
|
||||
// Add here deleters for items that are in the precomp
|
||||
}
|
||||
|
||||
void default_init_z_module_vtable(MOD_Z* module) {
|
||||
// Add function pointers here
|
||||
module->vtable.i8_approxdecomp_from_tndbl = default_i8_approxdecomp_from_tndbl_ref;
|
||||
module->vtable.i16_approxdecomp_from_tndbl = default_i16_approxdecomp_from_tndbl_ref;
|
||||
module->vtable.i32_approxdecomp_from_tndbl = default_i32_approxdecomp_from_tndbl_ref;
|
||||
module->vtable.zn32_vmp_prepare_contiguous = default_zn32_vmp_prepare_contiguous_ref;
|
||||
module->vtable.zn32_vmp_prepare_dblptr = default_zn32_vmp_prepare_dblptr_ref;
|
||||
module->vtable.zn32_vmp_prepare_row = default_zn32_vmp_prepare_row_ref;
|
||||
module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_ref;
|
||||
module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_ref;
|
||||
module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_ref;
|
||||
module->vtable.dbl_to_tn32 = dbl_to_tn32_ref;
|
||||
module->vtable.tn32_to_dbl = tn32_to_dbl_ref;
|
||||
module->vtable.dbl_round_to_i32 = dbl_round_to_i32_ref;
|
||||
module->vtable.i32_to_dbl = i32_to_dbl_ref;
|
||||
module->vtable.dbl_round_to_i64 = dbl_round_to_i64_ref;
|
||||
module->vtable.i64_to_dbl = i64_to_dbl_ref;
|
||||
|
||||
// Add optimized function pointers here
|
||||
if (CPU_SUPPORTS("avx")) {
|
||||
module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_avx;
|
||||
module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_avx;
|
||||
module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_avx;
|
||||
}
|
||||
}
|
||||
|
||||
void init_z_module_info(MOD_Z* module, //
|
||||
Z_MODULE_TYPE mtype) {
|
||||
memset(module, 0, sizeof(MOD_Z));
|
||||
module->mtype = mtype;
|
||||
switch (mtype) {
|
||||
case DEFAULT:
|
||||
default_init_z_module_precomp(module);
|
||||
default_init_z_module_vtable(module);
|
||||
break;
|
||||
default:
|
||||
NOT_SUPPORTED(); // unknown mtype
|
||||
}
|
||||
}
|
||||
|
||||
void finalize_z_module_info(MOD_Z* module) {
|
||||
if (module->custom) module->custom_deleter(module->custom);
|
||||
switch (module->mtype) {
|
||||
case DEFAULT:
|
||||
default_finalize_z_module_precomp(module);
|
||||
// fft64_finalize_rnx_module_vtable(module); // nothing to finalize
|
||||
break;
|
||||
default:
|
||||
NOT_SUPPORTED(); // unknown mtype
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT MOD_Z* new_z_module_info(Z_MODULE_TYPE mtype) {
|
||||
MOD_Z* res = (MOD_Z*)malloc(sizeof(MOD_Z));
|
||||
init_z_module_info(res, mtype);
|
||||
return res;
|
||||
}
|
||||
|
||||
EXPORT void delete_z_module_info(MOD_Z* module_info) {
|
||||
finalize_z_module_info(module_info);
|
||||
free(module_info);
|
||||
}
|
||||
|
||||
//////////////// wrappers //////////////////
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||
EXPORT void i8_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size) { // a
|
||||
module->vtable.i8_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||
EXPORT void i16_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int16_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size) { // a
|
||||
module->vtable.i16_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size);
|
||||
}
|
||||
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||
EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int32_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size) { // a
|
||||
module->vtable.i32_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void zn32_vmp_prepare_contiguous(const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* mat, uint64_t nrows, uint64_t ncols) { // a
|
||||
module->vtable.zn32_vmp_prepare_contiguous(module, pmat, mat, nrows, ncols);
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void zn32_vmp_prepare_dblptr(const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t** mat, uint64_t nrows, uint64_t ncols) { // a
|
||||
module->vtable.zn32_vmp_prepare_dblptr(module, pmat, mat, nrows, ncols);
|
||||
}
|
||||
|
||||
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||
EXPORT void zn32_vmp_prepare_row(const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols) { // a
|
||||
module->vtable.zn32_vmp_prepare_row(module, pmat, row, row_i, nrows, ncols);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (int32_t* input) */
|
||||
EXPORT void zn32_vmp_apply_i32(const MOD_Z* module, int32_t* res, uint64_t res_size, const int32_t* a, uint64_t a_size,
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||
module->vtable.zn32_vmp_apply_i32(module, res, res_size, a, a_size, pmat, nrows, ncols);
|
||||
}
|
||||
/** @brief applies a vmp product (int16_t* input) */
|
||||
EXPORT void zn32_vmp_apply_i16(const MOD_Z* module, int32_t* res, uint64_t res_size, const int16_t* a, uint64_t a_size,
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||
module->vtable.zn32_vmp_apply_i16(module, res, res_size, a, a_size, pmat, nrows, ncols);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (int8_t* input) */
|
||||
EXPORT void zn32_vmp_apply_i8(const MOD_Z* module, int32_t* res, uint64_t res_size, const int8_t* a, uint64_t a_size,
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||
module->vtable.zn32_vmp_apply_i8(module, res, res_size, a, a_size, pmat, nrows, ncols);
|
||||
}
|
||||
|
||||
/** reduction mod 1, output in torus32 space */
|
||||
EXPORT void dbl_to_tn32(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
module->vtable.dbl_to_tn32(module, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** real centerlift mod 1, output in double space */
|
||||
EXPORT void tn32_to_dbl(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
) {
|
||||
module->vtable.tn32_to_dbl(module, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** round to the nearest int, output in i32 space */
|
||||
EXPORT void dbl_round_to_i32(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
module->vtable.dbl_round_to_i32(module, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** small int (int32 space) to double */
|
||||
EXPORT void i32_to_dbl(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
) {
|
||||
module->vtable.i32_to_dbl(module, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** round to the nearest int, output in int64 space */
|
||||
EXPORT void dbl_round_to_i64(const MOD_Z* module, //
|
||||
int64_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
module->vtable.dbl_round_to_i64(module, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** small int (int64 space, <= 2^50) to double */
|
||||
EXPORT void i64_to_dbl(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size // a
|
||||
) {
|
||||
module->vtable.i64_to_dbl(module, res, res_size, a, a_size);
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
#include <memory.h>
|
||||
|
||||
#include "zn_arithmetic_private.h"
|
||||
|
||||
EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, //
|
||||
uint64_t k, uint64_t ell) {
|
||||
if (k * ell > 50) {
|
||||
return spqlios_error("approx decomposition requested is too precise for doubles");
|
||||
}
|
||||
if (k < 1) {
|
||||
return spqlios_error("approx decomposition supports k>=1");
|
||||
}
|
||||
TNDBL_APPROXDECOMP_GADGET* res = malloc(sizeof(TNDBL_APPROXDECOMP_GADGET));
|
||||
memset(res, 0, sizeof(TNDBL_APPROXDECOMP_GADGET));
|
||||
res->k = k;
|
||||
res->ell = ell;
|
||||
double add_cst = INT64_C(3) << (51 - k * ell);
|
||||
for (uint64_t i = 0; i < ell; ++i) {
|
||||
add_cst += pow(2., -(double)(i * k + 1));
|
||||
}
|
||||
res->add_cst = add_cst;
|
||||
res->and_mask = (UINT64_C(1) << k) - 1;
|
||||
res->sub_cst = UINT64_C(1) << (k - 1);
|
||||
for (uint64_t i = 0; i < ell; ++i) res->rshifts[i] = (ell - 1 - i) * k;
|
||||
return res;
|
||||
}
|
||||
EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr) { free(ptr); }
|
||||
|
||||
EXPORT int default_init_tndbl_approxdecomp_gadget(const MOD_Z* module, //
|
||||
TNDBL_APPROXDECOMP_GADGET* res, //
|
||||
uint64_t k, uint64_t ell) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
typedef union {
|
||||
double dv;
|
||||
uint64_t uv;
|
||||
} du_t;
|
||||
|
||||
#define IMPL_ixx_approxdecomp_from_tndbl_ref(ITYPE) \
|
||||
if (res_size != a_size * gadget->ell) NOT_IMPLEMENTED(); \
|
||||
const uint64_t ell = gadget->ell; \
|
||||
const double add_cst = gadget->add_cst; \
|
||||
const uint8_t* const rshifts = gadget->rshifts; \
|
||||
const ITYPE and_mask = gadget->and_mask; \
|
||||
const ITYPE sub_cst = gadget->sub_cst; \
|
||||
ITYPE* rr = res; \
|
||||
const double* aa = a; \
|
||||
const double* aaend = a + a_size; \
|
||||
while (aa < aaend) { \
|
||||
du_t t = {.dv = *aa + add_cst}; \
|
||||
for (uint64_t i = 0; i < ell; ++i) { \
|
||||
ITYPE v = (ITYPE)(t.uv >> rshifts[i]); \
|
||||
*rr = (v & and_mask) - sub_cst; \
|
||||
++rr; \
|
||||
} \
|
||||
++aa; \
|
||||
}
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||
EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size //
|
||||
){IMPL_ixx_approxdecomp_from_tndbl_ref(int8_t)}
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||
EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int16_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
){IMPL_ixx_approxdecomp_from_tndbl_ref(int16_t)}
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||
EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
IMPL_ixx_approxdecomp_from_tndbl_ref(int32_t)
|
||||
}
|
||||
@@ -1,147 +0,0 @@
|
||||
#ifndef SPQLIOS_ZN_ARITHMETIC_H
|
||||
#define SPQLIOS_ZN_ARITHMETIC_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../commons.h"
|
||||
|
||||
typedef enum z_module_type_t { DEFAULT } Z_MODULE_TYPE;
|
||||
|
||||
/** @brief opaque structure that describes the module and the hardware */
|
||||
typedef struct z_module_info_t MOD_Z;
|
||||
|
||||
/**
|
||||
* @brief obtain a module info for ring dimension N
|
||||
* the module-info knows about:
|
||||
* - the dimension N (or the complex dimension m=N/2)
|
||||
* - any moduleuted fft or ntt items
|
||||
* - the hardware (avx, arm64, x86, ...)
|
||||
*/
|
||||
EXPORT MOD_Z* new_z_module_info(Z_MODULE_TYPE mode);
|
||||
EXPORT void delete_z_module_info(MOD_Z* module_info);
|
||||
|
||||
typedef struct tndbl_approxdecomp_gadget_t TNDBL_APPROXDECOMP_GADGET;
|
||||
|
||||
EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, //
|
||||
uint64_t k,
|
||||
uint64_t ell); // base 2^k, and size
|
||||
|
||||
EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr);
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||
EXPORT void i8_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size); // a
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||
EXPORT void i16_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int16_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size); // a
|
||||
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||
EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int32_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size); // a
|
||||
|
||||
/** @brief opaque type that represents a prepared matrix */
|
||||
typedef struct zn32_vmp_pmat_t ZN32_VMP_PMAT;
|
||||
|
||||
/** @brief size in bytes of a prepared matrix (for custom allocation) */
|
||||
EXPORT uint64_t bytes_of_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||
uint64_t nrows, uint64_t ncols); // dimensions
|
||||
|
||||
/** @brief allocates a prepared matrix (release with delete_zn32_vmp_pmat) */
|
||||
EXPORT ZN32_VMP_PMAT* new_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||
uint64_t nrows, uint64_t ncols); // dimensions
|
||||
|
||||
/** @brief deletes a prepared matrix (release with free) */
|
||||
EXPORT void delete_zn32_vmp_pmat(ZN32_VMP_PMAT* ptr); // dimensions
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void zn32_vmp_prepare_contiguous( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* mat, uint64_t nrows, uint64_t ncols); // a
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void zn32_vmp_prepare_dblptr( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t** mat, uint64_t nrows, uint64_t ncols); // a
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void zn32_vmp_prepare_row( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols); // a
|
||||
|
||||
/** @brief applies a vmp product (int32_t* input) */
|
||||
EXPORT void zn32_vmp_apply_i32( //
|
||||
const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
/** @brief applies a vmp product (int16_t* input) */
|
||||
EXPORT void zn32_vmp_apply_i16( //
|
||||
const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int16_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
/** @brief applies a vmp product (int8_t* input) */
|
||||
EXPORT void zn32_vmp_apply_i8( //
|
||||
const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int8_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
// explicit conversions
|
||||
|
||||
/** reduction mod 1, output in torus32 space */
|
||||
EXPORT void dbl_to_tn32(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** real centerlift mod 1, output in double space */
|
||||
EXPORT void tn32_to_dbl(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** round to the nearest int, output in i32 space.
|
||||
* WARNING: ||a||_inf must be <= 2^18 in this function
|
||||
*/
|
||||
EXPORT void dbl_round_to_i32(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** small int (int32 space) to double
|
||||
* WARNING: ||a||_inf must be <= 2^18 in this function
|
||||
*/
|
||||
EXPORT void i32_to_dbl(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** round to the nearest int, output in int64 space
|
||||
* WARNING: ||a||_inf must be <= 2^50 in this function
|
||||
*/
|
||||
EXPORT void dbl_round_to_i64(const MOD_Z* module, //
|
||||
int64_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** small int (int64 space, <= 2^50) to double
|
||||
* WARNING: ||a||_inf must be <= 2^50 in this function
|
||||
*/
|
||||
EXPORT void i64_to_dbl(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
#endif // SPQLIOS_ZN_ARITHMETIC_H
|
||||
@@ -1,43 +0,0 @@
|
||||
#ifndef SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
|
||||
#define SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
|
||||
|
||||
#include "zn_arithmetic.h"
|
||||
|
||||
typedef typeof(i8_approxdecomp_from_tndbl) I8_APPROXDECOMP_FROM_TNDBL_F;
|
||||
typedef typeof(i16_approxdecomp_from_tndbl) I16_APPROXDECOMP_FROM_TNDBL_F;
|
||||
typedef typeof(i32_approxdecomp_from_tndbl) I32_APPROXDECOMP_FROM_TNDBL_F;
|
||||
typedef typeof(bytes_of_zn32_vmp_pmat) BYTES_OF_ZN32_VMP_PMAT_F;
|
||||
typedef typeof(zn32_vmp_prepare_contiguous) ZN32_VMP_PREPARE_CONTIGUOUS_F;
|
||||
typedef typeof(zn32_vmp_prepare_dblptr) ZN32_VMP_PREPARE_DBLPTR_F;
|
||||
typedef typeof(zn32_vmp_prepare_row) ZN32_VMP_PREPARE_ROW_F;
|
||||
typedef typeof(zn32_vmp_apply_i32) ZN32_VMP_APPLY_I32_F;
|
||||
typedef typeof(zn32_vmp_apply_i16) ZN32_VMP_APPLY_I16_F;
|
||||
typedef typeof(zn32_vmp_apply_i8) ZN32_VMP_APPLY_I8_F;
|
||||
typedef typeof(dbl_to_tn32) DBL_TO_TN32_F;
|
||||
typedef typeof(tn32_to_dbl) TN32_TO_DBL_F;
|
||||
typedef typeof(dbl_round_to_i32) DBL_ROUND_TO_I32_F;
|
||||
typedef typeof(i32_to_dbl) I32_TO_DBL_F;
|
||||
typedef typeof(dbl_round_to_i64) DBL_ROUND_TO_I64_F;
|
||||
typedef typeof(i64_to_dbl) I64_TO_DBL_F;
|
||||
|
||||
typedef struct z_module_vtable_t Z_MODULE_VTABLE;
|
||||
struct z_module_vtable_t {
|
||||
I8_APPROXDECOMP_FROM_TNDBL_F* i8_approxdecomp_from_tndbl;
|
||||
I16_APPROXDECOMP_FROM_TNDBL_F* i16_approxdecomp_from_tndbl;
|
||||
I32_APPROXDECOMP_FROM_TNDBL_F* i32_approxdecomp_from_tndbl;
|
||||
BYTES_OF_ZN32_VMP_PMAT_F* bytes_of_zn32_vmp_pmat;
|
||||
ZN32_VMP_PREPARE_CONTIGUOUS_F* zn32_vmp_prepare_contiguous;
|
||||
ZN32_VMP_PREPARE_DBLPTR_F* zn32_vmp_prepare_dblptr;
|
||||
ZN32_VMP_PREPARE_ROW_F* zn32_vmp_prepare_row;
|
||||
ZN32_VMP_APPLY_I32_F* zn32_vmp_apply_i32;
|
||||
ZN32_VMP_APPLY_I16_F* zn32_vmp_apply_i16;
|
||||
ZN32_VMP_APPLY_I8_F* zn32_vmp_apply_i8;
|
||||
DBL_TO_TN32_F* dbl_to_tn32;
|
||||
TN32_TO_DBL_F* tn32_to_dbl;
|
||||
DBL_ROUND_TO_I32_F* dbl_round_to_i32;
|
||||
I32_TO_DBL_F* i32_to_dbl;
|
||||
DBL_ROUND_TO_I64_F* dbl_round_to_i64;
|
||||
I64_TO_DBL_F* i64_to_dbl;
|
||||
};
|
||||
|
||||
#endif // SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
|
||||
@@ -1,164 +0,0 @@
|
||||
#ifndef SPQLIOS_ZN_ARITHMETIC_PRIVATE_H
|
||||
#define SPQLIOS_ZN_ARITHMETIC_PRIVATE_H
|
||||
|
||||
#include "../commons_private.h"
|
||||
#include "zn_arithmetic.h"
|
||||
#include "zn_arithmetic_plugin.h"
|
||||
|
||||
typedef struct main_z_module_precomp_t MAIN_Z_MODULE_PRECOMP;
|
||||
struct main_z_module_precomp_t {
|
||||
// TODO
|
||||
};
|
||||
|
||||
typedef union z_module_precomp_t Z_MODULE_PRECOMP;
|
||||
union z_module_precomp_t {
|
||||
MAIN_Z_MODULE_PRECOMP main;
|
||||
};
|
||||
|
||||
void main_init_z_module_precomp(MOD_Z* module);
|
||||
|
||||
void main_finalize_z_module_precomp(MOD_Z* module);
|
||||
|
||||
/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */
|
||||
struct z_module_info_t {
|
||||
Z_MODULE_TYPE mtype;
|
||||
Z_MODULE_VTABLE vtable;
|
||||
Z_MODULE_PRECOMP precomp;
|
||||
void* custom;
|
||||
void (*custom_deleter)(void*);
|
||||
};
|
||||
|
||||
void init_z_module_info(MOD_Z* module, Z_MODULE_TYPE mtype);
|
||||
|
||||
void main_init_z_module_vtable(MOD_Z* module);
|
||||
|
||||
struct tndbl_approxdecomp_gadget_t {
|
||||
uint64_t k;
|
||||
uint64_t ell;
|
||||
double add_cst; // 3.2^51-(K.ell) + 1/2.(sum 2^-(i+1)K)
|
||||
int64_t and_mask; // (2^K)-1
|
||||
int64_t sub_cst; // 2^(K-1)
|
||||
uint8_t rshifts[64]; // 2^(ell-1-i).K for i in [0:ell-1]
|
||||
};
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||
EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size); // a
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||
EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int16_t* res,
|
||||
uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size); // a
|
||||
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||
EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int32_t* res,
|
||||
uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size); // a
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void default_zn32_vmp_prepare_contiguous_ref( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* mat, uint64_t nrows, uint64_t ncols // a
|
||||
);
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void default_zn32_vmp_prepare_dblptr_ref( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t** mat, uint64_t nrows, uint64_t ncols // a
|
||||
);
|
||||
|
||||
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||
EXPORT void default_zn32_vmp_prepare_row_ref( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols // a
|
||||
);
|
||||
|
||||
/** @brief applies a vmp product (int32_t* input) */
|
||||
EXPORT void default_zn32_vmp_apply_i32_ref( //
|
||||
const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
/** @brief applies a vmp product (int16_t* input) */
|
||||
EXPORT void default_zn32_vmp_apply_i16_ref( //
|
||||
const MOD_Z* module, // N
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int16_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
/** @brief applies a vmp product (int8_t* input) */
|
||||
EXPORT void default_zn32_vmp_apply_i8_ref( //
|
||||
const MOD_Z* module, // N
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int8_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
/** @brief applies a vmp product (int32_t* input) */
|
||||
EXPORT void default_zn32_vmp_apply_i32_avx( //
|
||||
const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
/** @brief applies a vmp product (int16_t* input) */
|
||||
EXPORT void default_zn32_vmp_apply_i16_avx( //
|
||||
const MOD_Z* module, // N
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int16_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
/** @brief applies a vmp product (int8_t* input) */
|
||||
EXPORT void default_zn32_vmp_apply_i8_avx( //
|
||||
const MOD_Z* module, // N
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int8_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
// explicit conversions
|
||||
|
||||
/** reduction mod 1, output in torus32 space */
|
||||
EXPORT void dbl_to_tn32_ref(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** real centerlift mod 1, output in double space */
|
||||
EXPORT void tn32_to_dbl_ref(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** round to the nearest int, output in i32 space */
|
||||
EXPORT void dbl_round_to_i32_ref(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** small int (int32 space) to double */
|
||||
EXPORT void i32_to_dbl_ref(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** round to the nearest int, output in int64 space */
|
||||
EXPORT void dbl_round_to_i64_ref(const MOD_Z* module, //
|
||||
int64_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** small int (int64 space) to double */
|
||||
EXPORT void i64_to_dbl_ref(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
#endif // SPQLIOS_ZN_ARITHMETIC_PRIVATE_H
|
||||
@@ -1,108 +0,0 @@
|
||||
#include <memory.h>
|
||||
|
||||
#include "zn_arithmetic_private.h"
|
||||
|
||||
typedef union {
|
||||
double dv;
|
||||
int64_t s64v;
|
||||
int32_t s32v;
|
||||
uint64_t u64v;
|
||||
uint32_t u32v;
|
||||
} di_t;
|
||||
|
||||
/** reduction mod 1, output in torus32 space */
|
||||
EXPORT void dbl_to_tn32_ref(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
static const double ADD_CST = 0.5 + (double)(INT64_C(3) << (51 - 32));
|
||||
static const int32_t XOR_CST = (INT32_C(1) << 31);
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
di_t t = {.dv = a[i] + ADD_CST};
|
||||
res[i] = t.s32v ^ XOR_CST;
|
||||
}
|
||||
memset(res + msize, 0, (res_size - msize) * sizeof(int32_t));
|
||||
}
|
||||
|
||||
/** real centerlift mod 1, output in double space */
|
||||
EXPORT void tn32_to_dbl_ref(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
) {
|
||||
static const uint32_t XOR_CST = (UINT32_C(1) << 31);
|
||||
static const di_t OR_CST = {.dv = (double)(INT64_C(1) << (52 - 32))};
|
||||
static const double SUB_CST = 0.5 + (double)(INT64_C(1) << (52 - 32));
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
uint32_t ai = a[i] ^ XOR_CST;
|
||||
di_t t = {.u64v = OR_CST.u64v | (uint64_t)ai};
|
||||
res[i] = t.dv - SUB_CST;
|
||||
}
|
||||
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||
}
|
||||
|
||||
/** round to the nearest int, output in i32 space */
|
||||
EXPORT void dbl_round_to_i32_ref(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
static const double ADD_CST = (double)((INT64_C(3) << (51)) + (INT64_C(1) << (31)));
|
||||
static const int32_t XOR_CST = INT32_C(1) << 31;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
di_t t = {.dv = a[i] + ADD_CST};
|
||||
res[i] = t.s32v ^ XOR_CST;
|
||||
}
|
||||
memset(res + msize, 0, (res_size - msize) * sizeof(int32_t));
|
||||
}
|
||||
|
||||
/** small int (int32 space) to double */
|
||||
EXPORT void i32_to_dbl_ref(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
) {
|
||||
static const uint32_t XOR_CST = (UINT32_C(1) << 31);
|
||||
static const di_t OR_CST = {.dv = (double)(INT64_C(1) << 52)};
|
||||
static const double SUB_CST = (double)((INT64_C(1) << 52) + (INT64_C(1) << 31));
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
uint32_t ai = a[i] ^ XOR_CST;
|
||||
di_t t = {.u64v = OR_CST.u64v | (uint64_t)ai};
|
||||
res[i] = t.dv - SUB_CST;
|
||||
}
|
||||
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||
}
|
||||
|
||||
/** round to the nearest int, output in int64 space */
|
||||
EXPORT void dbl_round_to_i64_ref(const MOD_Z* module, //
|
||||
int64_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
static const double ADD_CST = (double)(INT64_C(3) << (51));
|
||||
static const int64_t AND_CST = (INT64_C(1) << 52) - 1;
|
||||
static const int64_t SUB_CST = INT64_C(1) << 51;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
di_t t = {.dv = a[i] + ADD_CST};
|
||||
res[i] = (t.s64v & AND_CST) - SUB_CST;
|
||||
}
|
||||
memset(res + msize, 0, (res_size - msize) * sizeof(int64_t));
|
||||
}
|
||||
|
||||
/** small int (int64 space) to double */
|
||||
EXPORT void i64_to_dbl_ref(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size // a
|
||||
) {
|
||||
static const uint64_t ADD_CST = UINT64_C(1) << 51;
|
||||
static const uint64_t AND_CST = (UINT64_C(1) << 52) - 1;
|
||||
static const di_t OR_CST = {.dv = (INT64_C(1) << 52)};
|
||||
static const double SUB_CST = INT64_C(3) << 51;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
di_t t = {.u64v = ((a[i] + ADD_CST) & AND_CST) | OR_CST.u64v};
|
||||
res[i] = t.dv - SUB_CST;
|
||||
}
|
||||
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
#define INTTYPE int16_t
|
||||
#define INTSN i16
|
||||
|
||||
#include "zn_vmp_int32_avx.c"
|
||||
@@ -1,4 +0,0 @@
|
||||
#define INTTYPE int16_t
|
||||
#define INTSN i16
|
||||
|
||||
#include "zn_vmp_int32_ref.c"
|
||||
@@ -1,223 +0,0 @@
|
||||
// This file is actually a template: it will be compiled multiple times with
|
||||
// different INTTYPES
|
||||
#ifndef INTTYPE
|
||||
#define INTTYPE int32_t
|
||||
#define INTSN i32
|
||||
#endif
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <memory.h>
|
||||
|
||||
#include "zn_arithmetic_private.h"
|
||||
|
||||
#define concat_inner(aa, bb, cc) aa##_##bb##_##cc
|
||||
#define concat(aa, bb, cc) concat_inner(aa, bb, cc)
|
||||
#define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc)
|
||||
|
||||
static void zn32_vec_mat32cols_avx_prefetch(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b) {
|
||||
if (nrows == 0) {
|
||||
memset(res, 0, 32 * sizeof(int32_t));
|
||||
return;
|
||||
}
|
||||
const int32_t* bb = b;
|
||||
const int32_t* pref_bb = b;
|
||||
const uint64_t pref_iters = 128;
|
||||
const uint64_t pref_start = pref_iters < nrows ? pref_iters : nrows;
|
||||
const uint64_t pref_last = pref_iters > nrows ? 0 : nrows - pref_iters;
|
||||
// let's do some prefetching of the GSW key, since on some cpus,
|
||||
// it helps
|
||||
for (uint64_t i = 0; i < pref_start; ++i) {
|
||||
__builtin_prefetch(pref_bb, 0, _MM_HINT_T0);
|
||||
__builtin_prefetch(pref_bb + 16, 0, _MM_HINT_T0);
|
||||
pref_bb += 32;
|
||||
}
|
||||
// we do the first iteration
|
||||
__m256i x = _mm256_set1_epi32(a[0]);
|
||||
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||
__m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)));
|
||||
__m256i r3 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)));
|
||||
bb += 32;
|
||||
uint64_t row = 1;
|
||||
for (; //
|
||||
row < pref_last; //
|
||||
++row, bb += 32) {
|
||||
// prefetch the next iteration
|
||||
__builtin_prefetch(pref_bb, 0, _MM_HINT_T0);
|
||||
__builtin_prefetch(pref_bb + 16, 0, _MM_HINT_T0);
|
||||
pref_bb += 32;
|
||||
INTTYPE ai = a[row];
|
||||
if (ai == 0) continue;
|
||||
x = _mm256_set1_epi32(ai);
|
||||
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||
r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))));
|
||||
}
|
||||
for (; //
|
||||
row < nrows; //
|
||||
++row, bb += 32) {
|
||||
INTTYPE ai = a[row];
|
||||
if (ai == 0) continue;
|
||||
x = _mm256_set1_epi32(ai);
|
||||
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||
r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))));
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||
_mm256_storeu_si256((__m256i*)(res + 16), r2);
|
||||
_mm256_storeu_si256((__m256i*)(res + 24), r3);
|
||||
}
|
||||
|
||||
void zn32_vec_fn(mat32cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
if (nrows == 0) {
|
||||
memset(res, 0, 32 * sizeof(int32_t));
|
||||
return;
|
||||
}
|
||||
const INTTYPE* aa = a;
|
||||
const INTTYPE* const aaend = a + nrows;
|
||||
const int32_t* bb = b;
|
||||
__m256i x = _mm256_set1_epi32(*aa);
|
||||
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||
__m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)));
|
||||
__m256i r3 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)));
|
||||
bb += b_sl;
|
||||
++aa;
|
||||
for (; //
|
||||
aa < aaend; //
|
||||
bb += b_sl, ++aa) {
|
||||
INTTYPE ai = *aa;
|
||||
if (ai == 0) continue;
|
||||
x = _mm256_set1_epi32(ai);
|
||||
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||
r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))));
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||
_mm256_storeu_si256((__m256i*)(res + 16), r2);
|
||||
_mm256_storeu_si256((__m256i*)(res + 24), r3);
|
||||
}
|
||||
|
||||
void zn32_vec_fn(mat24cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
if (nrows == 0) {
|
||||
memset(res, 0, 24 * sizeof(int32_t));
|
||||
return;
|
||||
}
|
||||
const INTTYPE* aa = a;
|
||||
const INTTYPE* const aaend = a + nrows;
|
||||
const int32_t* bb = b;
|
||||
__m256i x = _mm256_set1_epi32(*aa);
|
||||
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||
__m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)));
|
||||
bb += b_sl;
|
||||
++aa;
|
||||
for (; //
|
||||
aa < aaend; //
|
||||
bb += b_sl, ++aa) {
|
||||
INTTYPE ai = *aa;
|
||||
if (ai == 0) continue;
|
||||
x = _mm256_set1_epi32(ai);
|
||||
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||
_mm256_storeu_si256((__m256i*)(res + 16), r2);
|
||||
}
|
||||
void zn32_vec_fn(mat16cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
if (nrows == 0) {
|
||||
memset(res, 0, 16 * sizeof(int32_t));
|
||||
return;
|
||||
}
|
||||
const INTTYPE* aa = a;
|
||||
const INTTYPE* const aaend = a + nrows;
|
||||
const int32_t* bb = b;
|
||||
__m256i x = _mm256_set1_epi32(*aa);
|
||||
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||
bb += b_sl;
|
||||
++aa;
|
||||
for (; //
|
||||
aa < aaend; //
|
||||
bb += b_sl, ++aa) {
|
||||
INTTYPE ai = *aa;
|
||||
if (ai == 0) continue;
|
||||
x = _mm256_set1_epi32(ai);
|
||||
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||
}
|
||||
|
||||
void zn32_vec_fn(mat8cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
if (nrows == 0) {
|
||||
memset(res, 0, 8 * sizeof(int32_t));
|
||||
return;
|
||||
}
|
||||
const INTTYPE* aa = a;
|
||||
const INTTYPE* const aaend = a + nrows;
|
||||
const int32_t* bb = b;
|
||||
__m256i x = _mm256_set1_epi32(*aa);
|
||||
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||
bb += b_sl;
|
||||
++aa;
|
||||
for (; //
|
||||
aa < aaend; //
|
||||
bb += b_sl, ++aa) {
|
||||
INTTYPE ai = *aa;
|
||||
if (ai == 0) continue;
|
||||
x = _mm256_set1_epi32(ai);
|
||||
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||
}
|
||||
|
||||
typedef void (*vm_f)(uint64_t nrows, //
|
||||
int32_t* res, //
|
||||
const INTTYPE* a, //
|
||||
const int32_t* b, uint64_t b_sl //
|
||||
);
|
||||
static const vm_f zn32_vec_mat8kcols_avx[4] = { //
|
||||
zn32_vec_fn(mat8cols_avx), //
|
||||
zn32_vec_fn(mat16cols_avx), //
|
||||
zn32_vec_fn(mat24cols_avx), //
|
||||
zn32_vec_fn(mat32cols_avx)};
|
||||
|
||||
/** @brief applies a vmp product (int32_t* input) */
|
||||
EXPORT void concat(default_zn32_vmp_apply, INTSN, avx)( //
|
||||
const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, //
|
||||
const INTTYPE* a, uint64_t a_size, //
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||
const uint64_t rows = a_size < nrows ? a_size : nrows;
|
||||
const uint64_t cols = res_size < ncols ? res_size : ncols;
|
||||
const uint64_t ncolblk = cols >> 5;
|
||||
const uint64_t ncolrem = cols & 31;
|
||||
// copy the first full blocks
|
||||
const uint64_t full_blk_size = nrows * 32;
|
||||
const int32_t* mat = (int32_t*)pmat;
|
||||
int32_t* rr = res;
|
||||
for (uint64_t blk = 0; //
|
||||
blk < ncolblk; //
|
||||
++blk, mat += full_blk_size, rr += 32) {
|
||||
zn32_vec_mat32cols_avx_prefetch(rows, rr, a, mat);
|
||||
}
|
||||
// last block
|
||||
if (ncolrem) {
|
||||
uint64_t orig_rem = ncols - (ncolblk << 5);
|
||||
uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem;
|
||||
int32_t tmp[32];
|
||||
zn32_vec_mat8kcols_avx[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl);
|
||||
memcpy(rr, tmp, ncolrem * sizeof(int32_t));
|
||||
}
|
||||
// trailing bytes
|
||||
memset(res + cols, 0, (res_size - cols) * sizeof(int32_t));
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
// This file is actually a template: it will be compiled multiple times with
|
||||
// different INTTYPES
|
||||
#ifndef INTTYPE
|
||||
#define INTTYPE int32_t
|
||||
#define INTSN i32
|
||||
#endif
|
||||
|
||||
#include <memory.h>
|
||||
|
||||
#include "zn_arithmetic_private.h"
|
||||
|
||||
#define concat_inner(aa, bb, cc) aa##_##bb##_##cc
|
||||
#define concat(aa, bb, cc) concat_inner(aa, bb, cc)
|
||||
#define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc)
|
||||
|
||||
// the ref version shares the same implementation for each fixed column size
|
||||
// optimized implementations may do something different.
|
||||
static __always_inline void IMPL_zn32_vec_matcols_ref(
|
||||
const uint64_t NCOLS, // fixed number of columns
|
||||
uint64_t nrows, // nrows of b
|
||||
int32_t* res, // result: size NCOLS, only the first min(b_sl, NCOLS) are relevant
|
||||
const INTTYPE* a, // a: nrows-sized vector
|
||||
const int32_t* b, uint64_t b_sl // b: nrows * min(b_sl, NCOLS) matrix
|
||||
) {
|
||||
memset(res, 0, NCOLS * sizeof(int32_t));
|
||||
for (uint64_t row = 0; row < nrows; ++row) {
|
||||
int32_t ai = a[row];
|
||||
const int32_t* bb = b + row * b_sl;
|
||||
for (uint64_t i = 0; i < NCOLS; ++i) {
|
||||
res[i] += ai * bb[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void zn32_vec_fn(mat32cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_matcols_ref(32, nrows, res, a, b, b_sl);
|
||||
}
|
||||
void zn32_vec_fn(mat24cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_matcols_ref(24, nrows, res, a, b, b_sl);
|
||||
}
|
||||
void zn32_vec_fn(mat16cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_matcols_ref(16, nrows, res, a, b, b_sl);
|
||||
}
|
||||
void zn32_vec_fn(mat8cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_matcols_ref(8, nrows, res, a, b, b_sl);
|
||||
}
|
||||
|
||||
typedef void (*vm_f)(uint64_t nrows, //
|
||||
int32_t* res, //
|
||||
const INTTYPE* a, //
|
||||
const int32_t* b, uint64_t b_sl //
|
||||
);
|
||||
static const vm_f zn32_vec_mat8kcols_ref[4] = { //
|
||||
zn32_vec_fn(mat8cols_ref), //
|
||||
zn32_vec_fn(mat16cols_ref), //
|
||||
zn32_vec_fn(mat24cols_ref), //
|
||||
zn32_vec_fn(mat32cols_ref)};
|
||||
|
||||
/** @brief applies a vmp product (int32_t* input) */
|
||||
EXPORT void concat(default_zn32_vmp_apply, INTSN, ref)( //
|
||||
const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, //
|
||||
const INTTYPE* a, uint64_t a_size, //
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||
const uint64_t rows = a_size < nrows ? a_size : nrows;
|
||||
const uint64_t cols = res_size < ncols ? res_size : ncols;
|
||||
const uint64_t ncolblk = cols >> 5;
|
||||
const uint64_t ncolrem = cols & 31;
|
||||
// copy the first full blocks
|
||||
const uint32_t full_blk_size = nrows * 32;
|
||||
const int32_t* mat = (int32_t*)pmat;
|
||||
int32_t* rr = res;
|
||||
for (uint64_t blk = 0; //
|
||||
blk < ncolblk; //
|
||||
++blk, mat += full_blk_size, rr += 32) {
|
||||
zn32_vec_fn(mat32cols_ref)(rows, rr, a, mat, 32);
|
||||
}
|
||||
// last block
|
||||
if (ncolrem) {
|
||||
uint64_t orig_rem = ncols - (ncolblk << 5);
|
||||
uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem;
|
||||
int32_t tmp[32];
|
||||
zn32_vec_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl);
|
||||
memcpy(rr, tmp, ncolrem * sizeof(int32_t));
|
||||
}
|
||||
// trailing bytes
|
||||
memset(res + cols, 0, (res_size - cols) * sizeof(int32_t));
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
#define INTTYPE int8_t
|
||||
#define INTSN i8
|
||||
|
||||
#include "zn_vmp_int32_avx.c"
|
||||
@@ -1,4 +0,0 @@
|
||||
#define INTTYPE int8_t
|
||||
#define INTSN i8
|
||||
|
||||
#include "zn_vmp_int32_ref.c"
|
||||
@@ -1,185 +0,0 @@
|
||||
#include <memory.h>
|
||||
|
||||
#include "zn_arithmetic_private.h"
|
||||
|
||||
/** @brief size in bytes of a prepared matrix (for custom allocation) */
|
||||
EXPORT uint64_t bytes_of_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||
uint64_t nrows, uint64_t ncols // dimensions
|
||||
) {
|
||||
return (nrows * ncols + 7) * sizeof(int32_t);
|
||||
}
|
||||
|
||||
/** @brief allocates a prepared matrix (release with delete_zn32_vmp_pmat) */
|
||||
EXPORT ZN32_VMP_PMAT* new_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||
uint64_t nrows, uint64_t ncols) {
|
||||
return (ZN32_VMP_PMAT*)spqlios_alloc(bytes_of_zn32_vmp_pmat(module, nrows, ncols));
|
||||
}
|
||||
|
||||
/** @brief deletes a prepared matrix (release with free) */
|
||||
EXPORT void delete_zn32_vmp_pmat(ZN32_VMP_PMAT* ptr) { spqlios_free(ptr); }
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void default_zn32_vmp_prepare_contiguous_ref( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* mat, uint64_t nrows, uint64_t ncols // a
|
||||
) {
|
||||
int32_t* const out = (int32_t*)pmat;
|
||||
const uint64_t nblk = ncols >> 5;
|
||||
const uint64_t ncols_rem = ncols & 31;
|
||||
const uint64_t final_elems = (8 - nrows * ncols) & 7;
|
||||
for (uint64_t blk = 0; blk < nblk; ++blk) {
|
||||
int32_t* outblk = out + blk * nrows * 32;
|
||||
const int32_t* srcblk = mat + blk * 32;
|
||||
for (uint64_t row = 0; row < nrows; ++row) {
|
||||
int32_t* dest = outblk + row * 32;
|
||||
const int32_t* src = srcblk + row * ncols;
|
||||
for (uint64_t i = 0; i < 32; ++i) {
|
||||
dest[i] = src[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
// copy the last block if any
|
||||
if (ncols_rem) {
|
||||
int32_t* outblk = out + nblk * nrows * 32;
|
||||
const int32_t* srcblk = mat + nblk * 32;
|
||||
for (uint64_t row = 0; row < nrows; ++row) {
|
||||
int32_t* dest = outblk + row * ncols_rem;
|
||||
const int32_t* src = srcblk + row * ncols;
|
||||
for (uint64_t i = 0; i < ncols_rem; ++i) {
|
||||
dest[i] = src[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
// zero-out the final elements that may be accessed
|
||||
if (final_elems) {
|
||||
int32_t* f = out + nrows * ncols;
|
||||
for (uint64_t i = 0; i < final_elems; ++i) {
|
||||
f[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void default_zn32_vmp_prepare_dblptr_ref( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t** mat, uint64_t nrows, uint64_t ncols // a
|
||||
) {
|
||||
for (uint64_t row_i = 0; row_i < nrows; ++row_i) {
|
||||
default_zn32_vmp_prepare_row_ref(module, pmat, mat[row_i], row_i, nrows, ncols);
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||
EXPORT void default_zn32_vmp_prepare_row_ref( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols // a
|
||||
) {
|
||||
int32_t* const out = (int32_t*)pmat;
|
||||
const uint64_t nblk = ncols >> 5;
|
||||
const uint64_t ncols_rem = ncols & 31;
|
||||
const uint64_t final_elems = (row_i == nrows - 1) && (8 - nrows * ncols) & 7;
|
||||
for (uint64_t blk = 0; blk < nblk; ++blk) {
|
||||
int32_t* outblk = out + blk * nrows * 32;
|
||||
int32_t* dest = outblk + row_i * 32;
|
||||
const int32_t* src = row + blk * 32;
|
||||
for (uint64_t i = 0; i < 32; ++i) {
|
||||
dest[i] = src[i];
|
||||
}
|
||||
}
|
||||
// copy the last block if any
|
||||
if (ncols_rem) {
|
||||
int32_t* outblk = out + nblk * nrows * 32;
|
||||
int32_t* dest = outblk + row_i * ncols_rem;
|
||||
const int32_t* src = row + nblk * 32;
|
||||
for (uint64_t i = 0; i < ncols_rem; ++i) {
|
||||
dest[i] = src[i];
|
||||
}
|
||||
}
|
||||
// zero-out the final elements that may be accessed
|
||||
if (final_elems) {
|
||||
int32_t* f = out + nrows * ncols;
|
||||
for (uint64_t i = 0; i < final_elems; ++i) {
|
||||
f[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
|
||||
#define IMPL_zn32_vec_ixxx_matyyycols_ref(NCOLS) \
|
||||
memset(res, 0, NCOLS * sizeof(int32_t)); \
|
||||
for (uint64_t row = 0; row < nrows; ++row) { \
|
||||
int32_t ai = a[row]; \
|
||||
const int32_t* bb = b + row * b_sl; \
|
||||
for (uint64_t i = 0; i < NCOLS; ++i) { \
|
||||
res[i] += ai * bb[i]; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define IMPL_zn32_vec_ixxx_mat8cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(8)
|
||||
#define IMPL_zn32_vec_ixxx_mat16cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(16)
|
||||
#define IMPL_zn32_vec_ixxx_mat24cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(24)
|
||||
#define IMPL_zn32_vec_ixxx_mat32cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(32)
|
||||
|
||||
void zn32_vec_i8_mat32cols_ref(uint64_t nrows, int32_t* res, const int8_t* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_ixxx_mat32cols_ref()
|
||||
}
|
||||
void zn32_vec_i16_mat32cols_ref(uint64_t nrows, int32_t* res, const int16_t* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_ixxx_mat32cols_ref()
|
||||
}
|
||||
|
||||
void zn32_vec_i32_mat32cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_ixxx_mat32cols_ref()
|
||||
}
|
||||
void zn32_vec_i32_mat24cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_ixxx_mat24cols_ref()
|
||||
}
|
||||
void zn32_vec_i32_mat16cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_ixxx_mat16cols_ref()
|
||||
}
|
||||
void zn32_vec_i32_mat8cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_ixxx_mat8cols_ref()
|
||||
}
|
||||
typedef void (*zn32_vec_i32_mat8kcols_ref_f)(uint64_t nrows, //
|
||||
int32_t* res, //
|
||||
const int32_t* a, //
|
||||
const int32_t* b, uint64_t b_sl //
|
||||
);
|
||||
zn32_vec_i32_mat8kcols_ref_f zn32_vec_i32_mat8kcols_ref[4] = { //
|
||||
zn32_vec_i32_mat8cols_ref, zn32_vec_i32_mat16cols_ref, //
|
||||
zn32_vec_i32_mat24cols_ref, zn32_vec_i32_mat32cols_ref};
|
||||
|
||||
/** @brief applies a vmp product (int32_t* input) */
|
||||
EXPORT void default_zn32_vmp_apply_i32_ref(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, //
|
||||
const int32_t* a, uint64_t a_size, //
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||
const uint64_t rows = a_size < nrows ? a_size : nrows;
|
||||
const uint64_t cols = res_size < ncols ? res_size : ncols;
|
||||
const uint64_t ncolblk = cols >> 5;
|
||||
const uint64_t ncolrem = cols & 31;
|
||||
// copy the first full blocks
|
||||
const uint32_t full_blk_size = nrows * 32;
|
||||
const int32_t* mat = (int32_t*)pmat;
|
||||
int32_t* rr = res;
|
||||
for (uint64_t blk = 0; //
|
||||
blk < ncolblk; //
|
||||
++blk, mat += full_blk_size, rr += 32) {
|
||||
zn32_vec_i32_mat32cols_ref(rows, rr, a, mat, 32);
|
||||
}
|
||||
// last block
|
||||
if (ncolrem) {
|
||||
uint64_t orig_rem = ncols - (ncolblk << 5);
|
||||
uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem;
|
||||
int32_t tmp[32];
|
||||
zn32_vec_i32_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl);
|
||||
memcpy(rr, tmp, ncolrem * sizeof(int32_t));
|
||||
}
|
||||
// trailing bytes
|
||||
memset(res + cols, 0, (res_size - cols) * sizeof(int32_t));
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,38 +0,0 @@
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
/** @brief res = a * b : small integer polynomial product */
|
||||
EXPORT void fft64_znx_small_single_product(const MODULE* module, // N
|
||||
int64_t* res, // output
|
||||
const int64_t* a, // a
|
||||
const int64_t* b, // b
|
||||
uint8_t* tmp) {
|
||||
const uint64_t nn = module->nn;
|
||||
double* const ffta = (double*)tmp;
|
||||
double* const fftb = ((double*)tmp) + nn;
|
||||
reim_from_znx64(module->mod.fft64.p_conv, ffta, a);
|
||||
reim_from_znx64(module->mod.fft64.p_conv, fftb, b);
|
||||
reim_fft(module->mod.fft64.p_fft, ffta);
|
||||
reim_fft(module->mod.fft64.p_fft, fftb);
|
||||
reim_fftvec_mul_simple(module->m, ffta, ffta, fftb);
|
||||
reim_ifft(module->mod.fft64.p_ifft, ffta);
|
||||
reim_to_znx64(module->mod.fft64.p_reim_to_znx, res, ffta);
|
||||
}
|
||||
|
||||
/** @brief tmp bytes required for znx_small_single_product */
|
||||
EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn) {
|
||||
return 2 * nn * sizeof(double);
|
||||
}
|
||||
|
||||
/** @brief res = a * b : small integer polynomial product */
|
||||
EXPORT void znx_small_single_product(const MODULE* module, // N
|
||||
int64_t* res, // output
|
||||
const int64_t* a, // a
|
||||
const int64_t* b, // b
|
||||
uint8_t* tmp) {
|
||||
module->func.znx_small_single_product(module, res, a, b, tmp);
|
||||
}
|
||||
|
||||
/** @brief tmp bytes required for znx_small_single_product */
|
||||
EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn) {
|
||||
return module->func.znx_small_single_product_tmp_bytes(module, nn);
|
||||
}
|
||||
@@ -1,524 +0,0 @@
|
||||
#include "coeffs_arithmetic.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <memory.h>
|
||||
|
||||
/** res = a + b */
|
||||
EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
res[i] = a[i] + b[i];
|
||||
}
|
||||
}
|
||||
/** res = a - b */
|
||||
EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
res[i] = a[i] - b[i];
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) {
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
res[i] = -a[i];
|
||||
}
|
||||
}
|
||||
EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) { memcpy(res, a, nn * sizeof(int64_t)); }
|
||||
|
||||
EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res) { memset(res, 0, nn * sizeof(int64_t)); }
|
||||
|
||||
EXPORT void rnx_divide_by_m_ref(uint64_t n, double m, double* res, const double* a) {
|
||||
const double invm = 1. / m;
|
||||
for (uint64_t i = 0; i < n; ++i) {
|
||||
res[i] = a[i] * invm;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in) {
|
||||
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||
|
||||
if (a < nn) { // rotate to the left
|
||||
uint64_t nma = nn - a;
|
||||
// rotate first half
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = in[j + a];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
res[j] = -in[j - nma];
|
||||
}
|
||||
} else {
|
||||
a -= nn;
|
||||
uint64_t nma = nn - a;
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = -in[j + a];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
// rotate first half
|
||||
res[j] = in[j - nma];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
|
||||
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||
|
||||
if (a < nn) { // rotate to the left
|
||||
uint64_t nma = nn - a;
|
||||
// rotate first half
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = in[j + a];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
res[j] = -in[j - nma];
|
||||
}
|
||||
} else {
|
||||
a -= nn;
|
||||
uint64_t nma = nn - a;
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = -in[j + a];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
// rotate first half
|
||||
res[j] = in[j - nma];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void rnx_mul_xp_minus_one_f64(uint64_t nn, int64_t p, double* res, const double* in) {
|
||||
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||
if (a < nn) { // rotate to the left
|
||||
uint64_t nma = nn - a;
|
||||
// rotate first half
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = in[j + a] - in[j];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
res[j] = -in[j - nma] - in[j];
|
||||
}
|
||||
} else {
|
||||
a -= nn;
|
||||
uint64_t nma = nn - a;
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = -in[j + a] - in[j];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
// rotate first half
|
||||
res[j] = in[j - nma] - in[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void znx_mul_xp_minus_one_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
|
||||
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||
if (a < nn) { // rotate to the left
|
||||
uint64_t nma = nn - a;
|
||||
// rotate first half
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = in[j + a] - in[j];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
res[j] = -in[j - nma] - in[j];
|
||||
}
|
||||
} else {
|
||||
a -= nn;
|
||||
uint64_t nma = nn - a;
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = -in[j + a] - in[j];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
// rotate first half
|
||||
res[j] = in[j - nma] - in[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void znx_mul_xp_minus_one_inplace_i64(uint64_t nn, int64_t p, int64_t* res) {
|
||||
const uint64_t _2mn = 2 * nn - 1;
|
||||
const uint64_t _mn = nn - 1;
|
||||
uint64_t nb_modif = 0;
|
||||
uint64_t j_start = 0;
|
||||
while (nb_modif < nn) {
|
||||
// follow the cycle that start with j_start
|
||||
uint64_t j = j_start;
|
||||
int64_t tmp1 = res[j];
|
||||
do {
|
||||
// find where the value should go, and with which sign
|
||||
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||
// exchange this position with tmp1 (and take care of the sign)
|
||||
int64_t tmp2 = res[new_j_n];
|
||||
res[new_j_n] = ((new_j < nn) ? tmp1 : -tmp1) - res[new_j_n];
|
||||
tmp1 = tmp2;
|
||||
// move to the new location, and store the number of items modified
|
||||
++nb_modif;
|
||||
j = new_j_n;
|
||||
} while (j != j_start);
|
||||
// move to the start of the next cycle:
|
||||
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||
++j_start;
|
||||
}
|
||||
}
|
||||
|
||||
// 0 < p < 2nn
|
||||
EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in) {
|
||||
res[0] = in[0];
|
||||
uint64_t a = 0;
|
||||
uint64_t _2mn = 2 * nn - 1;
|
||||
for (uint64_t i = 1; i < nn; i++) {
|
||||
a = (a + p) & _2mn; // i*p mod 2n
|
||||
if (a < nn) {
|
||||
res[a] = in[i]; // res[ip mod 2n] = res[i]
|
||||
} else {
|
||||
res[a - nn] = -in[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
|
||||
res[0] = in[0];
|
||||
uint64_t a = 0;
|
||||
uint64_t _2mn = 2 * nn - 1;
|
||||
for (uint64_t i = 1; i < nn; i++) {
|
||||
a = (a + p) & _2mn;
|
||||
if (a < nn) {
|
||||
res[a] = in[i]; // res[ip mod 2n] = res[i]
|
||||
} else {
|
||||
res[a - nn] = -in[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res) {
|
||||
const uint64_t _2mn = 2 * nn - 1;
|
||||
const uint64_t _mn = nn - 1;
|
||||
uint64_t nb_modif = 0;
|
||||
uint64_t j_start = 0;
|
||||
while (nb_modif < nn) {
|
||||
// follow the cycle that start with j_start
|
||||
uint64_t j = j_start;
|
||||
double tmp1 = res[j];
|
||||
do {
|
||||
// find where the value should go, and with which sign
|
||||
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||
// exchange this position with tmp1 (and take care of the sign)
|
||||
double tmp2 = res[new_j_n];
|
||||
res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1;
|
||||
tmp1 = tmp2;
|
||||
// move to the new location, and store the number of items modified
|
||||
++nb_modif;
|
||||
j = new_j_n;
|
||||
} while (j != j_start);
|
||||
// move to the start of the next cycle:
|
||||
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||
++j_start;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res) {
|
||||
const uint64_t _2mn = 2 * nn - 1;
|
||||
const uint64_t _mn = nn - 1;
|
||||
uint64_t nb_modif = 0;
|
||||
uint64_t j_start = 0;
|
||||
while (nb_modif < nn) {
|
||||
// follow the cycle that start with j_start
|
||||
uint64_t j = j_start;
|
||||
int64_t tmp1 = res[j];
|
||||
do {
|
||||
// find where the value should go, and with which sign
|
||||
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||
// exchange this position with tmp1 (and take care of the sign)
|
||||
int64_t tmp2 = res[new_j_n];
|
||||
res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1;
|
||||
tmp1 = tmp2;
|
||||
// move to the new location, and store the number of items modified
|
||||
++nb_modif;
|
||||
j = new_j_n;
|
||||
} while (j != j_start);
|
||||
// move to the start of the next cycle:
|
||||
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||
++j_start;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void rnx_mul_xp_minus_one_inplace_f64(uint64_t nn, int64_t p, double* res) {
|
||||
const uint64_t _2mn = 2 * nn - 1;
|
||||
const uint64_t _mn = nn - 1;
|
||||
uint64_t nb_modif = 0;
|
||||
uint64_t j_start = 0;
|
||||
while (nb_modif < nn) {
|
||||
// follow the cycle that start with j_start
|
||||
uint64_t j = j_start;
|
||||
double tmp1 = res[j];
|
||||
do {
|
||||
// find where the value should go, and with which sign
|
||||
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||
// exchange this position with tmp1 (and take care of the sign)
|
||||
double tmp2 = res[new_j_n];
|
||||
res[new_j_n] = ((new_j < nn) ? tmp1 : -tmp1) - res[new_j_n];
|
||||
tmp1 = tmp2;
|
||||
// move to the new location, and store the number of items modified
|
||||
++nb_modif;
|
||||
j = new_j_n;
|
||||
} while (j != j_start);
|
||||
// move to the start of the next cycle:
|
||||
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||
++j_start;
|
||||
}
|
||||
}
|
||||
|
||||
__always_inline int64_t get_base_k_digit(const int64_t x, const uint64_t base_k) {
|
||||
return (x << (64 - base_k)) >> (64 - base_k);
|
||||
}
|
||||
|
||||
__always_inline int64_t get_base_k_carry(const int64_t x, const int64_t digit, const uint64_t base_k) {
|
||||
return (x - digit) >> base_k;
|
||||
}
|
||||
|
||||
EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in,
|
||||
const int64_t* carry_in) {
|
||||
assert(in);
|
||||
if (out != 0) {
|
||||
if (carry_in != 0x0 && carry_out != 0x0) {
|
||||
// with carry in and carry out is computed
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
const int64_t x = in[i];
|
||||
const int64_t cin = carry_in[i];
|
||||
|
||||
int64_t digit = get_base_k_digit(x, base_k);
|
||||
int64_t carry = get_base_k_carry(x, digit, base_k);
|
||||
int64_t digit_plus_cin = digit + cin;
|
||||
int64_t y = get_base_k_digit(digit_plus_cin, base_k);
|
||||
int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k);
|
||||
|
||||
out[i] = y;
|
||||
carry_out[i] = cout;
|
||||
}
|
||||
} else if (carry_in != 0) {
|
||||
// with carry in and carry out is dropped
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
const int64_t x = in[i];
|
||||
const int64_t cin = carry_in[i];
|
||||
|
||||
int64_t digit = get_base_k_digit(x, base_k);
|
||||
int64_t digit_plus_cin = digit + cin;
|
||||
int64_t y = get_base_k_digit(digit_plus_cin, base_k);
|
||||
|
||||
out[i] = y;
|
||||
}
|
||||
|
||||
} else if (carry_out != 0) {
|
||||
// no carry in and carry out is computed
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
const int64_t x = in[i];
|
||||
|
||||
int64_t y = get_base_k_digit(x, base_k);
|
||||
int64_t cout = get_base_k_carry(x, y, base_k);
|
||||
|
||||
out[i] = y;
|
||||
carry_out[i] = cout;
|
||||
}
|
||||
|
||||
} else {
|
||||
// no carry in and carry out is dropped
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
out[i] = get_base_k_digit(in[i], base_k);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
assert(carry_out);
|
||||
if (carry_in != 0x0) {
|
||||
// with carry in and carry out is computed
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
const int64_t x = in[i];
|
||||
const int64_t cin = carry_in[i];
|
||||
|
||||
int64_t digit = get_base_k_digit(x, base_k);
|
||||
int64_t carry = get_base_k_carry(x, digit, base_k);
|
||||
int64_t digit_plus_cin = digit + cin;
|
||||
int64_t y = get_base_k_digit(digit_plus_cin, base_k);
|
||||
int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k);
|
||||
|
||||
carry_out[i] = cout;
|
||||
}
|
||||
} else {
|
||||
// no carry in and carry out is computed
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
const int64_t x = in[i];
|
||||
|
||||
int64_t y = get_base_k_digit(x, base_k);
|
||||
int64_t cout = get_base_k_carry(x, y, base_k);
|
||||
|
||||
carry_out[i] = cout;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res) {
|
||||
const uint64_t _2mn = 2 * nn - 1;
|
||||
const uint64_t _mn = nn - 1;
|
||||
const uint64_t m = nn >> 1;
|
||||
// reduce p mod 2n
|
||||
p &= _2mn;
|
||||
// uint64_t vp = p & _2mn;
|
||||
/// uint64_t target_modifs = m >> 1;
|
||||
// we proceed by increasing binary valuation
|
||||
for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn;
|
||||
binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) {
|
||||
// In this loop, we are going to treat the orbit of indexes = binval mod 2.binval.
|
||||
// At the beginning of this loop we have:
|
||||
// vp = binval * p mod 2n
|
||||
// target_modif = m / binval (i.e. order of the orbit binval % 2.binval)
|
||||
|
||||
// first, handle the orders 1 and 2.
|
||||
// if p*binval == binval % 2n: we're done!
|
||||
if (vp == binval) return;
|
||||
// if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit!
|
||||
if (((vp + binval) & _2mn) == 0) {
|
||||
for (uint64_t j = binval; j < m; j += binval) {
|
||||
int64_t tmp = res[j];
|
||||
res[j] = -res[nn - j];
|
||||
res[nn - j] = -tmp;
|
||||
}
|
||||
res[m] = -res[m];
|
||||
return;
|
||||
}
|
||||
// if p*binval == binval + n % 2n: negate the orbit and exit
|
||||
if (((vp - binval) & _mn) == 0) {
|
||||
for (uint64_t j = binval; j < nn; j += 2 * binval) {
|
||||
res[j] = -res[j];
|
||||
}
|
||||
return;
|
||||
}
|
||||
// if p*binval == n - binval % 2n: mirror the orbit and continue!
|
||||
if (((vp + binval) & _mn) == 0) {
|
||||
for (uint64_t j = binval; j < m; j += 2 * binval) {
|
||||
int64_t tmp = res[j];
|
||||
res[j] = res[nn - j];
|
||||
res[nn - j] = tmp;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// otherwise we will follow the orbit cycles,
|
||||
// starting from binval and -binval in parallel
|
||||
uint64_t j_start = binval;
|
||||
uint64_t nb_modif = 0;
|
||||
while (nb_modif < orb_size) {
|
||||
// follow the cycle that start with j_start
|
||||
uint64_t j = j_start;
|
||||
int64_t tmp1 = res[j];
|
||||
int64_t tmp2 = res[nn - j];
|
||||
do {
|
||||
// find where the value should go, and with which sign
|
||||
uint64_t new_j = (j * p) & _2mn; // mod 2n to get the position and sign
|
||||
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||
// exchange this position with tmp1 (and take care of the sign)
|
||||
int64_t tmp1a = res[new_j_n];
|
||||
int64_t tmp2a = res[nn - new_j_n];
|
||||
if (new_j < nn) {
|
||||
res[new_j_n] = tmp1;
|
||||
res[nn - new_j_n] = tmp2;
|
||||
} else {
|
||||
res[new_j_n] = -tmp1;
|
||||
res[nn - new_j_n] = -tmp2;
|
||||
}
|
||||
tmp1 = tmp1a;
|
||||
tmp2 = tmp2a;
|
||||
// move to the new location, and store the number of items modified
|
||||
nb_modif += 2;
|
||||
j = new_j_n;
|
||||
} while (j != j_start);
|
||||
// move to the start of the next cycle:
|
||||
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||
// in practice, it is enough to do *5, because 5 is a generator.
|
||||
j_start = (5 * j_start) & _mn;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res) {
|
||||
const uint64_t _2mn = 2 * nn - 1;
|
||||
const uint64_t _mn = nn - 1;
|
||||
const uint64_t m = nn >> 1;
|
||||
// reduce p mod 2n
|
||||
p &= _2mn;
|
||||
// uint64_t vp = p & _2mn;
|
||||
/// uint64_t target_modifs = m >> 1;
|
||||
// we proceed by increasing binary valuation
|
||||
for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn;
|
||||
binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) {
|
||||
// In this loop, we are going to treat the orbit of indexes = binval mod 2.binval.
|
||||
// At the beginning of this loop we have:
|
||||
// vp = binval * p mod 2n
|
||||
// target_modif = m / binval (i.e. order of the orbit binval % 2.binval)
|
||||
|
||||
// first, handle the orders 1 and 2.
|
||||
// if p*binval == binval % 2n: we're done!
|
||||
if (vp == binval) return;
|
||||
// if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit!
|
||||
if (((vp + binval) & _2mn) == 0) {
|
||||
for (uint64_t j = binval; j < m; j += binval) {
|
||||
double tmp = res[j];
|
||||
res[j] = -res[nn - j];
|
||||
res[nn - j] = -tmp;
|
||||
}
|
||||
res[m] = -res[m];
|
||||
return;
|
||||
}
|
||||
// if p*binval == binval + n % 2n: negate the orbit and exit
|
||||
if (((vp - binval) & _mn) == 0) {
|
||||
for (uint64_t j = binval; j < nn; j += 2 * binval) {
|
||||
res[j] = -res[j];
|
||||
}
|
||||
return;
|
||||
}
|
||||
// if p*binval == n - binval % 2n: mirror the orbit and continue!
|
||||
if (((vp + binval) & _mn) == 0) {
|
||||
for (uint64_t j = binval; j < m; j += 2 * binval) {
|
||||
double tmp = res[j];
|
||||
res[j] = res[nn - j];
|
||||
res[nn - j] = tmp;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// otherwise we will follow the orbit cycles,
|
||||
// starting from binval and -binval in parallel
|
||||
uint64_t j_start = binval;
|
||||
uint64_t nb_modif = 0;
|
||||
while (nb_modif < orb_size) {
|
||||
// follow the cycle that start with j_start
|
||||
uint64_t j = j_start;
|
||||
double tmp1 = res[j];
|
||||
double tmp2 = res[nn - j];
|
||||
do {
|
||||
// find where the value should go, and with which sign
|
||||
uint64_t new_j = (j * p) & _2mn; // mod 2n to get the position and sign
|
||||
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||
// exchange this position with tmp1 (and take care of the sign)
|
||||
double tmp1a = res[new_j_n];
|
||||
double tmp2a = res[nn - new_j_n];
|
||||
if (new_j < nn) {
|
||||
res[new_j_n] = tmp1;
|
||||
res[nn - new_j_n] = tmp2;
|
||||
} else {
|
||||
res[new_j_n] = -tmp1;
|
||||
res[nn - new_j_n] = -tmp2;
|
||||
}
|
||||
tmp1 = tmp1a;
|
||||
tmp2 = tmp2a;
|
||||
// move to the new location, and store the number of items modified
|
||||
nb_modif += 2;
|
||||
j = new_j_n;
|
||||
} while (j != j_start);
|
||||
// move to the start of the next cycle:
|
||||
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||
// in practice, it is enough to do *5, because 5 is a generator.
|
||||
j_start = (5 * j_start) & _mn;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
#ifndef SPQLIOS_COEFFS_ARITHMETIC_H
|
||||
#define SPQLIOS_COEFFS_ARITHMETIC_H
|
||||
|
||||
#include "../commons.h"
|
||||
|
||||
/** res = a + b */
|
||||
EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||
EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||
/** res = a - b */
|
||||
EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||
EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||
/** res = -a */
|
||||
EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a);
|
||||
EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a);
|
||||
/** res = a */
|
||||
EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a);
|
||||
/** res = 0 */
|
||||
EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res);
|
||||
|
||||
/** res = a / m where m is a power of 2 */
|
||||
EXPORT void rnx_divide_by_m_ref(uint64_t nn, double m, double* res, const double* a);
|
||||
EXPORT void rnx_divide_by_m_avx(uint64_t nn, double m, double* res, const double* a);
|
||||
|
||||
/**
|
||||
* @param res = X^p *in mod X^nn +1
|
||||
* @param nn the ring dimension
|
||||
* @param p a power for the rotation -2nn <= p <= 2nn
|
||||
* @param in is a rnx/znx vector of dimension nn
|
||||
*/
|
||||
EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in);
|
||||
EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
|
||||
EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res);
|
||||
EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
|
||||
|
||||
/**
|
||||
* @brief res(X) = in(X^p)
|
||||
* @param nn the ring dimension
|
||||
* @param p is odd integer and must be between 0 < p < 2nn
|
||||
* @param in is a rnx/znx vector of dimension nn
|
||||
*/
|
||||
EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in);
|
||||
EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
|
||||
EXPORT void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res);
|
||||
EXPORT void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
|
||||
|
||||
/**
|
||||
* @brief res = (X^p-1).in
|
||||
* @param nn the ring dimension
|
||||
* @param p must be between -2nn <= p <= 2nn
|
||||
* @param in is a rnx/znx vector of dimension nn
|
||||
*/
|
||||
EXPORT void rnx_mul_xp_minus_one_f64(uint64_t nn, int64_t p, double* res, const double* in);
|
||||
EXPORT void znx_mul_xp_minus_one_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
|
||||
EXPORT void rnx_mul_xp_minus_one_inplace_f64(uint64_t nn, int64_t p, double* res);
|
||||
EXPORT void znx_mul_xp_minus_one_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
|
||||
|
||||
/**
|
||||
* @brief Normalize input plus carry mod-2^k. The following
|
||||
* equality holds @c {in + carry_in == out + carry_out . 2^k}.
|
||||
*
|
||||
* @c in must be in [-2^62 .. 2^62]
|
||||
*
|
||||
* @c out is in [ -2^(base_k-1), 2^(base_k-1) [.
|
||||
*
|
||||
* @c carry_in and @carry_out have at most 64+1-k bits.
|
||||
*
|
||||
* Null @c carry_in or @c carry_out are ignored.
|
||||
*
|
||||
* @param[in] nn the ring dimension
|
||||
* @param[in] base_k the base k
|
||||
* @param out output normalized znx
|
||||
* @param carry_out output carry znx
|
||||
* @param[in] in input znx
|
||||
* @param[in] carry_in input carry znx
|
||||
*/
|
||||
EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in,
|
||||
const int64_t* carry_in);
|
||||
|
||||
#endif // SPQLIOS_COEFFS_ARITHMETIC_H
|
||||
@@ -1,124 +0,0 @@
|
||||
#include <immintrin.h>
|
||||
|
||||
#include "../commons_private.h"
|
||||
#include "coeffs_arithmetic.h"
|
||||
|
||||
// res = a + b. dimension n must be a power of 2
|
||||
EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||
if (nn <= 2) {
|
||||
if (nn == 1) {
|
||||
res[0] = a[0] + b[0];
|
||||
} else {
|
||||
_mm_storeu_si128((__m128i*)res, //
|
||||
_mm_add_epi64( //
|
||||
_mm_loadu_si128((__m128i*)a), //
|
||||
_mm_loadu_si128((__m128i*)b)));
|
||||
}
|
||||
} else {
|
||||
const __m256i* aa = (__m256i*)a;
|
||||
const __m256i* bb = (__m256i*)b;
|
||||
__m256i* rr = (__m256i*)res;
|
||||
__m256i* const rrend = (__m256i*)(res + nn);
|
||||
do {
|
||||
_mm256_storeu_si256(rr, //
|
||||
_mm256_add_epi64( //
|
||||
_mm256_loadu_si256(aa), //
|
||||
_mm256_loadu_si256(bb)));
|
||||
++rr;
|
||||
++aa;
|
||||
++bb;
|
||||
} while (rr < rrend);
|
||||
}
|
||||
}
|
||||
|
||||
// res = a - b. dimension n must be a power of 2
|
||||
EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||
if (nn <= 2) {
|
||||
if (nn == 1) {
|
||||
res[0] = a[0] - b[0];
|
||||
} else {
|
||||
_mm_storeu_si128((__m128i*)res, //
|
||||
_mm_sub_epi64( //
|
||||
_mm_loadu_si128((__m128i*)a), //
|
||||
_mm_loadu_si128((__m128i*)b)));
|
||||
}
|
||||
} else {
|
||||
const __m256i* aa = (__m256i*)a;
|
||||
const __m256i* bb = (__m256i*)b;
|
||||
__m256i* rr = (__m256i*)res;
|
||||
__m256i* const rrend = (__m256i*)(res + nn);
|
||||
do {
|
||||
_mm256_storeu_si256(rr, //
|
||||
_mm256_sub_epi64( //
|
||||
_mm256_loadu_si256(aa), //
|
||||
_mm256_loadu_si256(bb)));
|
||||
++rr;
|
||||
++aa;
|
||||
++bb;
|
||||
} while (rr < rrend);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a) {
|
||||
if (nn <= 2) {
|
||||
if (nn == 1) {
|
||||
res[0] = -a[0];
|
||||
} else {
|
||||
_mm_storeu_si128((__m128i*)res, //
|
||||
_mm_sub_epi64( //
|
||||
_mm_set1_epi64x(0), //
|
||||
_mm_loadu_si128((__m128i*)a)));
|
||||
}
|
||||
} else {
|
||||
const __m256i* aa = (__m256i*)a;
|
||||
__m256i* rr = (__m256i*)res;
|
||||
__m256i* const rrend = (__m256i*)(res + nn);
|
||||
do {
|
||||
_mm256_storeu_si256(rr, //
|
||||
_mm256_sub_epi64( //
|
||||
_mm256_set1_epi64x(0), //
|
||||
_mm256_loadu_si256(aa)));
|
||||
++rr;
|
||||
++aa;
|
||||
} while (rr < rrend);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void rnx_divide_by_m_avx(uint64_t n, double m, double* res, const double* a) {
|
||||
// TODO: see if there is a faster way of dividing by a power of 2?
|
||||
const double invm = 1. / m;
|
||||
if (n < 8) {
|
||||
switch (n) {
|
||||
case 1:
|
||||
*res = *a * invm;
|
||||
break;
|
||||
case 2:
|
||||
_mm_storeu_pd(res, //
|
||||
_mm_mul_pd(_mm_loadu_pd(a), //
|
||||
_mm_set1_pd(invm)));
|
||||
break;
|
||||
case 4:
|
||||
_mm256_storeu_pd(res, //
|
||||
_mm256_mul_pd(_mm256_loadu_pd(a), //
|
||||
_mm256_set1_pd(invm)));
|
||||
break;
|
||||
default:
|
||||
NOT_SUPPORTED(); // non-power of 2
|
||||
}
|
||||
return;
|
||||
}
|
||||
const __m256d invm256 = _mm256_set1_pd(invm);
|
||||
double* rr = res;
|
||||
const double* aa = a;
|
||||
const double* const aaend = a + n;
|
||||
do {
|
||||
_mm256_storeu_pd(rr, //
|
||||
_mm256_mul_pd(_mm256_loadu_pd(aa), //
|
||||
invm256));
|
||||
_mm256_storeu_pd(rr + 4, //
|
||||
_mm256_mul_pd(_mm256_loadu_pd(aa + 4), //
|
||||
invm256));
|
||||
rr += 8;
|
||||
aa += 8;
|
||||
} while (aa < aaend);
|
||||
}
|
||||
@@ -1,165 +0,0 @@
|
||||
#include "commons.h"
|
||||
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
EXPORT void* UNDEFINED_p_ii(int32_t n, int32_t m) { UNDEFINED(); }
|
||||
EXPORT void* UNDEFINED_p_uu(uint32_t n, uint32_t m) { UNDEFINED(); }
|
||||
EXPORT double* UNDEFINED_dp_pi(const void* p, int32_t n) { UNDEFINED(); }
|
||||
EXPORT void* UNDEFINED_vp_pi(const void* p, int32_t n) { UNDEFINED(); }
|
||||
EXPORT void* UNDEFINED_vp_pu(const void* p, uint32_t n) { UNDEFINED(); }
|
||||
EXPORT void UNDEFINED_v_vpdp(const void* p, double* a) { UNDEFINED(); }
|
||||
EXPORT void UNDEFINED_v_vpvp(const void* p, void* a) { UNDEFINED(); }
|
||||
EXPORT double* NOT_IMPLEMENTED_dp_i(int32_t n) { NOT_IMPLEMENTED(); }
|
||||
EXPORT void* NOT_IMPLEMENTED_vp_i(int32_t n) { NOT_IMPLEMENTED(); }
|
||||
EXPORT void* NOT_IMPLEMENTED_vp_u(uint32_t n) { NOT_IMPLEMENTED(); }
|
||||
EXPORT void NOT_IMPLEMENTED_v_dp(double* a) { NOT_IMPLEMENTED(); }
|
||||
EXPORT void NOT_IMPLEMENTED_v_vp(void* p) { NOT_IMPLEMENTED(); }
|
||||
EXPORT void NOT_IMPLEMENTED_v_idpdpdp(int32_t n, double* a, const double* b, const double* c) { NOT_IMPLEMENTED(); }
|
||||
EXPORT void NOT_IMPLEMENTED_v_uvpcvpcvp(uint32_t n, void* r, const void* a, const void* b) { NOT_IMPLEMENTED(); }
|
||||
EXPORT void NOT_IMPLEMENTED_v_uvpvpcvp(uint32_t n, void* a, void* b, const void* o) { NOT_IMPLEMENTED(); }
|
||||
|
||||
#ifdef _WIN32
|
||||
#define __always_inline inline __attribute((always_inline))
|
||||
#endif
|
||||
|
||||
void internal_accurate_sincos(double* rcos, double* rsin, double x) {
|
||||
double _4_x_over_pi = 4 * x / M_PI;
|
||||
int64_t int_part = ((int64_t)rint(_4_x_over_pi)) & 7;
|
||||
double frac_part = _4_x_over_pi - (double)(int_part);
|
||||
double frac_x = M_PI * frac_part / 4.;
|
||||
// compute the taylor series
|
||||
double cosp = 1.;
|
||||
double sinp = 0.;
|
||||
double powx = 1.;
|
||||
int64_t nn = 0;
|
||||
while (fabs(powx) > 1e-20) {
|
||||
++nn;
|
||||
powx = powx * frac_x / (double)(nn); // x^n/n!
|
||||
switch (nn & 3) {
|
||||
case 0:
|
||||
cosp += powx;
|
||||
break;
|
||||
case 1:
|
||||
sinp += powx;
|
||||
break;
|
||||
case 2:
|
||||
cosp -= powx;
|
||||
break;
|
||||
case 3:
|
||||
sinp -= powx;
|
||||
break;
|
||||
default:
|
||||
abort(); // impossible
|
||||
}
|
||||
}
|
||||
// final multiplication
|
||||
switch (int_part) {
|
||||
case 0:
|
||||
*rcos = cosp;
|
||||
*rsin = sinp;
|
||||
break;
|
||||
case 1:
|
||||
*rcos = M_SQRT1_2 * (cosp - sinp);
|
||||
*rsin = M_SQRT1_2 * (cosp + sinp);
|
||||
break;
|
||||
case 2:
|
||||
*rcos = -sinp;
|
||||
*rsin = cosp;
|
||||
break;
|
||||
case 3:
|
||||
*rcos = -M_SQRT1_2 * (cosp + sinp);
|
||||
*rsin = M_SQRT1_2 * (cosp - sinp);
|
||||
break;
|
||||
case 4:
|
||||
*rcos = -cosp;
|
||||
*rsin = -sinp;
|
||||
break;
|
||||
case 5:
|
||||
*rcos = -M_SQRT1_2 * (cosp - sinp);
|
||||
*rsin = -M_SQRT1_2 * (cosp + sinp);
|
||||
break;
|
||||
case 6:
|
||||
*rcos = sinp;
|
||||
*rsin = -cosp;
|
||||
break;
|
||||
case 7:
|
||||
*rcos = M_SQRT1_2 * (cosp + sinp);
|
||||
*rsin = -M_SQRT1_2 * (cosp - sinp);
|
||||
break;
|
||||
default:
|
||||
abort(); // impossible
|
||||
}
|
||||
if (fabs(cos(x) - *rcos) > 1e-10 || fabs(sin(x) - *rsin) > 1e-10) {
|
||||
printf("cos(%.17lf) =? %.17lf instead of %.17lf\n", x, *rcos, cos(x));
|
||||
printf("sin(%.17lf) =? %.17lf instead of %.17lf\n", x, *rsin, sin(x));
|
||||
printf("fracx = %.17lf\n", frac_x);
|
||||
printf("cosp = %.17lf\n", cosp);
|
||||
printf("sinp = %.17lf\n", sinp);
|
||||
printf("nn = %d\n", (int)(nn));
|
||||
}
|
||||
}
|
||||
|
||||
double internal_accurate_cos(double x) {
|
||||
double rcos, rsin;
|
||||
internal_accurate_sincos(&rcos, &rsin, x);
|
||||
return rcos;
|
||||
}
|
||||
double internal_accurate_sin(double x) {
|
||||
double rcos, rsin;
|
||||
internal_accurate_sincos(&rcos, &rsin, x);
|
||||
return rsin;
|
||||
}
|
||||
|
||||
EXPORT void spqlios_debug_free(void* addr) { free((uint8_t*)addr - 64); }
|
||||
|
||||
EXPORT void* spqlios_debug_alloc(uint64_t size) { return (uint8_t*)malloc(size + 64) + 64; }
|
||||
|
||||
EXPORT void spqlios_free(void* addr) {
|
||||
#ifndef NDEBUG
|
||||
// in debug mode, we deallocated with spqlios_debug_free()
|
||||
spqlios_debug_free(addr);
|
||||
#else
|
||||
// in release mode, the function will free aligned memory
|
||||
#ifdef _WIN32
|
||||
_aligned_free(addr);
|
||||
#else
|
||||
free(addr);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
EXPORT void* spqlios_alloc(uint64_t size) {
|
||||
#ifndef NDEBUG
|
||||
// in debug mode, the function will not necessarily have any particular alignment
|
||||
// it will also ensure that memory can only be deallocated with spqlios_free()
|
||||
return spqlios_debug_alloc(size);
|
||||
#else
|
||||
// in release mode, the function will return 64-bytes aligned memory
|
||||
#ifdef _WIN32
|
||||
void* reps = _aligned_malloc((size + 63) & (UINT64_C(-64)), 64);
|
||||
#else
|
||||
void* reps = aligned_alloc(64, (size + 63) & (UINT64_C(-64)));
|
||||
#endif
|
||||
if (reps == 0) FATAL_ERROR("Out of memory");
|
||||
return reps;
|
||||
#endif
|
||||
}
|
||||
|
||||
EXPORT void* spqlios_alloc_custom_align(uint64_t align, uint64_t size) {
|
||||
#ifndef NDEBUG
|
||||
// in debug mode, the function will not necessarily have any particular alignment
|
||||
// it will also ensure that memory can only be deallocated with spqlios_free()
|
||||
return spqlios_debug_alloc(size);
|
||||
#else
|
||||
// in release mode, the function will return aligned memory
|
||||
#ifdef _WIN32
|
||||
void* reps = _aligned_malloc(size, align);
|
||||
#else
|
||||
void* reps = aligned_alloc(align, size);
|
||||
#endif
|
||||
if (reps == 0) FATAL_ERROR("Out of memory");
|
||||
return reps;
|
||||
#endif
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
#ifndef SPQLIOS_COMMONS_H
|
||||
#define SPQLIOS_COMMONS_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#define EXPORT extern "C"
|
||||
#define EXPORT_DECL extern "C"
|
||||
#else
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#define EXPORT
|
||||
#define EXPORT_DECL extern
|
||||
#define nullptr 0x0;
|
||||
#endif
|
||||
|
||||
#define UNDEFINED() \
|
||||
{ \
|
||||
fprintf(stderr, "UNDEFINED!!!\n"); \
|
||||
abort(); \
|
||||
}
|
||||
#define NOT_IMPLEMENTED() \
|
||||
{ \
|
||||
fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \
|
||||
abort(); \
|
||||
}
|
||||
#define FATAL_ERROR(MESSAGE) \
|
||||
{ \
|
||||
fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \
|
||||
abort(); \
|
||||
}
|
||||
|
||||
EXPORT void* UNDEFINED_p_ii(int32_t n, int32_t m);
|
||||
EXPORT void* UNDEFINED_p_uu(uint32_t n, uint32_t m);
|
||||
EXPORT double* UNDEFINED_dp_pi(const void* p, int32_t n);
|
||||
EXPORT void* UNDEFINED_vp_pi(const void* p, int32_t n);
|
||||
EXPORT void* UNDEFINED_vp_pu(const void* p, uint32_t n);
|
||||
EXPORT void UNDEFINED_v_vpdp(const void* p, double* a);
|
||||
EXPORT void UNDEFINED_v_vpvp(const void* p, void* a);
|
||||
EXPORT double* NOT_IMPLEMENTED_dp_i(int32_t n);
|
||||
EXPORT void* NOT_IMPLEMENTED_vp_i(int32_t n);
|
||||
EXPORT void* NOT_IMPLEMENTED_vp_u(uint32_t n);
|
||||
EXPORT void NOT_IMPLEMENTED_v_dp(double* a);
|
||||
EXPORT void NOT_IMPLEMENTED_v_vp(void* p);
|
||||
EXPORT void NOT_IMPLEMENTED_v_idpdpdp(int32_t n, double* a, const double* b, const double* c);
|
||||
EXPORT void NOT_IMPLEMENTED_v_uvpcvpcvp(uint32_t n, void* r, const void* a, const void* b);
|
||||
EXPORT void NOT_IMPLEMENTED_v_uvpvpcvp(uint32_t n, void* a, void* b, const void* o);
|
||||
|
||||
// windows
|
||||
|
||||
#if defined(_WIN32) || defined(__APPLE__)
|
||||
#define __always_inline inline __attribute((always_inline))
|
||||
#endif
|
||||
|
||||
EXPORT void spqlios_free(void* address);
|
||||
|
||||
EXPORT void* spqlios_alloc(uint64_t size);
|
||||
EXPORT void* spqlios_alloc_custom_align(uint64_t align, uint64_t size);
|
||||
|
||||
#define USE_LIBM_SIN_COS
|
||||
#ifndef USE_LIBM_SIN_COS
|
||||
// if at some point, we want to remove the libm dependency, we can
|
||||
// consider this:
|
||||
EXPORT double internal_accurate_cos(double x);
|
||||
EXPORT double internal_accurate_sin(double x);
|
||||
EXPORT void internal_accurate_sincos(double* rcos, double* rsin, double x);
|
||||
#define m_accurate_cos internal_accurate_cos
|
||||
#define m_accurate_sin internal_accurate_sin
|
||||
#else
|
||||
// let's use libm sin and cos
|
||||
#define m_accurate_cos cos
|
||||
#define m_accurate_sin sin
|
||||
#endif
|
||||
|
||||
#endif // SPQLIOS_COMMONS_H
|
||||
@@ -1,55 +0,0 @@
|
||||
#include "commons_private.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "commons.h"
|
||||
|
||||
EXPORT void* spqlios_error(const char* error) {
|
||||
fputs(error, stderr);
|
||||
abort();
|
||||
return nullptr;
|
||||
}
|
||||
EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2) {
|
||||
if (!ptr2) {
|
||||
free(ptr);
|
||||
}
|
||||
return ptr2;
|
||||
}
|
||||
|
||||
EXPORT uint32_t log2m(uint32_t m) {
|
||||
uint32_t a = m - 1;
|
||||
if (m & a) FATAL_ERROR("m must be a power of two");
|
||||
a = (a & 0x55555555u) + ((a >> 1) & 0x55555555u);
|
||||
a = (a & 0x33333333u) + ((a >> 2) & 0x33333333u);
|
||||
a = (a & 0x0F0F0F0Fu) + ((a >> 4) & 0x0F0F0F0Fu);
|
||||
a = (a & 0x00FF00FFu) + ((a >> 8) & 0x00FF00FFu);
|
||||
return (a & 0x0000FFFFu) + ((a >> 16) & 0x0000FFFFu);
|
||||
}
|
||||
|
||||
EXPORT uint64_t is_not_pow2_double(void* doublevalue) { return (*(uint64_t*)doublevalue) & 0x7FFFFFFFFFFFFUL; }
|
||||
|
||||
uint32_t revbits(uint32_t nbits, uint32_t value) {
|
||||
uint32_t res = 0;
|
||||
for (uint32_t i = 0; i < nbits; ++i) {
|
||||
res = (res << 1) + (value & 1);
|
||||
value >>= 1;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,...
|
||||
* essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/
|
||||
double fracrevbits(uint32_t i) {
|
||||
if (i == 0) return 0;
|
||||
if (i == 1) return 0.5;
|
||||
if (i % 2 == 0)
|
||||
return fracrevbits(i / 2) / 2.;
|
||||
else
|
||||
return fracrevbits((i - 1) / 2) / 2. + 0.5;
|
||||
}
|
||||
|
||||
uint64_t ceilto64b(uint64_t size) { return (size + UINT64_C(63)) & (UINT64_C(-64)); }
|
||||
|
||||
uint64_t ceilto32b(uint64_t size) { return (size + UINT64_C(31)) & (UINT64_C(-32)); }
|
||||
@@ -1,72 +0,0 @@
|
||||
#ifndef SPQLIOS_COMMONS_PRIVATE_H
|
||||
#define SPQLIOS_COMMONS_PRIVATE_H
|
||||
|
||||
#include "commons.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#else
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#define nullptr 0x0;
|
||||
#endif
|
||||
|
||||
/** @brief log2 of a power of two (UB if m is not a power of two) */
|
||||
EXPORT uint32_t log2m(uint32_t m);
|
||||
|
||||
/** @brief checks if the doublevalue is a power of two */
|
||||
EXPORT uint64_t is_not_pow2_double(void* doublevalue);
|
||||
|
||||
#define UNDEFINED() \
|
||||
{ \
|
||||
fprintf(stderr, "UNDEFINED!!!\n"); \
|
||||
abort(); \
|
||||
}
|
||||
#define NOT_IMPLEMENTED() \
|
||||
{ \
|
||||
fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \
|
||||
abort(); \
|
||||
}
|
||||
#define NOT_SUPPORTED() \
|
||||
{ \
|
||||
fprintf(stderr, "NOT SUPPORTED!!!\n"); \
|
||||
abort(); \
|
||||
}
|
||||
#define FATAL_ERROR(MESSAGE) \
|
||||
{ \
|
||||
fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \
|
||||
abort(); \
|
||||
}
|
||||
|
||||
#define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)])
|
||||
|
||||
/** @brief reports the error and returns nullptr */
|
||||
EXPORT void* spqlios_error(const char* error);
|
||||
/** @brief if ptr2 is not null, returns ptr, otherwise free ptr and return null */
|
||||
EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2);
|
||||
|
||||
#ifdef __x86_64__
|
||||
#define CPU_SUPPORTS __builtin_cpu_supports
|
||||
#else
|
||||
// TODO for now, we do not have any optimization for non x86 targets
|
||||
#define CPU_SUPPORTS(xxxx) 0
|
||||
#endif
|
||||
|
||||
/** @brief returns the n bits of value in reversed order */
|
||||
EXPORT uint32_t revbits(uint32_t nbits, uint32_t value);
|
||||
|
||||
/**
|
||||
* @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,...
|
||||
* essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/
|
||||
EXPORT double fracrevbits(uint32_t i);
|
||||
|
||||
/** @brief smallest multiple of 64 higher or equal to size */
|
||||
EXPORT uint64_t ceilto64b(uint64_t size);
|
||||
|
||||
/** @brief smallest multiple of 32 higher or equal to size */
|
||||
EXPORT uint64_t ceilto32b(uint64_t size);
|
||||
|
||||
#endif // SPQLIOS_COMMONS_PRIVATE_H
|
||||
@@ -1,22 +0,0 @@
|
||||
In this folder, we deal with the full complex FFT in `C[X] mod X^M-i`.
|
||||
One complex is represented by two consecutive doubles `(real,imag)`
|
||||
Note that a real polynomial sum_{j=0}^{N-1} p_j.X^j mod X^N+1
|
||||
corresponds to the complex polynomial of half degree `M=N/2`:
|
||||
`sum_{j=0}^{M-1} (p_{j} + i.p_{j+M}) X^j mod X^M-i`
|
||||
|
||||
For a complex polynomial A(X) sum c_i X^i of degree M-1
|
||||
or a real polynomial sum a_i X^i of degree N
|
||||
|
||||
coefficient space:
|
||||
a_0,a_M,a_1,a_{M+1},...,a_{M-1},a_{2M-1}
|
||||
or equivalently
|
||||
Re(c_0),Im(c_0),Re(c_1),Im(c_1),...Re(c_{M-1}),Im(c_{M-1})
|
||||
|
||||
eval space:
|
||||
c(omega_{0}),...,c(omega_{M-1})
|
||||
|
||||
where
|
||||
omega_j = omega^{1+rev_{2N}(j)}
|
||||
and omega = exp(i.pi/N)
|
||||
|
||||
rev_{2N}(j) is the number that has the log2(2N) bits of j in reverse order.
|
||||
@@ -1,80 +0,0 @@
|
||||
#include "cplx_fft_internal.h"
|
||||
|
||||
void cplx_set(CPLX r, const CPLX a) {
|
||||
r[0] = a[0];
|
||||
r[1] = a[1];
|
||||
}
|
||||
void cplx_neg(CPLX r, const CPLX a) {
|
||||
r[0] = -a[0];
|
||||
r[1] = -a[1];
|
||||
}
|
||||
void cplx_add(CPLX r, const CPLX a, const CPLX b) {
|
||||
r[0] = a[0] + b[0];
|
||||
r[1] = a[1] + b[1];
|
||||
}
|
||||
void cplx_sub(CPLX r, const CPLX a, const CPLX b) {
|
||||
r[0] = a[0] - b[0];
|
||||
r[1] = a[1] - b[1];
|
||||
}
|
||||
void cplx_mul(CPLX r, const CPLX a, const CPLX b) {
|
||||
double re = a[0] * b[0] - a[1] * b[1];
|
||||
r[1] = a[0] * b[1] + a[1] * b[0];
|
||||
r[0] = re;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief splits 2h evaluations of one polynomials into 2 times h evaluations of even/odd polynomial
|
||||
* Input: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y)
|
||||
* Output: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z)
|
||||
* where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z
|
||||
* @param h number of "coefficients" h >= 1
|
||||
* @param data 2h complex coefficients interleaved and 256b aligned
|
||||
* @param powom y represented as (yre,yim)
|
||||
*/
|
||||
EXPORT void cplx_split_fft_ref(int32_t h, CPLX* data, const CPLX powom) {
|
||||
CPLX* d0 = data;
|
||||
CPLX* d1 = data + h;
|
||||
for (uint64_t i = 0; i < h; ++i) {
|
||||
CPLX diff;
|
||||
cplx_sub(diff, d0[i], d1[i]);
|
||||
cplx_add(d0[i], d0[i], d1[i]);
|
||||
cplx_mul(d1[i], diff, powom);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Do two layers of itwiddle (i.e. split).
|
||||
* Input/output: d0,d1,d2,d3 of length h
|
||||
* Algo:
|
||||
* itwiddle(d0,d1,om[0]),itwiddle(d2,d3,i.om[0])
|
||||
* itwiddle(d0,d2,om[1]),itwiddle(d1,d3,om[1])
|
||||
* @param h number of "coefficients" h >= 1
|
||||
* @param data 4h complex coefficients interleaved and 256b aligned
|
||||
* @param powom om[0] (re,im) and om[1] where om[1]=om[0]^2
|
||||
*/
|
||||
EXPORT void cplx_bisplit_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]) {
|
||||
CPLX* d0 = data;
|
||||
CPLX* d2 = data + 2 * h;
|
||||
const CPLX* om0 = powom;
|
||||
CPLX iom0;
|
||||
iom0[0] = powom[0][1];
|
||||
iom0[1] = -powom[0][0];
|
||||
const CPLX* om1 = powom + 1;
|
||||
cplx_split_fft_ref(h, d0, *om0);
|
||||
cplx_split_fft_ref(h, d2, iom0);
|
||||
cplx_split_fft_ref(2 * h, d0, *om1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Input: Q(y),Q(-y)
|
||||
* Output: P_0(z),P_1(z)
|
||||
* where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z
|
||||
* @param data 2 complexes coefficients interleaved and 256b aligned
|
||||
* @param powom (z,-z) interleaved: (zre,zim,-zre,-zim)
|
||||
*/
|
||||
void split_fft_last_ref(CPLX* data, const CPLX powom) {
|
||||
CPLX diff;
|
||||
cplx_sub(diff, data[0], data[1]);
|
||||
cplx_add(data[0], data[0], data[1]);
|
||||
cplx_mul(data[1], diff, powom);
|
||||
}
|
||||
@@ -1,158 +0,0 @@
|
||||
#include <errno.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "../commons_private.h"
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
EXPORT void cplx_from_znx32_ref(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||
const uint32_t m = precomp->m;
|
||||
const int32_t* inre = x;
|
||||
const int32_t* inim = x + m;
|
||||
CPLX* out = r;
|
||||
for (uint32_t i = 0; i < m; ++i) {
|
||||
out[i][0] = (double)inre[i];
|
||||
out[i][1] = (double)inim[i];
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void cplx_from_tnx32_ref(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||
static const double _2p32 = 1. / (INT64_C(1) << 32);
|
||||
const uint32_t m = precomp->m;
|
||||
const int32_t* inre = x;
|
||||
const int32_t* inim = x + m;
|
||||
CPLX* out = r;
|
||||
for (uint32_t i = 0; i < m; ++i) {
|
||||
out[i][0] = ((double)inre[i]) * _2p32;
|
||||
out[i][1] = ((double)inim[i]) * _2p32;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void cplx_to_tnx32_ref(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) {
|
||||
static const double _2p32 = (INT64_C(1) << 32);
|
||||
const uint32_t m = precomp->m;
|
||||
double factor = _2p32 / precomp->divisor;
|
||||
int32_t* outre = r;
|
||||
int32_t* outim = r + m;
|
||||
const CPLX* in = x;
|
||||
// Note: this formula will only work if abs(in) < 2^32
|
||||
for (uint32_t i = 0; i < m; ++i) {
|
||||
outre[i] = (int32_t)(int64_t)(rint(in[i][0] * factor));
|
||||
outim[i] = (int32_t)(int64_t)(rint(in[i][1] * factor));
|
||||
}
|
||||
}
|
||||
|
||||
void* init_cplx_from_znx32_precomp(CPLX_FROM_ZNX32_PRECOMP* res, uint32_t m) {
|
||||
res->m = m;
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
if (m >= 8) {
|
||||
res->function = cplx_from_znx32_avx2_fma;
|
||||
} else {
|
||||
res->function = cplx_from_znx32_ref;
|
||||
}
|
||||
} else {
|
||||
res->function = cplx_from_znx32_ref;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
CPLX_FROM_ZNX32_PRECOMP* new_cplx_from_znx32_precomp(uint32_t m) {
|
||||
CPLX_FROM_ZNX32_PRECOMP* res = malloc(sizeof(CPLX_FROM_ZNX32_PRECOMP));
|
||||
if (!res) return spqlios_error(strerror(errno));
|
||||
return spqlios_keep_or_free(res, init_cplx_from_znx32_precomp(res, m));
|
||||
}
|
||||
|
||||
void* init_cplx_from_tnx32_precomp(CPLX_FROM_TNX32_PRECOMP* res, uint32_t m) {
|
||||
res->m = m;
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
if (m >= 8) {
|
||||
res->function = cplx_from_tnx32_avx2_fma;
|
||||
} else {
|
||||
res->function = cplx_from_tnx32_ref;
|
||||
}
|
||||
} else {
|
||||
res->function = cplx_from_tnx32_ref;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
CPLX_FROM_TNX32_PRECOMP* new_cplx_from_tnx32_precomp(uint32_t m) {
|
||||
CPLX_FROM_TNX32_PRECOMP* res = malloc(sizeof(CPLX_FROM_TNX32_PRECOMP));
|
||||
if (!res) return spqlios_error(strerror(errno));
|
||||
return spqlios_keep_or_free(res, init_cplx_from_tnx32_precomp(res, m));
|
||||
}
|
||||
|
||||
void* init_cplx_to_tnx32_precomp(CPLX_TO_TNX32_PRECOMP* res, uint32_t m, double divisor, uint32_t log2overhead) {
|
||||
if (is_not_pow2_double(&divisor)) return spqlios_error("divisor must be a power of 2");
|
||||
if (m & (m - 1)) return spqlios_error("m must be a power of 2");
|
||||
if (log2overhead > 52) return spqlios_error("log2overhead is too large");
|
||||
res->m = m;
|
||||
res->divisor = divisor;
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
if (log2overhead <= 18) {
|
||||
if (m >= 8) {
|
||||
res->function = cplx_to_tnx32_avx2_fma;
|
||||
} else {
|
||||
res->function = cplx_to_tnx32_ref;
|
||||
}
|
||||
} else {
|
||||
res->function = cplx_to_tnx32_ref;
|
||||
}
|
||||
} else {
|
||||
res->function = cplx_to_tnx32_ref;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
EXPORT CPLX_TO_TNX32_PRECOMP* new_cplx_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead) {
|
||||
CPLX_TO_TNX32_PRECOMP* res = malloc(sizeof(CPLX_TO_TNX32_PRECOMP));
|
||||
if (!res) return spqlios_error(strerror(errno));
|
||||
return spqlios_keep_or_free(res, init_cplx_to_tnx32_precomp(res, m, divisor, log2overhead));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Simpler API for the znx32 to cplx conversion.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_from_znx32_simple(uint32_t m, void* r, const int32_t* x) {
|
||||
// not checking for log2bound which is not relevant here
|
||||
static CPLX_FROM_ZNX32_PRECOMP precomp[32];
|
||||
CPLX_FROM_ZNX32_PRECOMP* p = precomp + log2m(m);
|
||||
if (!p->function) {
|
||||
if (!init_cplx_from_znx32_precomp(p, m)) abort();
|
||||
}
|
||||
p->function(p, r, x);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Simpler API for the tnx32 to cplx conversion.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_from_tnx32_simple(uint32_t m, void* r, const int32_t* x) {
|
||||
static CPLX_FROM_TNX32_PRECOMP precomp[32];
|
||||
CPLX_FROM_TNX32_PRECOMP* p = precomp + log2m(m);
|
||||
if (!p->function) {
|
||||
if (!init_cplx_from_tnx32_precomp(p, m)) abort();
|
||||
}
|
||||
p->function(p, r, x);
|
||||
}
|
||||
/**
|
||||
* @brief Simpler API for the cplx to tnx32 conversion.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x) {
|
||||
struct LAST_CPLX_TO_TNX32_PRECOMP {
|
||||
CPLX_TO_TNX32_PRECOMP p;
|
||||
double last_divisor;
|
||||
double last_log2over;
|
||||
};
|
||||
static __thread struct LAST_CPLX_TO_TNX32_PRECOMP precomp[32];
|
||||
struct LAST_CPLX_TO_TNX32_PRECOMP* p = precomp + log2m(m);
|
||||
if (!p->p.function || divisor != p->last_divisor || log2overhead != p->last_log2over) {
|
||||
memset(p, 0, sizeof(*p));
|
||||
if (!init_cplx_to_tnx32_precomp(&p->p, m, divisor, log2overhead)) abort();
|
||||
p->last_divisor = divisor;
|
||||
p->last_log2over = log2overhead;
|
||||
}
|
||||
p->p.function(&p->p, r, x);
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
#include <immintrin.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
typedef int32_t I8MEM[8];
|
||||
typedef double D4MEM[4];
|
||||
|
||||
__always_inline void cplx_from_any_fma(uint64_t m, void* r, const int32_t* x, const __m256i C, const __m256d R) {
|
||||
const __m256i S = _mm256_set1_epi32(0x80000000);
|
||||
const I8MEM* inre = (I8MEM*)(x);
|
||||
const I8MEM* inim = (I8MEM*)(x + m);
|
||||
D4MEM* out = (D4MEM*)r;
|
||||
const uint64_t ms8 = m / 8;
|
||||
for (uint32_t i = 0; i < ms8; ++i) {
|
||||
__m256i rea = _mm256_loadu_si256((__m256i*)inre[0]);
|
||||
__m256i ima = _mm256_loadu_si256((__m256i*)inim[0]);
|
||||
rea = _mm256_add_epi32(rea, S);
|
||||
ima = _mm256_add_epi32(ima, S);
|
||||
__m256i tmpa = _mm256_unpacklo_epi32(rea, ima);
|
||||
__m256i tmpc = _mm256_unpackhi_epi32(rea, ima);
|
||||
__m256i cpla = _mm256_permute2x128_si256(tmpa, tmpc, 0x20);
|
||||
__m256i cplc = _mm256_permute2x128_si256(tmpa, tmpc, 0x31);
|
||||
tmpa = _mm256_unpacklo_epi32(cpla, C);
|
||||
__m256i tmpb = _mm256_unpackhi_epi32(cpla, C);
|
||||
tmpc = _mm256_unpacklo_epi32(cplc, C);
|
||||
__m256i tmpd = _mm256_unpackhi_epi32(cplc, C);
|
||||
cpla = _mm256_permute2x128_si256(tmpa, tmpb, 0x20);
|
||||
__m256i cplb = _mm256_permute2x128_si256(tmpa, tmpb, 0x31);
|
||||
cplc = _mm256_permute2x128_si256(tmpc, tmpd, 0x20);
|
||||
__m256i cpld = _mm256_permute2x128_si256(tmpc, tmpd, 0x31);
|
||||
__m256d dcpla = _mm256_sub_pd(_mm256_castsi256_pd(cpla), R);
|
||||
__m256d dcplb = _mm256_sub_pd(_mm256_castsi256_pd(cplb), R);
|
||||
__m256d dcplc = _mm256_sub_pd(_mm256_castsi256_pd(cplc), R);
|
||||
__m256d dcpld = _mm256_sub_pd(_mm256_castsi256_pd(cpld), R);
|
||||
_mm256_storeu_pd(out[0], dcpla);
|
||||
_mm256_storeu_pd(out[1], dcplb);
|
||||
_mm256_storeu_pd(out[2], dcplc);
|
||||
_mm256_storeu_pd(out[3], dcpld);
|
||||
inre += 1;
|
||||
inim += 1;
|
||||
out += 4;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||
// note: the hex code of 2^31 + 2^52 is 0x4330000080000000
|
||||
const __m256i C = _mm256_set1_epi32(0x43300000);
|
||||
const __m256d R = _mm256_set1_pd((INT64_C(1) << 31) + (INT64_C(1) << 52));
|
||||
// double XX = INT64_C(1) + (INT64_C(1)<<31) + (INT64_C(1)<<52);
|
||||
// printf("\n\n%016lx\n", *(uint64_t*)&XX);
|
||||
// abort();
|
||||
const uint64_t m = precomp->m;
|
||||
cplx_from_any_fma(m, r, x, C, R);
|
||||
}
|
||||
|
||||
EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||
// note: the hex code of 2^-1 + 2^30 is 0x4130000080000000
|
||||
const __m256i C = _mm256_set1_epi32(0x41300000);
|
||||
const __m256d R = _mm256_set1_pd(0.5 + (INT64_C(1) << 20));
|
||||
// double XX = (double)(INT64_C(1) + (INT64_C(1)<<31) + (INT64_C(1)<<52))/(INT64_C(1)<<32);
|
||||
// printf("\n\n%016lx\n", *(uint64_t*)&XX);
|
||||
// abort();
|
||||
const uint64_t m = precomp->m;
|
||||
cplx_from_any_fma(m, r, x, C, R);
|
||||
}
|
||||
|
||||
EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) {
|
||||
const __m256d R = _mm256_set1_pd((0.5 + (INT64_C(3) << 19)) * precomp->divisor);
|
||||
const __m256i MASK = _mm256_set1_epi64x(0xFFFFFFFFUL);
|
||||
const __m256i S = _mm256_set1_epi32(0x80000000);
|
||||
// const __m256i IDX = _mm256_set_epi32(0,4,1,5,2,6,3,7);
|
||||
const __m256i IDX = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
|
||||
const uint64_t m = precomp->m;
|
||||
const uint64_t ms8 = m / 8;
|
||||
I8MEM* outre = (I8MEM*)r;
|
||||
I8MEM* outim = (I8MEM*)(r + m);
|
||||
const D4MEM* in = x;
|
||||
// Note: this formula will only work if abs(in) < 2^32
|
||||
for (uint32_t i = 0; i < ms8; ++i) {
|
||||
__m256d cpla = _mm256_loadu_pd(in[0]);
|
||||
__m256d cplb = _mm256_loadu_pd(in[1]);
|
||||
__m256d cplc = _mm256_loadu_pd(in[2]);
|
||||
__m256d cpld = _mm256_loadu_pd(in[3]);
|
||||
__m256i icpla = _mm256_castpd_si256(_mm256_add_pd(cpla, R));
|
||||
__m256i icplb = _mm256_castpd_si256(_mm256_add_pd(cplb, R));
|
||||
__m256i icplc = _mm256_castpd_si256(_mm256_add_pd(cplc, R));
|
||||
__m256i icpld = _mm256_castpd_si256(_mm256_add_pd(cpld, R));
|
||||
icpla = _mm256_or_si256(_mm256_and_si256(icpla, MASK), _mm256_slli_epi64(icplb, 32));
|
||||
icplc = _mm256_or_si256(_mm256_and_si256(icplc, MASK), _mm256_slli_epi64(icpld, 32));
|
||||
icpla = _mm256_xor_si256(icpla, S);
|
||||
icplc = _mm256_xor_si256(icplc, S);
|
||||
__m256i re = _mm256_unpacklo_epi64(icpla, icplc);
|
||||
__m256i im = _mm256_unpackhi_epi64(icpla, icplc);
|
||||
re = _mm256_permutevar8x32_epi32(re, IDX);
|
||||
im = _mm256_permutevar8x32_epi32(im, IDX);
|
||||
_mm256_storeu_si256((__m256i*)outre[0], re);
|
||||
_mm256_storeu_si256((__m256i*)outim[0], im);
|
||||
outre += 1;
|
||||
outim += 1;
|
||||
in += 4;
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
EXPORT void cplx_from_znx32(const CPLX_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a) {
|
||||
tables->function(tables, r, a);
|
||||
}
|
||||
EXPORT void cplx_from_tnx32(const CPLX_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a) {
|
||||
tables->function(tables, r, a);
|
||||
}
|
||||
EXPORT void cplx_to_tnx32(const CPLX_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a) {
|
||||
tables->function(tables, r, a);
|
||||
}
|
||||
EXPORT void cplx_fftvec_mul(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||
tables->function(tables, r, a, b);
|
||||
}
|
||||
EXPORT void cplx_fftvec_addmul(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||
tables->function(tables, r, a, b);
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||
UNDEFINED(); // not defined for non x86 targets
|
||||
}
|
||||
EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||
UNDEFINED();
|
||||
}
|
||||
EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||
UNDEFINED();
|
||||
}
|
||||
EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a,
|
||||
const void* b) {
|
||||
UNDEFINED();
|
||||
}
|
||||
EXPORT void cplx_fft16_avx_fma(void* data, const void* omega) { UNDEFINED(); }
|
||||
EXPORT void cplx_ifft16_avx_fma(void* data, const void* omega) { UNDEFINED(); }
|
||||
EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); }
|
||||
EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); }
|
||||
EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c) { UNDEFINED(); }
|
||||
EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* tables, void* data){UNDEFINED()} EXPORT
|
||||
void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* itables, void* data){UNDEFINED()} EXPORT
|
||||
void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om){
|
||||
UNDEFINED()} EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b,
|
||||
const void* om){UNDEFINED()} EXPORT
|
||||
void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
|
||||
const void* om){UNDEFINED()} EXPORT
|
||||
void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
|
||||
const void* om){UNDEFINED()}
|
||||
|
||||
// DEPRECATED?
|
||||
EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT
|
||||
void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT
|
||||
void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) {
|
||||
UNDEFINED()
|
||||
}
|
||||
|
||||
// executors
|
||||
// EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* itables, void* data) {
|
||||
// itables->function(itables, data);
|
||||
//}
|
||||
// EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data) { tables->function(tables, data); }
|
||||
@@ -1,221 +0,0 @@
|
||||
#ifndef SPQLIOS_CPLX_FFT_H
|
||||
#define SPQLIOS_CPLX_FFT_H
|
||||
|
||||
#include "../commons.h"
|
||||
|
||||
typedef struct cplx_fft_precomp CPLX_FFT_PRECOMP;
|
||||
typedef struct cplx_ifft_precomp CPLX_IFFT_PRECOMP;
|
||||
typedef struct cplx_mul_precomp CPLX_FFTVEC_MUL_PRECOMP;
|
||||
typedef struct cplx_addmul_precomp CPLX_FFTVEC_ADDMUL_PRECOMP;
|
||||
typedef struct cplx_from_znx32_precomp CPLX_FROM_ZNX32_PRECOMP;
|
||||
typedef struct cplx_from_tnx32_precomp CPLX_FROM_TNX32_PRECOMP;
|
||||
typedef struct cplx_to_tnx32_precomp CPLX_TO_TNX32_PRECOMP;
|
||||
typedef struct cplx_to_znx32_precomp CPLX_TO_ZNX32_PRECOMP;
|
||||
typedef struct cplx_from_rnx64_precomp CPLX_FROM_RNX64_PRECOMP;
|
||||
typedef struct cplx_to_rnx64_precomp CPLX_TO_RNX64_PRECOMP;
|
||||
typedef struct cplx_round_to_rnx64_precomp CPLX_ROUND_TO_RNX64_PRECOMP;
|
||||
|
||||
/**
|
||||
* @brief precomputes fft tables.
|
||||
* The FFT tables contains a constant section that is required for efficient FFT operations in dimension nn.
|
||||
* The resulting pointer is to be passed as "tables" argument to any call to the fft function.
|
||||
* The user can optionnally allocate zero or more computation buffers, which are scratch spaces that are contiguous to
|
||||
* the constant tables in memory, and allow for more efficient operations. It is the user's responsibility to ensure
|
||||
* that each of those buffers are never used simultaneously by two ffts on different threads at the same time. The fft
|
||||
* table must be deleted by delete_fft_precomp after its last usage.
|
||||
*/
|
||||
EXPORT CPLX_FFT_PRECOMP* new_cplx_fft_precomp(uint32_t m, uint32_t num_buffers);
|
||||
|
||||
/**
|
||||
* @brief gets the address of a fft buffer allocated during new_fft_precomp.
|
||||
* This buffer can be used as data pointer in subsequent calls to fft,
|
||||
* and does not need to be released afterwards.
|
||||
*/
|
||||
EXPORT void* cplx_fft_precomp_get_buffer(const CPLX_FFT_PRECOMP* tables, uint32_t buffer_index);
|
||||
|
||||
/**
|
||||
* @brief allocates a new fft buffer.
|
||||
* This buffer can be used as data pointer in subsequent calls to fft,
|
||||
* and must be deleted afterwards by calling delete_fft_buffer.
|
||||
*/
|
||||
EXPORT void* new_cplx_fft_buffer(uint32_t m);
|
||||
|
||||
/**
|
||||
* @brief allocates a new fft buffer.
|
||||
* This buffer can be used as data pointer in subsequent calls to fft,
|
||||
* and must be deleted afterwards by calling delete_fft_buffer.
|
||||
*/
|
||||
EXPORT void delete_cplx_fft_buffer(void* buffer);
|
||||
|
||||
/**
|
||||
* @brief deallocates a fft table and all its built-in buffers.
|
||||
*/
|
||||
#define delete_cplx_fft_precomp free
|
||||
|
||||
/**
|
||||
* @brief computes a direct fft in-place over data.
|
||||
*/
|
||||
EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data);
|
||||
|
||||
EXPORT CPLX_IFFT_PRECOMP* new_cplx_ifft_precomp(uint32_t m, uint32_t num_buffers);
|
||||
EXPORT void* cplx_ifft_precomp_get_buffer(const CPLX_IFFT_PRECOMP* tables, uint32_t buffer_index);
|
||||
EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* tables, void* data);
|
||||
#define delete_cplx_ifft_precomp free
|
||||
|
||||
EXPORT CPLX_FFTVEC_MUL_PRECOMP* new_cplx_fftvec_mul_precomp(uint32_t m);
|
||||
EXPORT void cplx_fftvec_mul(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b);
|
||||
#define delete_cplx_fftvec_mul_precomp free
|
||||
|
||||
EXPORT CPLX_FFTVEC_ADDMUL_PRECOMP* new_cplx_fftvec_addmul_precomp(uint32_t m);
|
||||
EXPORT void cplx_fftvec_addmul(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b);
|
||||
#define delete_cplx_fftvec_addmul_precomp free
|
||||
|
||||
/**
|
||||
* @brief prepares a conversion from ZnX to the cplx layout.
|
||||
* All the coefficients must be strictly lower than 2^log2bound in absolute value. Any attempt to use
|
||||
* this function on a larger coefficient is undefined behaviour. The resulting precomputed data must
|
||||
* be freed with `new_cplx_from_znx32_precomp`
|
||||
* @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m
|
||||
* int32 coefficients in natural order modulo X^n+1
|
||||
* @param log2bound bound on the input coefficients. Must be between 0 and 32
|
||||
*/
|
||||
EXPORT CPLX_FROM_ZNX32_PRECOMP* new_cplx_from_znx32_precomp(uint32_t m);
|
||||
/**
|
||||
* @brief converts from ZnX to the cplx layout.
|
||||
* @param tables precomputed data obtained by new_cplx_from_znx32_precomp.
|
||||
* @param r resulting array of m complexes coefficients mod X^m-i
|
||||
* @param x input array of n bounded integer coefficients mod X^n+1
|
||||
*/
|
||||
EXPORT void cplx_from_znx32(const CPLX_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a);
|
||||
/** @brief frees a precomputed conversion data initialized with new_cplx_from_znx32_precomp. */
|
||||
#define delete_cplx_from_znx32_precomp free
|
||||
|
||||
/**
|
||||
* @brief prepares a conversion from TnX to the cplx layout.
|
||||
* @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m
|
||||
* torus32 coefficients. The resulting precomputed data must
|
||||
* be freed with `delete_cplx_from_tnx32_precomp`
|
||||
*/
|
||||
EXPORT CPLX_FROM_TNX32_PRECOMP* new_cplx_from_tnx32_precomp(uint32_t m);
|
||||
/**
|
||||
* @brief converts from TnX to the cplx layout.
|
||||
* @param tables precomputed data obtained by new_cplx_from_tnx32_precomp.
|
||||
* @param r resulting array of m complexes coefficients mod X^m-i
|
||||
* @param x input array of n torus32 coefficients mod X^n+1
|
||||
*/
|
||||
EXPORT void cplx_from_tnx32(const CPLX_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a);
|
||||
/** @brief frees a precomputed conversion data initialized with new_cplx_from_tnx32_precomp. */
|
||||
#define delete_cplx_from_tnx32_precomp free
|
||||
|
||||
/**
|
||||
* @brief prepares a rescale and conversion from the cplx layout to TnX.
|
||||
* @param m the target complex dimension m from C[X] mod X^m-i. Note that the outputs have n=2m
|
||||
* torus32 coefficients.
|
||||
* @param divisor must be a power of two. The inputs are rescaled by divisor before being reduced modulo 1.
|
||||
* Remember that the output of an iFFT must be divided by m.
|
||||
* @param log2overhead all inputs absolute values must be within divisor.2^log2overhead.
|
||||
* For any inputs outside of these bounds, the conversion is undefined behaviour.
|
||||
* The maximum supported log2overhead is 52, and the algorithm is faster for log2overhead=18.
|
||||
*/
|
||||
EXPORT CPLX_TO_TNX32_PRECOMP* new_cplx_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead);
|
||||
/**
|
||||
* @brief rescale, converts and reduce mod 1 from cplx layout to torus32.
|
||||
* @param tables precomputed data obtained by new_cplx_from_tnx32_precomp.
|
||||
* @param r resulting array of n torus32 coefficients mod X^n+1
|
||||
* @param x input array of m cplx coefficients mod X^m-i
|
||||
*/
|
||||
EXPORT void cplx_to_tnx32(const CPLX_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a);
|
||||
#define delete_cplx_to_tnx32_precomp free
|
||||
|
||||
EXPORT CPLX_TO_ZNX32_PRECOMP* new_cplx_to_znx32_precomp(uint32_t m, double divisor);
|
||||
EXPORT void cplx_to_znx32(const CPLX_TO_ZNX32_PRECOMP* precomp, int32_t* r, const void* x);
|
||||
#define delete_cplx_to_znx32_simple free
|
||||
|
||||
EXPORT CPLX_FROM_RNX64_PRECOMP* new_cplx_from_rnx64_simple(uint32_t m);
|
||||
EXPORT void cplx_from_rnx64(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x);
|
||||
#define delete_cplx_from_rnx64_simple free
|
||||
|
||||
EXPORT CPLX_TO_RNX64_PRECOMP* new_cplx_to_rnx64(uint32_t m, double divisor);
|
||||
EXPORT void cplx_to_rnx64(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||
#define delete_cplx_round_to_rnx64_simple free
|
||||
|
||||
EXPORT CPLX_ROUND_TO_RNX64_PRECOMP* new_cplx_round_to_rnx64(uint32_t m, double divisor, uint32_t log2bound);
|
||||
EXPORT void cplx_round_to_rnx64(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||
#define delete_cplx_round_to_rnx64_simple free
|
||||
|
||||
/**
|
||||
* @brief Simpler API for the fft function.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically.
|
||||
* It is advised to do one dry-run per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_fft_simple(uint32_t m, void* data);
|
||||
/**
|
||||
* @brief Simpler API for the ifft function.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension in the main thread before using in a multithread
|
||||
* environment */
|
||||
EXPORT void cplx_ifft_simple(uint32_t m, void* data);
|
||||
/**
|
||||
* @brief Simpler API for the fftvec multiplication function.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b);
|
||||
/**
|
||||
* @brief Simpler API for the fftvec addmul function.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b);
|
||||
/**
|
||||
* @brief Simpler API for the znx32 to cplx conversion.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_from_znx32_simple(uint32_t m, void* r, const int32_t* x);
|
||||
/**
|
||||
* @brief Simpler API for the tnx32 to cplx conversion.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_from_tnx32_simple(uint32_t m, void* r, const int32_t* x);
|
||||
/**
|
||||
* @brief Simpler API for the cplx to tnx32 conversion.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x);
|
||||
|
||||
/**
|
||||
* @brief converts, divides and round from cplx to znx32 (simple API)
|
||||
* @param m the complex dimension
|
||||
* @param divisor the divisor: a power of two, often m after an ifft
|
||||
* @param r the result: must be a double array of size 2m. r must be distinct from x
|
||||
* @param x the input: must hold m complex numbers.
|
||||
*/
|
||||
EXPORT void cplx_to_znx32_simple(uint32_t m, double divisor, int32_t* r, const void* x);
|
||||
|
||||
/**
|
||||
* @brief converts from rnx64 to cplx (simple API)
|
||||
* The bound on the output is assumed to be within ]2^-31,2^31[.
|
||||
* Any coefficient that would fall outside this range is undefined behaviour.
|
||||
* @param m the complex dimension
|
||||
* @param r the result: must be an array of m complex numbers. r must be distinct from x
|
||||
* @param x the input: must be an array of 2m doubles.
|
||||
*/
|
||||
EXPORT void cplx_from_rnx64_simple(uint32_t m, void* r, const double* x);
|
||||
|
||||
/**
|
||||
* @brief converts, divides from cplx to rnx64 (simple API)
|
||||
* @param m the complex dimension
|
||||
* @param divisor the divisor: a power of two, often m after an ifft
|
||||
* @param r the result: must be a double array of size 2m. r must be distinct from x
|
||||
* @param x the input: must hold m complex numbers.
|
||||
*/
|
||||
EXPORT void cplx_to_rnx64_simple(uint32_t m, double divisor, double* r, const void* x);
|
||||
|
||||
/**
|
||||
* @brief converts, divides and round to integer from cplx to rnx32 (simple API)
|
||||
* @param m the complex dimension
|
||||
* @param divisor the divisor: a power of two, often m after an ifft
|
||||
* @param log2bound a guarantee on the log2bound of the output. log2bound<=48 will use a more efficient algorithm.
|
||||
* @param r the result: must be a double array of size 2m. r must be distinct from x
|
||||
* @param x the input: must hold m complex numbers.
|
||||
*/
|
||||
EXPORT void cplx_round_to_rnx64_simple(uint32_t m, double divisor, uint32_t log2bound, double* r, const void* x);
|
||||
|
||||
#endif // SPQLIOS_CPLX_FFT_H
|
||||
@@ -1,156 +0,0 @@
|
||||
# shifted FFT over X^16-i
|
||||
# 1st argument (rdi) contains 16 complexes
|
||||
# 2nd argument (rsi) contains: 8 complexes
|
||||
# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma
|
||||
# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta)
|
||||
# j = sqrt(i), k=sqrt(j)
|
||||
.globl cplx_fft16_avx_fma
|
||||
cplx_fft16_avx_fma:
|
||||
vmovupd (%rdi),%ymm8
|
||||
vmovupd 0x20(%rdi),%ymm9
|
||||
vmovupd 0x40(%rdi),%ymm10
|
||||
vmovupd 0x60(%rdi),%ymm11
|
||||
vmovupd 0x80(%rdi),%ymm12
|
||||
vmovupd 0xa0(%rdi),%ymm13
|
||||
vmovupd 0xc0(%rdi),%ymm14
|
||||
vmovupd 0xe0(%rdi),%ymm15
|
||||
|
||||
.first_pass:
|
||||
vmovupd (%rsi),%xmm0 /* omri */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||
vshufpd $5, %ymm12, %ymm12, %ymm4
|
||||
vshufpd $5, %ymm13, %ymm13, %ymm5
|
||||
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||
vmulpd %ymm4,%ymm1,%ymm4
|
||||
vmulpd %ymm5,%ymm1,%ymm5
|
||||
vmulpd %ymm6,%ymm1,%ymm6
|
||||
vmulpd %ymm7,%ymm1,%ymm7
|
||||
vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4
|
||||
vfmaddsub231pd %ymm13, %ymm0, %ymm5
|
||||
vfmaddsub231pd %ymm14, %ymm0, %ymm6
|
||||
vfmaddsub231pd %ymm15, %ymm0, %ymm7
|
||||
vsubpd %ymm4,%ymm8,%ymm12
|
||||
vsubpd %ymm5,%ymm9,%ymm13
|
||||
vsubpd %ymm6,%ymm10,%ymm14
|
||||
vsubpd %ymm7,%ymm11,%ymm15
|
||||
vaddpd %ymm4,%ymm8,%ymm8
|
||||
vaddpd %ymm5,%ymm9,%ymm9
|
||||
vaddpd %ymm6,%ymm10,%ymm10
|
||||
vaddpd %ymm7,%ymm11,%ymm11
|
||||
|
||||
.second_pass:
|
||||
vmovupd 16(%rsi),%xmm0 /* omri */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||
vshufpd $5, %ymm10, %ymm10, %ymm4
|
||||
vshufpd $5, %ymm11, %ymm11, %ymm5
|
||||
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||
vmulpd %ymm4,%ymm1,%ymm4
|
||||
vmulpd %ymm5,%ymm1,%ymm5
|
||||
vmulpd %ymm6,%ymm0,%ymm6
|
||||
vmulpd %ymm7,%ymm0,%ymm7
|
||||
vfmaddsub231pd %ymm10, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4
|
||||
vfmaddsub231pd %ymm11, %ymm0, %ymm5
|
||||
vfmsubadd231pd %ymm14, %ymm1, %ymm6
|
||||
vfmsubadd231pd %ymm15, %ymm1, %ymm7
|
||||
vsubpd %ymm4,%ymm8,%ymm10
|
||||
vsubpd %ymm5,%ymm9,%ymm11
|
||||
vaddpd %ymm6,%ymm12,%ymm14
|
||||
vaddpd %ymm7,%ymm13,%ymm15
|
||||
vaddpd %ymm4,%ymm8,%ymm8
|
||||
vaddpd %ymm5,%ymm9,%ymm9
|
||||
vsubpd %ymm6,%ymm12,%ymm12
|
||||
vsubpd %ymm7,%ymm13,%ymm13
|
||||
|
||||
.third_pass:
|
||||
vmovupd 32(%rsi),%xmm0 /* gamma */
|
||||
vmovupd 48(%rsi),%xmm2 /* delta */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0
|
||||
vinsertf128 $1, %xmm2, %ymm2, %ymm2
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||
vshufpd $5, %ymm9, %ymm9, %ymm4
|
||||
vshufpd $5, %ymm11, %ymm11, %ymm5
|
||||
vshufpd $5, %ymm13, %ymm13, %ymm6
|
||||
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||
vmulpd %ymm4,%ymm1,%ymm4
|
||||
vmulpd %ymm5,%ymm0,%ymm5
|
||||
vmulpd %ymm6,%ymm3,%ymm6
|
||||
vmulpd %ymm7,%ymm2,%ymm7
|
||||
vfmaddsub231pd %ymm9, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4
|
||||
vfmsubadd231pd %ymm11, %ymm1, %ymm5
|
||||
vfmaddsub231pd %ymm13, %ymm2, %ymm6
|
||||
vfmsubadd231pd %ymm15, %ymm3, %ymm7
|
||||
vsubpd %ymm4,%ymm8,%ymm9
|
||||
vaddpd %ymm5,%ymm10,%ymm11
|
||||
vsubpd %ymm6,%ymm12,%ymm13
|
||||
vaddpd %ymm7,%ymm14,%ymm15
|
||||
vaddpd %ymm4,%ymm8,%ymm8
|
||||
vsubpd %ymm5,%ymm10,%ymm10
|
||||
vaddpd %ymm6,%ymm12,%ymm12
|
||||
vsubpd %ymm7,%ymm14,%ymm14
|
||||
|
||||
.fourth_pass:
|
||||
vmovupd 64(%rsi),%ymm0 /* gamma */
|
||||
vmovupd 96(%rsi),%ymm2 /* delta */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||
vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5 -- x gamma
|
||||
vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7 -- x igamma
|
||||
vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13 -- x delta
|
||||
vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15 -- x idelta
|
||||
vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||
vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||
vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12
|
||||
vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14
|
||||
vshufpd $5, %ymm4, %ymm4, %ymm12
|
||||
vshufpd $5, %ymm5, %ymm5, %ymm13
|
||||
vshufpd $5, %ymm6, %ymm6, %ymm14
|
||||
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||
vmulpd %ymm12,%ymm1,%ymm12
|
||||
vmulpd %ymm13,%ymm0,%ymm13
|
||||
vmulpd %ymm14,%ymm3,%ymm14
|
||||
vmulpd %ymm15,%ymm2,%ymm15
|
||||
vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12
|
||||
vfmsubadd231pd %ymm5, %ymm1, %ymm13
|
||||
vfmaddsub231pd %ymm6, %ymm2, %ymm14
|
||||
vfmsubadd231pd %ymm7, %ymm3, %ymm15
|
||||
vsubpd %ymm12,%ymm8,%ymm4
|
||||
vaddpd %ymm13,%ymm9,%ymm5
|
||||
vsubpd %ymm14,%ymm10,%ymm6
|
||||
vaddpd %ymm15,%ymm11,%ymm7
|
||||
vaddpd %ymm12,%ymm8,%ymm8
|
||||
vsubpd %ymm13,%ymm9,%ymm9
|
||||
vaddpd %ymm14,%ymm10,%ymm10
|
||||
vsubpd %ymm15,%ymm11,%ymm11
|
||||
|
||||
vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma
|
||||
vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma
|
||||
vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta
|
||||
vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta
|
||||
vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12
|
||||
vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14
|
||||
vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||
vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||
|
||||
.save_and_return:
|
||||
vmovupd %ymm8,(%rdi)
|
||||
vmovupd %ymm9,0x20(%rdi)
|
||||
vmovupd %ymm10,0x40(%rdi)
|
||||
vmovupd %ymm11,0x60(%rdi)
|
||||
vmovupd %ymm12,0x80(%rdi)
|
||||
vmovupd %ymm13,0xa0(%rdi)
|
||||
vmovupd %ymm14,0xc0(%rdi)
|
||||
vmovupd %ymm15,0xe0(%rdi)
|
||||
ret
|
||||
.size cplx_fft16_avx_fma, .-cplx_fft16_avx_fma
|
||||
.section .note.GNU-stack,"",@progbits
|
||||
@@ -1,190 +0,0 @@
|
||||
.text
|
||||
.p2align 4
|
||||
.globl cplx_fft16_avx_fma
|
||||
.def cplx_fft16_avx_fma; .scl 2; .type 32; .endef
|
||||
cplx_fft16_avx_fma:
|
||||
|
||||
pushq %rdi
|
||||
pushq %rsi
|
||||
movq %rcx,%rdi
|
||||
movq %rdx,%rsi
|
||||
subq $0x100,%rsp
|
||||
movdqu %xmm6,(%rsp)
|
||||
movdqu %xmm7,0x10(%rsp)
|
||||
movdqu %xmm8,0x20(%rsp)
|
||||
movdqu %xmm9,0x30(%rsp)
|
||||
movdqu %xmm10,0x40(%rsp)
|
||||
movdqu %xmm11,0x50(%rsp)
|
||||
movdqu %xmm12,0x60(%rsp)
|
||||
movdqu %xmm13,0x70(%rsp)
|
||||
movdqu %xmm14,0x80(%rsp)
|
||||
movdqu %xmm15,0x90(%rsp)
|
||||
callq cplx_fft16_avx_fma_amd64
|
||||
movdqu (%rsp),%xmm6
|
||||
movdqu 0x10(%rsp),%xmm7
|
||||
movdqu 0x20(%rsp),%xmm8
|
||||
movdqu 0x30(%rsp),%xmm9
|
||||
movdqu 0x40(%rsp),%xmm10
|
||||
movdqu 0x50(%rsp),%xmm11
|
||||
movdqu 0x60(%rsp),%xmm12
|
||||
movdqu 0x70(%rsp),%xmm13
|
||||
movdqu 0x80(%rsp),%xmm14
|
||||
movdqu 0x90(%rsp),%xmm15
|
||||
addq $0x100,%rsp
|
||||
popq %rsi
|
||||
popq %rdi
|
||||
retq
|
||||
|
||||
# shifted FFT over X^16-i
|
||||
# 1st argument (rdi) contains 16 complexes
|
||||
# 2nd argument (rsi) contains: 8 complexes
|
||||
# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma
|
||||
# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta)
|
||||
# j = sqrt(i), k=sqrt(j)
|
||||
cplx_fft16_avx_fma_amd64:
|
||||
vmovupd (%rdi),%ymm8
|
||||
vmovupd 0x20(%rdi),%ymm9
|
||||
vmovupd 0x40(%rdi),%ymm10
|
||||
vmovupd 0x60(%rdi),%ymm11
|
||||
vmovupd 0x80(%rdi),%ymm12
|
||||
vmovupd 0xa0(%rdi),%ymm13
|
||||
vmovupd 0xc0(%rdi),%ymm14
|
||||
vmovupd 0xe0(%rdi),%ymm15
|
||||
|
||||
.first_pass:
|
||||
vmovupd (%rsi),%xmm0 /* omri */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||
vshufpd $5, %ymm12, %ymm12, %ymm4
|
||||
vshufpd $5, %ymm13, %ymm13, %ymm5
|
||||
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||
vmulpd %ymm4,%ymm1,%ymm4
|
||||
vmulpd %ymm5,%ymm1,%ymm5
|
||||
vmulpd %ymm6,%ymm1,%ymm6
|
||||
vmulpd %ymm7,%ymm1,%ymm7
|
||||
vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4
|
||||
vfmaddsub231pd %ymm13, %ymm0, %ymm5
|
||||
vfmaddsub231pd %ymm14, %ymm0, %ymm6
|
||||
vfmaddsub231pd %ymm15, %ymm0, %ymm7
|
||||
vsubpd %ymm4,%ymm8,%ymm12
|
||||
vsubpd %ymm5,%ymm9,%ymm13
|
||||
vsubpd %ymm6,%ymm10,%ymm14
|
||||
vsubpd %ymm7,%ymm11,%ymm15
|
||||
vaddpd %ymm4,%ymm8,%ymm8
|
||||
vaddpd %ymm5,%ymm9,%ymm9
|
||||
vaddpd %ymm6,%ymm10,%ymm10
|
||||
vaddpd %ymm7,%ymm11,%ymm11
|
||||
|
||||
.second_pass:
|
||||
vmovupd 16(%rsi),%xmm0 /* omri */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||
vshufpd $5, %ymm10, %ymm10, %ymm4
|
||||
vshufpd $5, %ymm11, %ymm11, %ymm5
|
||||
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||
vmulpd %ymm4,%ymm1,%ymm4
|
||||
vmulpd %ymm5,%ymm1,%ymm5
|
||||
vmulpd %ymm6,%ymm0,%ymm6
|
||||
vmulpd %ymm7,%ymm0,%ymm7
|
||||
vfmaddsub231pd %ymm10, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4
|
||||
vfmaddsub231pd %ymm11, %ymm0, %ymm5
|
||||
vfmsubadd231pd %ymm14, %ymm1, %ymm6
|
||||
vfmsubadd231pd %ymm15, %ymm1, %ymm7
|
||||
vsubpd %ymm4,%ymm8,%ymm10
|
||||
vsubpd %ymm5,%ymm9,%ymm11
|
||||
vaddpd %ymm6,%ymm12,%ymm14
|
||||
vaddpd %ymm7,%ymm13,%ymm15
|
||||
vaddpd %ymm4,%ymm8,%ymm8
|
||||
vaddpd %ymm5,%ymm9,%ymm9
|
||||
vsubpd %ymm6,%ymm12,%ymm12
|
||||
vsubpd %ymm7,%ymm13,%ymm13
|
||||
|
||||
.third_pass:
|
||||
vmovupd 32(%rsi),%xmm0 /* gamma */
|
||||
vmovupd 48(%rsi),%xmm2 /* delta */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0
|
||||
vinsertf128 $1, %xmm2, %ymm2, %ymm2
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||
vshufpd $5, %ymm9, %ymm9, %ymm4
|
||||
vshufpd $5, %ymm11, %ymm11, %ymm5
|
||||
vshufpd $5, %ymm13, %ymm13, %ymm6
|
||||
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||
vmulpd %ymm4,%ymm1,%ymm4
|
||||
vmulpd %ymm5,%ymm0,%ymm5
|
||||
vmulpd %ymm6,%ymm3,%ymm6
|
||||
vmulpd %ymm7,%ymm2,%ymm7
|
||||
vfmaddsub231pd %ymm9, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4
|
||||
vfmsubadd231pd %ymm11, %ymm1, %ymm5
|
||||
vfmaddsub231pd %ymm13, %ymm2, %ymm6
|
||||
vfmsubadd231pd %ymm15, %ymm3, %ymm7
|
||||
vsubpd %ymm4,%ymm8,%ymm9
|
||||
vaddpd %ymm5,%ymm10,%ymm11
|
||||
vsubpd %ymm6,%ymm12,%ymm13
|
||||
vaddpd %ymm7,%ymm14,%ymm15
|
||||
vaddpd %ymm4,%ymm8,%ymm8
|
||||
vsubpd %ymm5,%ymm10,%ymm10
|
||||
vaddpd %ymm6,%ymm12,%ymm12
|
||||
vsubpd %ymm7,%ymm14,%ymm14
|
||||
|
||||
.fourth_pass:
|
||||
vmovupd 64(%rsi),%ymm0 /* gamma */
|
||||
vmovupd 96(%rsi),%ymm2 /* delta */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||
vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5 -- x gamma
|
||||
vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7 -- x igamma
|
||||
vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13 -- x delta
|
||||
vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15 -- x idelta
|
||||
vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||
vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||
vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12
|
||||
vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14
|
||||
vshufpd $5, %ymm4, %ymm4, %ymm12
|
||||
vshufpd $5, %ymm5, %ymm5, %ymm13
|
||||
vshufpd $5, %ymm6, %ymm6, %ymm14
|
||||
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||
vmulpd %ymm12,%ymm1,%ymm12
|
||||
vmulpd %ymm13,%ymm0,%ymm13
|
||||
vmulpd %ymm14,%ymm3,%ymm14
|
||||
vmulpd %ymm15,%ymm2,%ymm15
|
||||
vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12
|
||||
vfmsubadd231pd %ymm5, %ymm1, %ymm13
|
||||
vfmaddsub231pd %ymm6, %ymm2, %ymm14
|
||||
vfmsubadd231pd %ymm7, %ymm3, %ymm15
|
||||
vsubpd %ymm12,%ymm8,%ymm4
|
||||
vaddpd %ymm13,%ymm9,%ymm5
|
||||
vsubpd %ymm14,%ymm10,%ymm6
|
||||
vaddpd %ymm15,%ymm11,%ymm7
|
||||
vaddpd %ymm12,%ymm8,%ymm8
|
||||
vsubpd %ymm13,%ymm9,%ymm9
|
||||
vaddpd %ymm14,%ymm10,%ymm10
|
||||
vsubpd %ymm15,%ymm11,%ymm11
|
||||
|
||||
vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma
|
||||
vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma
|
||||
vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta
|
||||
vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta
|
||||
vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12
|
||||
vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14
|
||||
vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||
vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||
|
||||
.save_and_return:
|
||||
vmovupd %ymm8,(%rdi)
|
||||
vmovupd %ymm9,0x20(%rdi)
|
||||
vmovupd %ymm10,0x40(%rdi)
|
||||
vmovupd %ymm11,0x60(%rdi)
|
||||
vmovupd %ymm12,0x80(%rdi)
|
||||
vmovupd %ymm13,0xa0(%rdi)
|
||||
vmovupd %ymm14,0xc0(%rdi)
|
||||
vmovupd %ymm15,0xe0(%rdi)
|
||||
ret
|
||||
@@ -1,8 +0,0 @@
|
||||
#include "../commons_private.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
__always_inline void my_asserts() {
|
||||
STATIC_ASSERT(sizeof(FFT_FUNCTION) == 8);
|
||||
STATIC_ASSERT(sizeof(CPLX_FFT_PRECOMP) == 40);
|
||||
STATIC_ASSERT(sizeof(CPLX_IFFT_PRECOMP) == 40);
|
||||
}
|
||||
@@ -1,266 +0,0 @@
|
||||
#include <immintrin.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
typedef double D4MEM[4];
|
||||
|
||||
/**
|
||||
* @brief complex fft via bfs strategy (for m between 2 and 8)
|
||||
* @param dat the data to run the algorithm on
|
||||
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||
*/
|
||||
void cplx_fft_avx2_fma_bfs_2(D4MEM* dat, const D4MEM** omg, uint32_t m) {
|
||||
double* data = (double*)dat;
|
||||
int32_t _2nblock = m >> 1; // = h in ref code
|
||||
D4MEM* const finaldd = (D4MEM*)(data + 2 * m);
|
||||
while (_2nblock >= 2) {
|
||||
int32_t nblock = _2nblock >> 1; // =h/2 in ref code
|
||||
D4MEM* dd = (D4MEM*)data;
|
||||
do {
|
||||
const __m256d om = _mm256_load_pd(*omg[0]);
|
||||
const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om));
|
||||
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||
D4MEM* const ddend = (dd + nblock);
|
||||
D4MEM* ddmid = ddend;
|
||||
do {
|
||||
const __m256d b = _mm256_loadu_pd(ddmid[0]);
|
||||
const __m256d t1 = _mm256_mul_pd(b, omre);
|
||||
const __m256d barb = _mm256_shuffle_pd(b, b, 5);
|
||||
const __m256d t2 = _mm256_fmadd_pd(barb, omim, t1);
|
||||
const __m256d a = _mm256_loadu_pd(dd[0]);
|
||||
const __m256d newa = _mm256_add_pd(a, t2);
|
||||
const __m256d newb = _mm256_sub_pd(a, t2);
|
||||
_mm256_storeu_pd(dd[0], newa);
|
||||
_mm256_storeu_pd(ddmid[0], newb);
|
||||
dd += 1;
|
||||
ddmid += 1;
|
||||
} while (dd < ddend);
|
||||
dd += nblock;
|
||||
*omg += 1;
|
||||
} while (dd < finaldd);
|
||||
_2nblock >>= 1;
|
||||
}
|
||||
// last iteration when _2nblock == 1
|
||||
{
|
||||
D4MEM* dd = (D4MEM*)data;
|
||||
do {
|
||||
const __m256d om = _mm256_load_pd(*omg[0]);
|
||||
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||
const __m256d omim = _mm256_unpackhi_pd(om, om);
|
||||
const __m256d ab = _mm256_loadu_pd(dd[0]);
|
||||
const __m256d bb = _mm256_permute4x64_pd(ab, 0b11101110);
|
||||
const __m256d bbbar = _mm256_permute4x64_pd(ab, 0b10111011);
|
||||
const __m256d t1 = _mm256_mul_pd(bbbar, omim);
|
||||
const __m256d t2 = _mm256_fmaddsub_pd(bb, omre, t1);
|
||||
const __m256d aa = _mm256_permute4x64_pd(ab, 0b01000100);
|
||||
const __m256d newab = _mm256_add_pd(aa, t2);
|
||||
_mm256_storeu_pd(dd[0], newab);
|
||||
dd += 1;
|
||||
*omg += 1;
|
||||
} while (dd < finaldd);
|
||||
}
|
||||
}
|
||||
|
||||
__always_inline void cplx_twiddle_fft_avx2(int32_t h, D4MEM* data, const void* omg) {
|
||||
const __m256d om = _mm256_loadu_pd(omg);
|
||||
const __m256d omim = _mm256_unpackhi_pd(om, om);
|
||||
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||
D4MEM* d0 = data;
|
||||
D4MEM* const ddend = d0 + (h >> 1);
|
||||
D4MEM* d1 = ddend;
|
||||
do {
|
||||
const __m256d b = _mm256_loadu_pd(d1[0]);
|
||||
const __m256d barb = _mm256_shuffle_pd(b, b, 5);
|
||||
const __m256d t1 = _mm256_mul_pd(barb, omim);
|
||||
const __m256d t2 = _mm256_fmaddsub_pd(b, omre, t1);
|
||||
const __m256d a = _mm256_loadu_pd(d0[0]);
|
||||
const __m256d newa = _mm256_add_pd(a, t2);
|
||||
const __m256d newb = _mm256_sub_pd(a, t2);
|
||||
_mm256_storeu_pd(d0[0], newa);
|
||||
_mm256_storeu_pd(d1[0], newb);
|
||||
d0 += 1;
|
||||
d1 += 1;
|
||||
} while (d0 < ddend);
|
||||
}
|
||||
|
||||
__always_inline void cplx_bitwiddle_fft_avx2(int32_t h, void* data, const void* powom) {
|
||||
const __m256d omx = _mm256_loadu_pd(powom);
|
||||
const __m256d oma = _mm256_permute2f128_pd(omx, omx, 0x00);
|
||||
const __m256d omb = _mm256_permute2f128_pd(omx, omx, 0x11);
|
||||
const __m256d omaim = _mm256_unpackhi_pd(oma, oma);
|
||||
const __m256d omare = _mm256_unpacklo_pd(oma, oma);
|
||||
const __m256d ombim = _mm256_unpackhi_pd(omb, omb);
|
||||
const __m256d ombre = _mm256_unpacklo_pd(omb, omb);
|
||||
D4MEM* d0 = (D4MEM*)data;
|
||||
D4MEM* const ddend = d0 + (h >> 1);
|
||||
D4MEM* d1 = ddend;
|
||||
D4MEM* d2 = d0 + h;
|
||||
D4MEM* d3 = d1 + h;
|
||||
__m256d reg0, reg1, reg2, reg3, tmp0, tmp1;
|
||||
do {
|
||||
reg0 = _mm256_loadu_pd(d0[0]);
|
||||
reg1 = _mm256_loadu_pd(d1[0]);
|
||||
reg2 = _mm256_loadu_pd(d2[0]);
|
||||
reg3 = _mm256_loadu_pd(d3[0]);
|
||||
tmp0 = _mm256_shuffle_pd(reg2, reg2, 5);
|
||||
tmp1 = _mm256_shuffle_pd(reg3, reg3, 5);
|
||||
tmp0 = _mm256_mul_pd(tmp0, omaim);
|
||||
tmp1 = _mm256_mul_pd(tmp1, omaim);
|
||||
tmp0 = _mm256_fmaddsub_pd(reg2, omare, tmp0);
|
||||
tmp1 = _mm256_fmaddsub_pd(reg3, omare, tmp1);
|
||||
reg2 = _mm256_sub_pd(reg0, tmp0);
|
||||
reg3 = _mm256_sub_pd(reg1, tmp1);
|
||||
reg0 = _mm256_add_pd(reg0, tmp0);
|
||||
reg1 = _mm256_add_pd(reg1, tmp1);
|
||||
//--------------------------------------
|
||||
tmp0 = _mm256_shuffle_pd(reg1, reg1, 5);
|
||||
tmp1 = _mm256_shuffle_pd(reg3, reg3, 5);
|
||||
tmp0 = _mm256_mul_pd(tmp0, ombim); //(r,i)
|
||||
tmp1 = _mm256_mul_pd(tmp1, ombre); //(-i,r)
|
||||
tmp0 = _mm256_fmaddsub_pd(reg1, ombre, tmp0);
|
||||
tmp1 = _mm256_fmsubadd_pd(reg3, ombim, tmp1);
|
||||
reg1 = _mm256_sub_pd(reg0, tmp0);
|
||||
reg3 = _mm256_add_pd(reg2, tmp1);
|
||||
reg0 = _mm256_add_pd(reg0, tmp0);
|
||||
reg2 = _mm256_sub_pd(reg2, tmp1);
|
||||
/////
|
||||
_mm256_storeu_pd(d0[0], reg0);
|
||||
_mm256_storeu_pd(d1[0], reg1);
|
||||
_mm256_storeu_pd(d2[0], reg2);
|
||||
_mm256_storeu_pd(d3[0], reg3);
|
||||
d0 += 1;
|
||||
d1 += 1;
|
||||
d2 += 1;
|
||||
d3 += 1;
|
||||
} while (d0 < ddend);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief complex fft via bfs strategy (for m >= 16)
|
||||
* @param dat the data to run the algorithm on
|
||||
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||
*/
|
||||
void cplx_fft_avx2_fma_bfs_16(D4MEM* dat, const D4MEM** omg, uint32_t m) {
|
||||
double* data = (double*)dat;
|
||||
D4MEM* const finaldd = (D4MEM*)(data + 2 * m);
|
||||
uint32_t mm = m;
|
||||
uint32_t log2m = _mm_popcnt_u32(m - 1); // log2(m)
|
||||
if (log2m % 2 == 1) {
|
||||
uint32_t h = mm >> 1;
|
||||
cplx_twiddle_fft_avx2(h, dat, **omg);
|
||||
*omg += 1;
|
||||
mm >>= 1;
|
||||
}
|
||||
while (mm > 16) {
|
||||
uint32_t h = mm / 4;
|
||||
for (CPLX* d = (CPLX*)data; d < (CPLX*)finaldd; d += mm) {
|
||||
cplx_bitwiddle_fft_avx2(h, d, (CPLX*)*omg);
|
||||
*omg += 1;
|
||||
}
|
||||
mm = h;
|
||||
}
|
||||
{
|
||||
D4MEM* dd = (D4MEM*)data;
|
||||
do {
|
||||
cplx_fft16_avx_fma(dd, *omg);
|
||||
dd += 8;
|
||||
*omg += 4;
|
||||
} while (dd < finaldd);
|
||||
_mm256_zeroupper();
|
||||
}
|
||||
/*
|
||||
int32_t _2nblock = m >> 1; // = h in ref code
|
||||
while (_2nblock >= 16) {
|
||||
int32_t nblock = _2nblock >> 1; // =h/2 in ref code
|
||||
D4MEM* dd = (D4MEM*)data;
|
||||
do {
|
||||
const __m256d om = _mm256_load_pd(*omg[0]);
|
||||
const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om));
|
||||
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||
D4MEM* const ddend = (dd + nblock);
|
||||
D4MEM* ddmid = ddend;
|
||||
do {
|
||||
const __m256d b = _mm256_loadu_pd(ddmid[0]);
|
||||
const __m256d t1 = _mm256_mul_pd(b, omre);
|
||||
const __m256d barb = _mm256_shuffle_pd(b, b, 5);
|
||||
const __m256d t2 = _mm256_fmadd_pd(barb, omim, t1);
|
||||
const __m256d a = _mm256_loadu_pd(dd[0]);
|
||||
const __m256d newa = _mm256_add_pd(a, t2);
|
||||
const __m256d newb = _mm256_sub_pd(a, t2);
|
||||
_mm256_storeu_pd(dd[0], newa);
|
||||
_mm256_storeu_pd(ddmid[0], newb);
|
||||
dd += 1;
|
||||
ddmid += 1;
|
||||
} while (dd < ddend);
|
||||
dd += nblock;
|
||||
*omg += 1;
|
||||
} while (dd < finaldd);
|
||||
_2nblock >>= 1;
|
||||
}
|
||||
// last iteration when _2nblock == 8
|
||||
{
|
||||
D4MEM* dd = (D4MEM*)data;
|
||||
do {
|
||||
cplx_fft16_avx_fma(dd, *omg);
|
||||
dd += 8;
|
||||
*omg += 4;
|
||||
} while (dd < finaldd);
|
||||
_mm256_zeroupper();
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief complex fft via dfs recursion (for m >= 16)
|
||||
* @param dat the data to run the algorithm on
|
||||
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||
*/
|
||||
void cplx_fft_avx2_fma_rec_16(D4MEM* dat, const D4MEM** omg, uint32_t m) {
|
||||
if (m <= 8) return cplx_fft_avx2_fma_bfs_2(dat, omg, m);
|
||||
if (m <= 2048) return cplx_fft_avx2_fma_bfs_16(dat, omg, m);
|
||||
double* data = (double*)dat;
|
||||
int32_t _2nblock = m >> 1; // = h in ref code
|
||||
int32_t nblock = _2nblock >> 1; // =h/2 in ref code
|
||||
D4MEM* dd = (D4MEM*)data;
|
||||
const __m256d om = _mm256_load_pd(*omg[0]);
|
||||
const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om));
|
||||
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||
D4MEM* const ddend = (dd + nblock);
|
||||
D4MEM* ddmid = ddend;
|
||||
do {
|
||||
const __m256d b = _mm256_loadu_pd(ddmid[0]);
|
||||
const __m256d t1 = _mm256_mul_pd(b, omre);
|
||||
const __m256d barb = _mm256_shuffle_pd(b, b, 5);
|
||||
const __m256d t2 = _mm256_fmadd_pd(barb, omim, t1);
|
||||
const __m256d a = _mm256_loadu_pd(dd[0]);
|
||||
const __m256d newa = _mm256_add_pd(a, t2);
|
||||
const __m256d newb = _mm256_sub_pd(a, t2);
|
||||
_mm256_storeu_pd(dd[0], newa);
|
||||
_mm256_storeu_pd(ddmid[0], newb);
|
||||
dd += 1;
|
||||
ddmid += 1;
|
||||
} while (dd < ddend);
|
||||
*omg += 1;
|
||||
cplx_fft_avx2_fma_rec_16(dat, omg, _2nblock);
|
||||
cplx_fft_avx2_fma_rec_16(ddend, omg, _2nblock);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief complex fft via best strategy (for m>=1)
|
||||
* @param dat the data to run the algorithm on: m complex numbers
|
||||
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||
*/
|
||||
EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* precomp, void* d) {
|
||||
const uint32_t m = precomp->m;
|
||||
const D4MEM* omg = (D4MEM*)precomp->powomegas;
|
||||
if (m <= 1) return;
|
||||
if (m <= 8) return cplx_fft_avx2_fma_bfs_2(d, &omg, m);
|
||||
if (m <= 2048) return cplx_fft_avx2_fma_bfs_16(d, &omg, m);
|
||||
cplx_fft_avx2_fma_rec_16(d, &omg, m);
|
||||
}
|
||||
@@ -1,451 +0,0 @@
|
||||
#include <immintrin.h>
|
||||
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
typedef double D2MEM[2];
|
||||
typedef double D4MEM[4];
|
||||
typedef double D8MEM[8];
|
||||
|
||||
EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a,
|
||||
const void* b) {
|
||||
const uint32_t m = precomp->m;
|
||||
const D8MEM* aa = (D8MEM*)a;
|
||||
const D8MEM* bb = (D8MEM*)b;
|
||||
D8MEM* rr = (D8MEM*)r;
|
||||
const D8MEM* const aend = aa + (m >> 2);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m512d ari% = _mm512_loadu_pd(aa[%]);
|
||||
const __m512d bri% = _mm512_loadu_pd(bb[%]);
|
||||
const __m512d rri% = _mm512_loadu_pd(rr[%]);
|
||||
const __m512d bir% = _mm512_shuffle_pd(bri%,bri%, 0b01010101);
|
||||
const __m512d aii% = _mm512_shuffle_pd(ari%,ari%, 0b11111111);
|
||||
const __m512d pro% = _mm512_fmaddsub_pd(aii%,bir%,rri%);
|
||||
const __m512d arr% = _mm512_shuffle_pd(ari%,ari%, 0b00000000);
|
||||
const __m512d res% = _mm512_fmaddsub_pd(arr%,bri%,pro%);
|
||||
_mm512_storeu_pd(rr[%],res%);
|
||||
rr += @; // ONCE
|
||||
aa += @; // ONCE
|
||||
bb += @; // ONCE
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 2
|
||||
const __m512d ari0 = _mm512_loadu_pd(aa[0]);
|
||||
const __m512d ari1 = _mm512_loadu_pd(aa[1]);
|
||||
const __m512d bri0 = _mm512_loadu_pd(bb[0]);
|
||||
const __m512d bri1 = _mm512_loadu_pd(bb[1]);
|
||||
const __m512d rri0 = _mm512_loadu_pd(rr[0]);
|
||||
const __m512d rri1 = _mm512_loadu_pd(rr[1]);
|
||||
const __m512d bir0 = _mm512_shuffle_pd(bri0, bri0, 0b01010101);
|
||||
const __m512d bir1 = _mm512_shuffle_pd(bri1, bri1, 0b01010101);
|
||||
const __m512d aii0 = _mm512_shuffle_pd(ari0, ari0, 0b11111111);
|
||||
const __m512d aii1 = _mm512_shuffle_pd(ari1, ari1, 0b11111111);
|
||||
const __m512d pro0 = _mm512_fmaddsub_pd(aii0, bir0, rri0);
|
||||
const __m512d pro1 = _mm512_fmaddsub_pd(aii1, bir1, rri1);
|
||||
const __m512d arr0 = _mm512_shuffle_pd(ari0, ari0, 0b00000000);
|
||||
const __m512d arr1 = _mm512_shuffle_pd(ari1, ari1, 0b00000000);
|
||||
const __m512d res0 = _mm512_fmaddsub_pd(arr0, bri0, pro0);
|
||||
const __m512d res1 = _mm512_fmaddsub_pd(arr1, bri1, pro1);
|
||||
_mm512_storeu_pd(rr[0], res0);
|
||||
_mm512_storeu_pd(rr[1], res1);
|
||||
rr += 2; // ONCE
|
||||
aa += 2; // ONCE
|
||||
bb += 2; // ONCE
|
||||
// END_INTERLEAVE
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
#if 0
|
||||
EXPORT void cplx_fftvec_mul_fma(uint32_t m, void* r, const void* a, const void* b) {
|
||||
const double(*aa)[4] = (double(*)[4])a;
|
||||
const double(*bb)[4] = (double(*)[4])b;
|
||||
double(*rr)[4] = (double(*)[4])r;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||
const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5); // conj of b
|
||||
const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15); // im of a
|
||||
const __m256d pro% = _mm256_mul_pd(aii%,bir%);
|
||||
const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0); // rr of a
|
||||
const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%);
|
||||
_mm256_storeu_pd(rr[%],res%);
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||
const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0, 5); // conj of b
|
||||
const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1, 5); // conj of b
|
||||
const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2, 5); // conj of b
|
||||
const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3, 5); // conj of b
|
||||
const __m256d aii0 = _mm256_shuffle_pd(ari0,ari0, 15); // im of a
|
||||
const __m256d aii1 = _mm256_shuffle_pd(ari1,ari1, 15); // im of a
|
||||
const __m256d aii2 = _mm256_shuffle_pd(ari2,ari2, 15); // im of a
|
||||
const __m256d aii3 = _mm256_shuffle_pd(ari3,ari3, 15); // im of a
|
||||
const __m256d pro0 = _mm256_mul_pd(aii0,bir0);
|
||||
const __m256d pro1 = _mm256_mul_pd(aii1,bir1);
|
||||
const __m256d pro2 = _mm256_mul_pd(aii2,bir2);
|
||||
const __m256d pro3 = _mm256_mul_pd(aii3,bir3);
|
||||
const __m256d arr0 = _mm256_shuffle_pd(ari0,ari0, 0); // rr of a
|
||||
const __m256d arr1 = _mm256_shuffle_pd(ari1,ari1, 0); // rr of a
|
||||
const __m256d arr2 = _mm256_shuffle_pd(ari2,ari2, 0); // rr of a
|
||||
const __m256d arr3 = _mm256_shuffle_pd(ari3,ari3, 0); // rr of a
|
||||
const __m256d res0 = _mm256_fmaddsub_pd(arr0,bri0,pro0);
|
||||
const __m256d res1 = _mm256_fmaddsub_pd(arr1,bri1,pro1);
|
||||
const __m256d res2 = _mm256_fmaddsub_pd(arr2,bri2,pro2);
|
||||
const __m256d res3 = _mm256_fmaddsub_pd(arr3,bri3,pro3);
|
||||
_mm256_storeu_pd(rr[0],res0);
|
||||
_mm256_storeu_pd(rr[1],res1);
|
||||
_mm256_storeu_pd(rr[2],res2);
|
||||
_mm256_storeu_pd(rr[3],res3);
|
||||
// END_INTERLEAVE
|
||||
rr += 4;
|
||||
aa += 4;
|
||||
bb += 4;
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b) {
|
||||
const double(*aa)[4] = (double(*)[4])a;
|
||||
const double(*bb)[4] = (double(*)[4])b;
|
||||
double(*rr)[4] = (double(*)[4])r;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||
const __m256d res% = _mm256_add_pd(ari%,bri%);
|
||||
_mm256_storeu_pd(rr[%],res%);
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||
const __m256d res0 = _mm256_add_pd(ari0,bri0);
|
||||
const __m256d res1 = _mm256_add_pd(ari1,bri1);
|
||||
const __m256d res2 = _mm256_add_pd(ari2,bri2);
|
||||
const __m256d res3 = _mm256_add_pd(ari3,bri3);
|
||||
_mm256_storeu_pd(rr[0],res0);
|
||||
_mm256_storeu_pd(rr[1],res1);
|
||||
_mm256_storeu_pd(rr[2],res2);
|
||||
_mm256_storeu_pd(rr[3],res3);
|
||||
// END_INTERLEAVE
|
||||
rr += 4;
|
||||
aa += 4;
|
||||
bb += 4;
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b) {
|
||||
const double(*aa)[4] = (double(*)[4])a;
|
||||
const double(*bb)[4] = (double(*)[4])b;
|
||||
double(*rr)[4] = (double(*)[4])r;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||
const __m256d sum% = _mm256_add_pd(ari%,bri%);
|
||||
const __m256d rri% = _mm256_loadu_pd(rr[%]);
|
||||
const __m256d res% = _mm256_sub_pd(rri%,sum%);
|
||||
_mm256_storeu_pd(rr[%],res%);
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||
const __m256d sum0 = _mm256_add_pd(ari0,bri0);
|
||||
const __m256d sum1 = _mm256_add_pd(ari1,bri1);
|
||||
const __m256d sum2 = _mm256_add_pd(ari2,bri2);
|
||||
const __m256d sum3 = _mm256_add_pd(ari3,bri3);
|
||||
const __m256d rri0 = _mm256_loadu_pd(rr[0]);
|
||||
const __m256d rri1 = _mm256_loadu_pd(rr[1]);
|
||||
const __m256d rri2 = _mm256_loadu_pd(rr[2]);
|
||||
const __m256d rri3 = _mm256_loadu_pd(rr[3]);
|
||||
const __m256d res0 = _mm256_sub_pd(rri0,sum0);
|
||||
const __m256d res1 = _mm256_sub_pd(rri1,sum1);
|
||||
const __m256d res2 = _mm256_sub_pd(rri2,sum2);
|
||||
const __m256d res3 = _mm256_sub_pd(rri3,sum3);
|
||||
_mm256_storeu_pd(rr[0],res0);
|
||||
_mm256_storeu_pd(rr[1],res1);
|
||||
_mm256_storeu_pd(rr[2],res2);
|
||||
_mm256_storeu_pd(rr[3],res3);
|
||||
// END_INTERLEAVE
|
||||
rr += 4;
|
||||
aa += 4;
|
||||
bb += 4;
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) {
|
||||
const double(*aa)[4] = (double(*)[4])a;
|
||||
double(*rr)[4] = (double(*)[4])r;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
_mm256_storeu_pd(rr[%],ari%);
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
_mm256_storeu_pd(rr[0],ari0);
|
||||
_mm256_storeu_pd(rr[1],ari1);
|
||||
_mm256_storeu_pd(rr[2],ari2);
|
||||
_mm256_storeu_pd(rr[3],ari3);
|
||||
// END_INTERLEAVE
|
||||
rr += 4;
|
||||
aa += 4;
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_twiddle_fma(uint32_t m, void* a, void* b, const void* omg) {
|
||||
double(*aa)[4] = (double(*)[4])a;
|
||||
double(*bb)[4] = (double(*)[4])b;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
const __m256d om = _mm256_loadu_pd(omg);
|
||||
const __m256d omrr = _mm256_shuffle_pd(om, om, 0);
|
||||
const __m256d omii = _mm256_shuffle_pd(om, om, 15);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||
const __m256d bir% = _mm256_shuffle_pd(bri%,bri%,5);
|
||||
__m256d p% = _mm256_mul_pd(bir%,omii);
|
||||
p% = _mm256_fmaddsub_pd(bri%,omrr,p%);
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
_mm256_storeu_pd(aa[%],_mm256_add_pd(ari%,p%));
|
||||
_mm256_storeu_pd(bb[%],_mm256_sub_pd(ari%,p%));
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||
const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0,5);
|
||||
const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1,5);
|
||||
const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2,5);
|
||||
const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3,5);
|
||||
__m256d p0 = _mm256_mul_pd(bir0,omii);
|
||||
__m256d p1 = _mm256_mul_pd(bir1,omii);
|
||||
__m256d p2 = _mm256_mul_pd(bir2,omii);
|
||||
__m256d p3 = _mm256_mul_pd(bir3,omii);
|
||||
p0 = _mm256_fmaddsub_pd(bri0,omrr,p0);
|
||||
p1 = _mm256_fmaddsub_pd(bri1,omrr,p1);
|
||||
p2 = _mm256_fmaddsub_pd(bri2,omrr,p2);
|
||||
p3 = _mm256_fmaddsub_pd(bri3,omrr,p3);
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
_mm256_storeu_pd(aa[0],_mm256_add_pd(ari0,p0));
|
||||
_mm256_storeu_pd(aa[1],_mm256_add_pd(ari1,p1));
|
||||
_mm256_storeu_pd(aa[2],_mm256_add_pd(ari2,p2));
|
||||
_mm256_storeu_pd(aa[3],_mm256_add_pd(ari3,p3));
|
||||
_mm256_storeu_pd(bb[0],_mm256_sub_pd(ari0,p0));
|
||||
_mm256_storeu_pd(bb[1],_mm256_sub_pd(ari1,p1));
|
||||
_mm256_storeu_pd(bb[2],_mm256_sub_pd(ari2,p2));
|
||||
_mm256_storeu_pd(bb[3],_mm256_sub_pd(ari3,p3));
|
||||
// END_INTERLEAVE
|
||||
bb += 4;
|
||||
aa += 4;
|
||||
} while (aa < aend);
|
||||
}
|
||||
#endif
|
||||
|
||||
EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* precomp, void* a, void* b, const void* omg) {
|
||||
const uint32_t m = precomp->m;
|
||||
D8MEM* aa = (D8MEM*)a;
|
||||
D8MEM* bb = (D8MEM*)b;
|
||||
D8MEM* const aend = aa + (m >> 2);
|
||||
const __m512d om = _mm512_broadcast_f64x4(_mm256_loadu_pd(omg));
|
||||
const __m512d omrr = _mm512_shuffle_pd(om, om, 0b00000000);
|
||||
const __m512d omii = _mm512_shuffle_pd(om, om, 0b11111111);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m512d bri% = _mm512_loadu_pd(bb[%]);
|
||||
const __m512d bir% = _mm512_shuffle_pd(bri%,bri%,0b10011001);
|
||||
__m512d p% = _mm512_mul_pd(bir%,omii);
|
||||
p% = _mm512_fmaddsub_pd(bri%,omrr,p%);
|
||||
const __m512d ari% = _mm512_loadu_pd(aa[%]);
|
||||
_mm512_storeu_pd(aa[%],_mm512_add_pd(ari%,p%));
|
||||
_mm512_storeu_pd(bb[%],_mm512_sub_pd(ari%,p%));
|
||||
bb += @; // ONCE
|
||||
aa += @; // ONCE
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
const __m512d bri0 = _mm512_loadu_pd(bb[0]);
|
||||
const __m512d bri1 = _mm512_loadu_pd(bb[1]);
|
||||
const __m512d bri2 = _mm512_loadu_pd(bb[2]);
|
||||
const __m512d bri3 = _mm512_loadu_pd(bb[3]);
|
||||
const __m512d bir0 = _mm512_shuffle_pd(bri0, bri0, 0b10011001);
|
||||
const __m512d bir1 = _mm512_shuffle_pd(bri1, bri1, 0b10011001);
|
||||
const __m512d bir2 = _mm512_shuffle_pd(bri2, bri2, 0b10011001);
|
||||
const __m512d bir3 = _mm512_shuffle_pd(bri3, bri3, 0b10011001);
|
||||
__m512d p0 = _mm512_mul_pd(bir0, omii);
|
||||
__m512d p1 = _mm512_mul_pd(bir1, omii);
|
||||
__m512d p2 = _mm512_mul_pd(bir2, omii);
|
||||
__m512d p3 = _mm512_mul_pd(bir3, omii);
|
||||
p0 = _mm512_fmaddsub_pd(bri0, omrr, p0);
|
||||
p1 = _mm512_fmaddsub_pd(bri1, omrr, p1);
|
||||
p2 = _mm512_fmaddsub_pd(bri2, omrr, p2);
|
||||
p3 = _mm512_fmaddsub_pd(bri3, omrr, p3);
|
||||
const __m512d ari0 = _mm512_loadu_pd(aa[0]);
|
||||
const __m512d ari1 = _mm512_loadu_pd(aa[1]);
|
||||
const __m512d ari2 = _mm512_loadu_pd(aa[2]);
|
||||
const __m512d ari3 = _mm512_loadu_pd(aa[3]);
|
||||
_mm512_storeu_pd(aa[0], _mm512_add_pd(ari0, p0));
|
||||
_mm512_storeu_pd(aa[1], _mm512_add_pd(ari1, p1));
|
||||
_mm512_storeu_pd(aa[2], _mm512_add_pd(ari2, p2));
|
||||
_mm512_storeu_pd(aa[3], _mm512_add_pd(ari3, p3));
|
||||
_mm512_storeu_pd(bb[0], _mm512_sub_pd(ari0, p0));
|
||||
_mm512_storeu_pd(bb[1], _mm512_sub_pd(ari1, p1));
|
||||
_mm512_storeu_pd(bb[2], _mm512_sub_pd(ari2, p2));
|
||||
_mm512_storeu_pd(bb[3], _mm512_sub_pd(ari3, p3));
|
||||
bb += 4; // ONCE
|
||||
aa += 4; // ONCE
|
||||
// END_INTERLEAVE
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* precomp, void* a, uint64_t slicea,
|
||||
const void* omg) {
|
||||
const uint32_t m = precomp->m;
|
||||
const uint64_t OFFSET = slicea / sizeof(D8MEM);
|
||||
D8MEM* aa = (D8MEM*)a;
|
||||
const D8MEM* aend = aa + (m >> 2);
|
||||
const __m512d om = _mm512_broadcast_f64x4(_mm256_loadu_pd(omg));
|
||||
const __m512d om1rr = _mm512_shuffle_pd(om, om, 0);
|
||||
const __m512d om1ii = _mm512_shuffle_pd(om, om, 15);
|
||||
const __m512d om2rr = _mm512_shuffle_pd(om, om, 0);
|
||||
const __m512d om2ii = _mm512_shuffle_pd(om, om, 0);
|
||||
const __m512d om3rr = _mm512_shuffle_pd(om, om, 15);
|
||||
const __m512d om3ii = _mm512_shuffle_pd(om, om, 15);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
__m512d ari% = _mm512_loadu_pd(aa[%]);
|
||||
__m512d bri% = _mm512_loadu_pd((aa+OFFSET)[%]);
|
||||
__m512d cri% = _mm512_loadu_pd((aa+2*OFFSET)[%]);
|
||||
__m512d dri% = _mm512_loadu_pd((aa+3*OFFSET)[%]);
|
||||
__m512d pa% = _mm512_shuffle_pd(cri%,cri%,5);
|
||||
__m512d pb% = _mm512_shuffle_pd(dri%,dri%,5);
|
||||
pa% = _mm512_mul_pd(pa%,om1ii);
|
||||
pb% = _mm512_mul_pd(pb%,om1ii);
|
||||
pa% = _mm512_fmaddsub_pd(cri%,om1rr,pa%);
|
||||
pb% = _mm512_fmaddsub_pd(dri%,om1rr,pb%);
|
||||
cri% = _mm512_sub_pd(ari%,pa%);
|
||||
dri% = _mm512_sub_pd(bri%,pb%);
|
||||
ari% = _mm512_add_pd(ari%,pa%);
|
||||
bri% = _mm512_add_pd(bri%,pb%);
|
||||
pa% = _mm512_shuffle_pd(bri%,bri%,5);
|
||||
pb% = _mm512_shuffle_pd(dri%,dri%,5);
|
||||
pa% = _mm512_mul_pd(pa%,om2ii);
|
||||
pb% = _mm512_mul_pd(pb%,om3ii);
|
||||
pa% = _mm512_fmaddsub_pd(bri%,om2rr,pa%);
|
||||
pb% = _mm512_fmaddsub_pd(dri%,om3rr,pb%);
|
||||
bri% = _mm512_sub_pd(ari%,pa%);
|
||||
dri% = _mm512_sub_pd(cri%,pb%);
|
||||
ari% = _mm512_add_pd(ari%,pa%);
|
||||
cri% = _mm512_add_pd(cri%,pb%);
|
||||
_mm512_storeu_pd(aa[%], ari%);
|
||||
_mm512_storeu_pd((aa+OFFSET)[%],bri%);
|
||||
_mm512_storeu_pd((aa+2*OFFSET)[%],cri%);
|
||||
_mm512_storeu_pd((aa+3*OFFSET)[%],dri%);
|
||||
aa += @; // ONCE
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 2
|
||||
__m512d ari0 = _mm512_loadu_pd(aa[0]);
|
||||
__m512d ari1 = _mm512_loadu_pd(aa[1]);
|
||||
__m512d bri0 = _mm512_loadu_pd((aa + OFFSET)[0]);
|
||||
__m512d bri1 = _mm512_loadu_pd((aa + OFFSET)[1]);
|
||||
__m512d cri0 = _mm512_loadu_pd((aa + 2 * OFFSET)[0]);
|
||||
__m512d cri1 = _mm512_loadu_pd((aa + 2 * OFFSET)[1]);
|
||||
__m512d dri0 = _mm512_loadu_pd((aa + 3 * OFFSET)[0]);
|
||||
__m512d dri1 = _mm512_loadu_pd((aa + 3 * OFFSET)[1]);
|
||||
__m512d pa0 = _mm512_shuffle_pd(cri0, cri0, 5);
|
||||
__m512d pa1 = _mm512_shuffle_pd(cri1, cri1, 5);
|
||||
__m512d pb0 = _mm512_shuffle_pd(dri0, dri0, 5);
|
||||
__m512d pb1 = _mm512_shuffle_pd(dri1, dri1, 5);
|
||||
pa0 = _mm512_mul_pd(pa0, om1ii);
|
||||
pa1 = _mm512_mul_pd(pa1, om1ii);
|
||||
pb0 = _mm512_mul_pd(pb0, om1ii);
|
||||
pb1 = _mm512_mul_pd(pb1, om1ii);
|
||||
pa0 = _mm512_fmaddsub_pd(cri0, om1rr, pa0);
|
||||
pa1 = _mm512_fmaddsub_pd(cri1, om1rr, pa1);
|
||||
pb0 = _mm512_fmaddsub_pd(dri0, om1rr, pb0);
|
||||
pb1 = _mm512_fmaddsub_pd(dri1, om1rr, pb1);
|
||||
cri0 = _mm512_sub_pd(ari0, pa0);
|
||||
cri1 = _mm512_sub_pd(ari1, pa1);
|
||||
dri0 = _mm512_sub_pd(bri0, pb0);
|
||||
dri1 = _mm512_sub_pd(bri1, pb1);
|
||||
ari0 = _mm512_add_pd(ari0, pa0);
|
||||
ari1 = _mm512_add_pd(ari1, pa1);
|
||||
bri0 = _mm512_add_pd(bri0, pb0);
|
||||
bri1 = _mm512_add_pd(bri1, pb1);
|
||||
pa0 = _mm512_shuffle_pd(bri0, bri0, 5);
|
||||
pa1 = _mm512_shuffle_pd(bri1, bri1, 5);
|
||||
pb0 = _mm512_shuffle_pd(dri0, dri0, 5);
|
||||
pb1 = _mm512_shuffle_pd(dri1, dri1, 5);
|
||||
pa0 = _mm512_mul_pd(pa0, om2ii);
|
||||
pa1 = _mm512_mul_pd(pa1, om2ii);
|
||||
pb0 = _mm512_mul_pd(pb0, om3ii);
|
||||
pb1 = _mm512_mul_pd(pb1, om3ii);
|
||||
pa0 = _mm512_fmaddsub_pd(bri0, om2rr, pa0);
|
||||
pa1 = _mm512_fmaddsub_pd(bri1, om2rr, pa1);
|
||||
pb0 = _mm512_fmaddsub_pd(dri0, om3rr, pb0);
|
||||
pb1 = _mm512_fmaddsub_pd(dri1, om3rr, pb1);
|
||||
bri0 = _mm512_sub_pd(ari0, pa0);
|
||||
bri1 = _mm512_sub_pd(ari1, pa1);
|
||||
dri0 = _mm512_sub_pd(cri0, pb0);
|
||||
dri1 = _mm512_sub_pd(cri1, pb1);
|
||||
ari0 = _mm512_add_pd(ari0, pa0);
|
||||
ari1 = _mm512_add_pd(ari1, pa1);
|
||||
cri0 = _mm512_add_pd(cri0, pb0);
|
||||
cri1 = _mm512_add_pd(cri1, pb1);
|
||||
_mm512_storeu_pd(aa[0], ari0);
|
||||
_mm512_storeu_pd(aa[1], ari1);
|
||||
_mm512_storeu_pd((aa + OFFSET)[0], bri0);
|
||||
_mm512_storeu_pd((aa + OFFSET)[1], bri1);
|
||||
_mm512_storeu_pd((aa + 2 * OFFSET)[0], cri0);
|
||||
_mm512_storeu_pd((aa + 2 * OFFSET)[1], cri1);
|
||||
_mm512_storeu_pd((aa + 3 * OFFSET)[0], dri0);
|
||||
_mm512_storeu_pd((aa + 3 * OFFSET)[1], dri1);
|
||||
aa += 2; // ONCE
|
||||
// END_INTERLEAVE
|
||||
} while (aa < aend);
|
||||
_mm256_zeroupper();
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
#ifndef SPQLIOS_CPLX_FFT_INTERNAL_H
|
||||
#define SPQLIOS_CPLX_FFT_INTERNAL_H
|
||||
|
||||
#include "cplx_fft.h"
|
||||
|
||||
/** @brief a complex number contains two doubles real,imag */
|
||||
typedef double CPLX[2];
|
||||
|
||||
EXPORT void cplx_set(CPLX r, const CPLX a);
|
||||
EXPORT void cplx_neg(CPLX r, const CPLX a);
|
||||
EXPORT void cplx_add(CPLX r, const CPLX a, const CPLX b);
|
||||
EXPORT void cplx_sub(CPLX r, const CPLX a, const CPLX b);
|
||||
EXPORT void cplx_mul(CPLX r, const CPLX a, const CPLX b);
|
||||
|
||||
/**
|
||||
* @brief splits 2h evaluations of one polynomials into 2 times h evaluations of even/odd polynomial
|
||||
* Input: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y)
|
||||
* Output: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z)
|
||||
* where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z
|
||||
* @param h number of "coefficients" h >= 1
|
||||
* @param data 2h complex coefficients interleaved and 256b aligned
|
||||
* @param powom y represented as (yre,yim)
|
||||
*/
|
||||
EXPORT void cplx_split_fft_ref(int32_t h, CPLX* data, const CPLX powom);
|
||||
EXPORT void cplx_bisplit_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]);
|
||||
|
||||
/**
|
||||
* Input: Q(y),Q(-y)
|
||||
* Output: P_0(z),P_1(z)
|
||||
* where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z
|
||||
* @param data 2 complexes coefficients interleaved and 256b aligned
|
||||
* @param powom (z,-z) interleaved: (zre,zim,-zre,-zim)
|
||||
*/
|
||||
EXPORT void split_fft_last_ref(CPLX* data, const CPLX powom);
|
||||
|
||||
EXPORT void cplx_ifft_naive(const uint32_t m, const double entry_pwr, CPLX* data);
|
||||
EXPORT void cplx_ifft16_avx_fma(void* data, const void* omega);
|
||||
EXPORT void cplx_ifft16_ref(void* data, const void* omega);
|
||||
|
||||
/**
|
||||
* @brief compute the ifft evaluations of P in place
|
||||
* ifft(data) = ifft_rec(data, i);
|
||||
* function ifft_rec(data, omega) {
|
||||
* if #data = 1: return data
|
||||
* let s = sqrt(omega) w. re(s)>0
|
||||
* let (u,v) = data
|
||||
* return split_fft([ifft_rec(u, s), ifft_rec(v, -s)],s)
|
||||
* }
|
||||
* @param itables precomputed tables (contains all the powers of omega in the order they are used)
|
||||
* @param data vector of m complexes (coeffs as input, evals as output)
|
||||
*/
|
||||
EXPORT void cplx_ifft_ref(const CPLX_IFFT_PRECOMP* itables, void* data);
|
||||
EXPORT void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* itables, void* data);
|
||||
EXPORT void cplx_fft_naive(const uint32_t m, const double entry_pwr, CPLX* data);
|
||||
EXPORT void cplx_fft16_avx_fma(void* data, const void* omega);
|
||||
EXPORT void cplx_fft16_ref(void* data, const void* omega);
|
||||
|
||||
/**
|
||||
* @brief compute the fft evaluations of P in place
|
||||
* fft(data) = fft_rec(data, i);
|
||||
* function fft_rec(data, omega) {
|
||||
* if #data = 1: return data
|
||||
* let s = sqrt(omega) w. re(s)>0
|
||||
* let (u,v) = merge_fft(data, s)
|
||||
* return [fft_rec(u, s), fft_rec(v, -s)]
|
||||
* }
|
||||
* @param tables precomputed tables (contains all the powers of omega in the order they are used)
|
||||
* @param data vector of m complexes (coeffs as input, evals as output)
|
||||
*/
|
||||
EXPORT void cplx_fft_ref(const CPLX_FFT_PRECOMP* tables, void* data);
|
||||
EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* tables, void* data);
|
||||
|
||||
/**
|
||||
* @brief merges 2 times h evaluations of even/odd polynomials into 2h evaluations of a sigle polynomial
|
||||
* Input: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z)
|
||||
* Output: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y)
|
||||
* where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z
|
||||
* @param h number of "coefficients" h >= 1
|
||||
* @param data 2h complex coefficients interleaved and 256b aligned
|
||||
* @param powom y represented as (yre,yim)
|
||||
*/
|
||||
EXPORT void cplx_twiddle_fft_ref(int32_t h, CPLX* data, const CPLX powom);
|
||||
|
||||
EXPORT void citwiddle(CPLX a, CPLX b, const CPLX om);
|
||||
EXPORT void ctwiddle(CPLX a, CPLX b, const CPLX om);
|
||||
EXPORT void invctwiddle(CPLX a, CPLX b, const CPLX ombar);
|
||||
EXPORT void invcitwiddle(CPLX a, CPLX b, const CPLX ombar);
|
||||
|
||||
// CONVERSIONS
|
||||
|
||||
/** @brief r = x from ZnX (coeffs as signed int32_t's ) to double */
|
||||
EXPORT void cplx_from_znx32_ref(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x);
|
||||
EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x);
|
||||
/** @brief r = x to ZnX (coeffs as signed int32_t's ) to double */
|
||||
EXPORT void cplx_to_znx32_ref(const CPLX_TO_ZNX32_PRECOMP* precomp, int32_t* r, const void* x);
|
||||
EXPORT void cplx_to_znx32_avx2_fma(const CPLX_TO_ZNX32_PRECOMP* precomp, int32_t* r, const void* x);
|
||||
/** @brief r = x mod 1 from TnX (coeffs as signed int32_t's) to double */
|
||||
EXPORT void cplx_from_tnx32_ref(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x);
|
||||
EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x);
|
||||
/** @brief r = x mod 1 from TnX (coeffs as signed int32_t's) */
|
||||
EXPORT void cplx_to_tnx32_ref(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c);
|
||||
EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c);
|
||||
/** @brief r = x from RnX (coeffs as doubles ) to double */
|
||||
EXPORT void cplx_from_rnx64_ref(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x);
|
||||
EXPORT void cplx_from_rnx64_avx2_fma(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x);
|
||||
/** @brief r = x to RnX (coeffs as doubles ) to double */
|
||||
EXPORT void cplx_to_rnx64_ref(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||
EXPORT void cplx_to_rnx64_avx2_fma(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||
/** @brief r = x to integers in RnX (coeffs as doubles ) to double */
|
||||
EXPORT void cplx_round_to_rnx64_ref(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||
EXPORT void cplx_round_to_rnx64_avx2_fma(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||
|
||||
// fftvec operations
|
||||
/** @brief element-wise addmul r += ab */
|
||||
EXPORT void cplx_fftvec_addmul_ref(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b);
|
||||
EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b);
|
||||
EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b);
|
||||
EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b);
|
||||
/** @brief element-wise mul r = ab */
|
||||
EXPORT void cplx_fftvec_mul_ref(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b);
|
||||
EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b);
|
||||
|
||||
#endif // SPQLIOS_CPLX_FFT_INTERNAL_H
|
||||
@@ -1,109 +0,0 @@
|
||||
#ifndef SPQLIOS_CPLX_FFT_PRIVATE_H
|
||||
#define SPQLIOS_CPLX_FFT_PRIVATE_H
|
||||
|
||||
#include "cplx_fft.h"
|
||||
|
||||
typedef struct cplx_twiddle_precomp CPLX_FFTVEC_TWIDDLE_PRECOMP;
|
||||
typedef struct cplx_bitwiddle_precomp CPLX_FFTVEC_BITWIDDLE_PRECOMP;
|
||||
|
||||
typedef void (*IFFT_FUNCTION)(const CPLX_IFFT_PRECOMP*, void*);
|
||||
typedef void (*FFT_FUNCTION)(const CPLX_FFT_PRECOMP*, void*);
|
||||
// conversions
|
||||
typedef void (*FROM_ZNX32_FUNCTION)(const CPLX_FROM_ZNX32_PRECOMP*, void*, const int32_t*);
|
||||
typedef void (*TO_ZNX32_FUNCTION)(const CPLX_FROM_ZNX32_PRECOMP*, int32_t*, const void*);
|
||||
typedef void (*FROM_TNX32_FUNCTION)(const CPLX_FROM_TNX32_PRECOMP*, void*, const int32_t*);
|
||||
typedef void (*TO_TNX32_FUNCTION)(const CPLX_TO_TNX32_PRECOMP*, int32_t*, const void*);
|
||||
typedef void (*FROM_RNX64_FUNCTION)(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x);
|
||||
typedef void (*TO_RNX64_FUNCTION)(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||
typedef void (*ROUND_TO_RNX64_FUNCTION)(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||
// fftvec operations
|
||||
typedef void (*FFTVEC_MUL_FUNCTION)(const CPLX_FFTVEC_MUL_PRECOMP*, void*, const void*, const void*);
|
||||
typedef void (*FFTVEC_ADDMUL_FUNCTION)(const CPLX_FFTVEC_ADDMUL_PRECOMP*, void*, const void*, const void*);
|
||||
|
||||
typedef void (*FFTVEC_TWIDDLE_FUNCTION)(const CPLX_FFTVEC_TWIDDLE_PRECOMP*, void*, const void*, const void*);
|
||||
typedef void (*FFTVEC_BITWIDDLE_FUNCTION)(const CPLX_FFTVEC_BITWIDDLE_PRECOMP*, void*, uint64_t, const void*);
|
||||
|
||||
struct cplx_ifft_precomp {
|
||||
IFFT_FUNCTION function;
|
||||
int64_t m;
|
||||
uint64_t buf_size;
|
||||
double* powomegas;
|
||||
void* aligned_buffers;
|
||||
};
|
||||
|
||||
struct cplx_fft_precomp {
|
||||
FFT_FUNCTION function;
|
||||
int64_t m;
|
||||
uint64_t buf_size;
|
||||
double* powomegas;
|
||||
void* aligned_buffers;
|
||||
};
|
||||
|
||||
struct cplx_from_znx32_precomp {
|
||||
FROM_ZNX32_FUNCTION function;
|
||||
int64_t m;
|
||||
};
|
||||
|
||||
struct cplx_to_znx32_precomp {
|
||||
TO_ZNX32_FUNCTION function;
|
||||
int64_t m;
|
||||
double divisor;
|
||||
};
|
||||
|
||||
struct cplx_from_tnx32_precomp {
|
||||
FROM_TNX32_FUNCTION function;
|
||||
int64_t m;
|
||||
};
|
||||
|
||||
struct cplx_to_tnx32_precomp {
|
||||
TO_TNX32_FUNCTION function;
|
||||
int64_t m;
|
||||
double divisor;
|
||||
};
|
||||
|
||||
struct cplx_from_rnx64_precomp {
|
||||
FROM_RNX64_FUNCTION function;
|
||||
int64_t m;
|
||||
};
|
||||
|
||||
struct cplx_to_rnx64_precomp {
|
||||
TO_RNX64_FUNCTION function;
|
||||
int64_t m;
|
||||
double divisor;
|
||||
};
|
||||
|
||||
struct cplx_round_to_rnx64_precomp {
|
||||
ROUND_TO_RNX64_FUNCTION function;
|
||||
int64_t m;
|
||||
double divisor;
|
||||
uint32_t log2bound;
|
||||
};
|
||||
|
||||
typedef struct cplx_mul_precomp {
|
||||
FFTVEC_MUL_FUNCTION function;
|
||||
int64_t m;
|
||||
} CPLX_FFTVEC_MUL_PRECOMP;
|
||||
|
||||
typedef struct cplx_addmul_precomp {
|
||||
FFTVEC_ADDMUL_FUNCTION function;
|
||||
int64_t m;
|
||||
} CPLX_FFTVEC_ADDMUL_PRECOMP;
|
||||
|
||||
struct cplx_twiddle_precomp {
|
||||
FFTVEC_TWIDDLE_FUNCTION function;
|
||||
int64_t m;
|
||||
};
|
||||
|
||||
struct cplx_bitwiddle_precomp {
|
||||
FFTVEC_BITWIDDLE_FUNCTION function;
|
||||
int64_t m;
|
||||
};
|
||||
|
||||
EXPORT void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om);
|
||||
EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om);
|
||||
EXPORT void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
|
||||
const void* om);
|
||||
EXPORT void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
|
||||
const void* om);
|
||||
|
||||
#endif // SPQLIOS_CPLX_FFT_PRIVATE_H
|
||||
@@ -1,367 +0,0 @@
|
||||
#include <memory.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "../commons_private.h"
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
/** @brief (a,b) <- (a+omega.b,a-omega.b) */
|
||||
void ctwiddle(CPLX a, CPLX b, const CPLX om) {
|
||||
double re = om[0] * b[0] - om[1] * b[1];
|
||||
double im = om[0] * b[1] + om[1] * b[0];
|
||||
b[0] = a[0] - re;
|
||||
b[1] = a[1] - im;
|
||||
a[0] += re;
|
||||
a[1] += im;
|
||||
}
|
||||
|
||||
/** @brief (a,b) <- (a+i.omega.b,a-i.omega.b) */
|
||||
void citwiddle(CPLX a, CPLX b, const CPLX om) {
|
||||
double re = -om[1] * b[0] - om[0] * b[1];
|
||||
double im = -om[1] * b[1] + om[0] * b[0];
|
||||
b[0] = a[0] - re;
|
||||
b[1] = a[1] - im;
|
||||
a[0] += re;
|
||||
a[1] += im;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief FFT modulo X^16-omega^2 (in registers)
|
||||
* @param data contains 16 complexes
|
||||
* @param omega 8 complexes in this order:
|
||||
* omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma
|
||||
* alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta)
|
||||
* j = sqrt(i), k=sqrt(j)
|
||||
*/
|
||||
void cplx_fft16_ref(void* data, const void* omega) {
|
||||
CPLX* d = data;
|
||||
const CPLX* om = omega;
|
||||
// first pass
|
||||
for (uint64_t i = 0; i < 8; ++i) {
|
||||
ctwiddle(d[0 + i], d[8 + i], om[0]);
|
||||
}
|
||||
//
|
||||
ctwiddle(d[0], d[4], om[1]);
|
||||
ctwiddle(d[1], d[5], om[1]);
|
||||
ctwiddle(d[2], d[6], om[1]);
|
||||
ctwiddle(d[3], d[7], om[1]);
|
||||
citwiddle(d[8], d[12], om[1]);
|
||||
citwiddle(d[9], d[13], om[1]);
|
||||
citwiddle(d[10], d[14], om[1]);
|
||||
citwiddle(d[11], d[15], om[1]);
|
||||
//
|
||||
ctwiddle(d[0], d[2], om[2]);
|
||||
ctwiddle(d[1], d[3], om[2]);
|
||||
citwiddle(d[4], d[6], om[2]);
|
||||
citwiddle(d[5], d[7], om[2]);
|
||||
ctwiddle(d[8], d[10], om[3]);
|
||||
ctwiddle(d[9], d[11], om[3]);
|
||||
citwiddle(d[12], d[14], om[3]);
|
||||
citwiddle(d[13], d[15], om[3]);
|
||||
//
|
||||
ctwiddle(d[0], d[1], om[4]);
|
||||
citwiddle(d[2], d[3], om[4]);
|
||||
ctwiddle(d[4], d[5], om[5]);
|
||||
citwiddle(d[6], d[7], om[5]);
|
||||
ctwiddle(d[8], d[9], om[6]);
|
||||
citwiddle(d[10], d[11], om[6]);
|
||||
ctwiddle(d[12], d[13], om[7]);
|
||||
citwiddle(d[14], d[15], om[7]);
|
||||
}
|
||||
|
||||
double cos_2pix(double x) { return m_accurate_cos(2 * M_PI * x); }
|
||||
double sin_2pix(double x) { return m_accurate_sin(2 * M_PI * x); }
|
||||
void cplx_set_e2pix(CPLX res, double x) {
|
||||
res[0] = cos_2pix(x);
|
||||
res[1] = sin_2pix(x);
|
||||
}
|
||||
|
||||
void cplx_fft16_precomp(const double entry_pwr, CPLX** omg) {
|
||||
static const double j_pow = 1. / 8.;
|
||||
static const double k_pow = 1. / 16.;
|
||||
const double pom = entry_pwr / 2.;
|
||||
const double pom_2 = entry_pwr / 4.;
|
||||
const double pom_4 = entry_pwr / 8.;
|
||||
const double pom_8 = entry_pwr / 16.;
|
||||
cplx_set_e2pix((*omg)[0], pom);
|
||||
cplx_set_e2pix((*omg)[1], pom_2);
|
||||
cplx_set_e2pix((*omg)[2], pom_4);
|
||||
cplx_set_e2pix((*omg)[3], pom_4 + j_pow);
|
||||
cplx_set_e2pix((*omg)[4], pom_8);
|
||||
cplx_set_e2pix((*omg)[5], pom_8 + j_pow);
|
||||
cplx_set_e2pix((*omg)[6], pom_8 + k_pow);
|
||||
cplx_set_e2pix((*omg)[7], pom_8 + j_pow + k_pow);
|
||||
*omg += 8;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief h twiddles-fft on the same omega
|
||||
* (also called merge-fft)merges 2 times h evaluations of even/odd polynomials into 2h evaluations of a sigle polynomial
|
||||
* Input: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z)
|
||||
* Output: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y)
|
||||
* where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z
|
||||
* @param h number of "coefficients" h >= 1
|
||||
* @param data 2h complex coefficients interleaved and 256b aligned
|
||||
* @param powom y represented as (yre,yim)
|
||||
*/
|
||||
void cplx_twiddle_fft_ref(int32_t h, CPLX* data, const CPLX powom) {
|
||||
CPLX* d0 = data;
|
||||
CPLX* d1 = data + h;
|
||||
for (uint64_t i = 0; i < h; ++i) {
|
||||
ctwiddle(d0[i], d1[i], powom);
|
||||
}
|
||||
}
|
||||
|
||||
void cplx_bitwiddle_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]) {
|
||||
CPLX* d0 = data;
|
||||
CPLX* d1 = data + h;
|
||||
CPLX* d2 = data + 2 * h;
|
||||
CPLX* d3 = data + 3 * h;
|
||||
for (uint64_t i = 0; i < h; ++i) {
|
||||
ctwiddle(d0[i], d2[i], powom[0]);
|
||||
ctwiddle(d1[i], d3[i], powom[0]);
|
||||
}
|
||||
for (uint64_t i = 0; i < h; ++i) {
|
||||
ctwiddle(d0[i], d1[i], powom[1]);
|
||||
citwiddle(d2[i], d3[i], powom[1]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Input: P_0(z),P_1(z)
|
||||
* Output: Q(y),Q(-y)
|
||||
* where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z
|
||||
* @param data 2 complexes coefficients interleaved and 256b aligned
|
||||
* @param powom (z,-z) interleaved: (zre,zim,-zre,-zim)
|
||||
*/
|
||||
void merge_fft_last_ref(CPLX* data, const CPLX powom) {
|
||||
CPLX prod;
|
||||
cplx_mul(prod, data[1], powom);
|
||||
cplx_sub(data[1], data[0], prod);
|
||||
cplx_add(data[0], data[0], prod);
|
||||
}
|
||||
|
||||
void cplx_fft_ref_bfs_2(CPLX* dat, const CPLX** omg, uint32_t m) {
|
||||
CPLX* data = (CPLX*)dat;
|
||||
CPLX* const dend = data + m;
|
||||
for (int32_t h = m / 2; h >= 2; h >>= 1) {
|
||||
for (CPLX* d = data; d < dend; d += 2 * h) {
|
||||
if (memcmp((*omg)[0], (*omg)[1], 8) != 0) abort();
|
||||
cplx_twiddle_fft_ref(h, d, **omg);
|
||||
*omg += 2;
|
||||
}
|
||||
#if 0
|
||||
printf("after merge %d: ", h);
|
||||
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||
printf("%.6lf %.6lf ",data[ii][0],data[ii][1]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
}
|
||||
for (CPLX* d = data; d < dend; d += 2) {
|
||||
// TODO see if encoding changes
|
||||
if ((*omg)[0][0] != -(*omg)[1][0]) abort();
|
||||
if ((*omg)[0][1] != -(*omg)[1][1]) abort();
|
||||
merge_fft_last_ref(d, **omg);
|
||||
*omg += 2;
|
||||
}
|
||||
#if 0
|
||||
printf("after last: ");
|
||||
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||
printf("%.6lf %.6lf ",data[ii][0],data[ii][1]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
void cplx_fft_ref_bfs_16(CPLX* dat, const CPLX** omg, uint32_t m) {
|
||||
CPLX* data = (CPLX*)dat;
|
||||
CPLX* const dend = data + m;
|
||||
uint32_t mm = m;
|
||||
uint32_t log2m = log2(m);
|
||||
if (log2m % 2 == 1) {
|
||||
cplx_twiddle_fft_ref(mm / 2, data, **omg);
|
||||
*omg += 2;
|
||||
mm >>= 1;
|
||||
}
|
||||
while (mm > 16) {
|
||||
uint32_t h = mm / 4;
|
||||
for (CPLX* d = data; d < dend; d += mm) {
|
||||
cplx_bitwiddle_fft_ref(h, d, *omg);
|
||||
*omg += 2;
|
||||
}
|
||||
mm = h;
|
||||
}
|
||||
for (CPLX* d = data; d < dend; d += 16) {
|
||||
cplx_fft16_ref(d, *omg);
|
||||
*omg += 8;
|
||||
}
|
||||
#if 0
|
||||
printf("after last: ");
|
||||
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||
printf("%.6lf %.6lf ",data[ii][0],data[ii][1]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
/** @brief fft modulo X^m-exp(i.2pi.entry+pwr) -- reference code */
|
||||
void cplx_fft_naive(const uint32_t m, const double entry_pwr, CPLX* data) {
|
||||
if (m == 1) return;
|
||||
const double pom = entry_pwr / 2.;
|
||||
const uint32_t h = m / 2;
|
||||
// apply the twiddle factors
|
||||
CPLX cpom;
|
||||
cplx_set_e2pix(cpom, pom);
|
||||
for (uint64_t i = 0; i < h; ++i) {
|
||||
ctwiddle(data[i], data[i + h], cpom);
|
||||
}
|
||||
// do the recursive calls
|
||||
cplx_fft_naive(h, pom, data);
|
||||
cplx_fft_naive(h, pom + 0.5, data + h);
|
||||
}
|
||||
|
||||
/** @brief fills omega for cplx_fft_bfs_16 modulo X^m-exp(i.2.pi.entry_pwr) */
|
||||
void fill_cplx_fft_omegas_bfs_16(const double entry_pwr, CPLX** omg, uint32_t m) {
|
||||
uint32_t mm = m;
|
||||
uint32_t log2m = log2(m);
|
||||
double ss = entry_pwr;
|
||||
if (log2m % 2 == 1) {
|
||||
uint32_t h = mm / 2;
|
||||
double pom = ss / 2.;
|
||||
for (uint32_t i = 0; i < m / mm; i++) {
|
||||
cplx_set_e2pix(omg[0][0], pom + fracrevbits(i) / 2.);
|
||||
cplx_set(omg[0][1], omg[0][0]);
|
||||
*omg += 2;
|
||||
}
|
||||
mm = h;
|
||||
ss = pom;
|
||||
}
|
||||
while (mm > 16) {
|
||||
double pom = ss / 4.;
|
||||
uint32_t h = mm / 4;
|
||||
for (uint32_t i = 0; i < m / mm; i++) {
|
||||
double om = pom + fracrevbits(i) / 4.;
|
||||
cplx_set_e2pix(omg[0][0], 2. * om);
|
||||
cplx_set_e2pix(omg[0][1], om);
|
||||
*omg += 2;
|
||||
}
|
||||
mm = h;
|
||||
ss = pom;
|
||||
}
|
||||
{
|
||||
// mm=16
|
||||
for (uint32_t i = 0; i < m / 16; i++) {
|
||||
cplx_fft16_precomp(ss + fracrevbits(i), omg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief fills omega for cplx_fft_bfs_2 modulo X^m-exp(i.2.pi.entry_pwr) */
|
||||
void fill_cplx_fft_omegas_bfs_2(const double entry_pwr, CPLX** omg, uint32_t m) {
|
||||
double pom = entry_pwr / 2.;
|
||||
for (int32_t h = m / 2; h >= 2; h >>= 1) {
|
||||
for (uint32_t i = 0; i < m / (2 * h); i++) {
|
||||
cplx_set_e2pix(omg[0][0], pom + fracrevbits(i) / 2.);
|
||||
cplx_set(omg[0][1], omg[0][0]);
|
||||
*omg += 2;
|
||||
}
|
||||
pom /= 2;
|
||||
}
|
||||
{
|
||||
// h=1
|
||||
for (uint32_t i = 0; i < m / 2; i++) {
|
||||
cplx_set_e2pix((*omg)[0], pom + fracrevbits(i) / 2.);
|
||||
cplx_neg((*omg)[1], (*omg)[0]);
|
||||
*omg += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief fills omega for cplx_fft_rec modulo X^m-exp(i.2.pi.entry_pwr) */
|
||||
void fill_cplx_fft_omegas_rec_16(const double entry_pwr, CPLX** omg, uint32_t m) {
|
||||
// note that the cases below are for recursive calls only!
|
||||
// externally, this function shall only be called with m>=4096
|
||||
if (m == 1) return;
|
||||
if (m <= 8) return fill_cplx_fft_omegas_bfs_2(entry_pwr, omg, m);
|
||||
if (m <= 2048) return fill_cplx_fft_omegas_bfs_16(entry_pwr, omg, m);
|
||||
double pom = entry_pwr / 2.;
|
||||
cplx_set_e2pix((*omg)[0], pom);
|
||||
cplx_set_e2pix((*omg)[1], pom);
|
||||
*omg += 2;
|
||||
fill_cplx_fft_omegas_rec_16(pom, omg, m / 2);
|
||||
fill_cplx_fft_omegas_rec_16(pom + 0.5, omg, m / 2);
|
||||
}
|
||||
|
||||
void cplx_fft_ref_rec_16(CPLX* dat, const CPLX** omg, uint32_t m) {
|
||||
if (m == 1) return;
|
||||
if (m <= 8) return cplx_fft_ref_bfs_2(dat, omg, m);
|
||||
if (m <= 2048) return cplx_fft_ref_bfs_16(dat, omg, m);
|
||||
const uint32_t h = m / 2;
|
||||
if (memcmp((*omg)[0], (*omg)[1], 8) != 0) abort();
|
||||
cplx_twiddle_fft_ref(h, dat, **omg);
|
||||
*omg += 2;
|
||||
cplx_fft_ref_rec_16(dat, omg, h);
|
||||
cplx_fft_ref_rec_16(dat + h, omg, h);
|
||||
}
|
||||
|
||||
void cplx_fft_ref(const CPLX_FFT_PRECOMP* precomp, void* d) {
|
||||
CPLX* data = (CPLX*)d;
|
||||
const int32_t m = precomp->m;
|
||||
const CPLX* omg = (CPLX*)precomp->powomegas;
|
||||
if (m == 1) return;
|
||||
if (m <= 8) return cplx_fft_ref_bfs_2(data, &omg, m);
|
||||
if (m <= 2048) return cplx_fft_ref_bfs_16(data, &omg, m);
|
||||
cplx_fft_ref_rec_16(data, &omg, m);
|
||||
}
|
||||
|
||||
EXPORT CPLX_FFT_PRECOMP* new_cplx_fft_precomp(uint32_t m, uint32_t num_buffers) {
|
||||
const uint64_t OMG_SPACE = ceilto64b((2 * m) * sizeof(CPLX));
|
||||
const uint64_t BUF_SIZE = ceilto64b(m * sizeof(CPLX));
|
||||
void* reps = malloc(sizeof(CPLX_FFT_PRECOMP) + 63 // padding
|
||||
+ OMG_SPACE // tables //TODO 16?
|
||||
+ num_buffers * BUF_SIZE // buffers
|
||||
);
|
||||
uint64_t aligned_addr = ceilto64b((uint64_t)(reps) + sizeof(CPLX_FFT_PRECOMP));
|
||||
CPLX_FFT_PRECOMP* r = (CPLX_FFT_PRECOMP*)reps;
|
||||
r->m = m;
|
||||
r->buf_size = BUF_SIZE;
|
||||
r->powomegas = (double*)aligned_addr;
|
||||
r->aligned_buffers = (void*)(aligned_addr + OMG_SPACE);
|
||||
// fill in powomegas
|
||||
CPLX* omg = (CPLX*)r->powomegas;
|
||||
if (m <= 8) {
|
||||
fill_cplx_fft_omegas_bfs_2(0.25, &omg, m);
|
||||
} else if (m <= 2048) {
|
||||
fill_cplx_fft_omegas_bfs_16(0.25, &omg, m);
|
||||
} else {
|
||||
fill_cplx_fft_omegas_rec_16(0.25, &omg, m);
|
||||
}
|
||||
if (((uint64_t)omg) - aligned_addr > OMG_SPACE) abort();
|
||||
// dispatch the right implementation
|
||||
{
|
||||
if (m <= 4) {
|
||||
// currently, we do not have any acceletated
|
||||
// implementation for m<=4
|
||||
r->function = cplx_fft_ref;
|
||||
} else if (CPU_SUPPORTS("fma")) {
|
||||
r->function = cplx_fft_avx2_fma;
|
||||
} else {
|
||||
r->function = cplx_fft_ref;
|
||||
}
|
||||
}
|
||||
return reps;
|
||||
}
|
||||
|
||||
EXPORT void* cplx_fft_precomp_get_buffer(const CPLX_FFT_PRECOMP* tables, uint32_t buffer_index) {
|
||||
return (uint8_t*)tables->aligned_buffers + buffer_index * tables->buf_size;
|
||||
}
|
||||
|
||||
EXPORT void cplx_fft_simple(uint32_t m, void* data) {
|
||||
static CPLX_FFT_PRECOMP* p[31] = {0};
|
||||
CPLX_FFT_PRECOMP** f = p + log2m(m);
|
||||
if (!*f) *f = new_cplx_fft_precomp(m, 0);
|
||||
(*f)->function(*f, data);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data) { tables->function(tables, data); }
|
||||
@@ -1,309 +0,0 @@
|
||||
#include <immintrin.h>
|
||||
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
typedef double D2MEM[2];
|
||||
|
||||
EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||
const uint32_t m = precomp->m;
|
||||
const D2MEM* aa = (D2MEM*)a;
|
||||
const D2MEM* bb = (D2MEM*)b;
|
||||
D2MEM* rr = (D2MEM*)r;
|
||||
const D2MEM* const aend = aa + m;
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m128d ari% = _mm_loadu_pd(aa[%]);
|
||||
const __m128d bri% = _mm_loadu_pd(bb[%]);
|
||||
const __m128d rri% = _mm_loadu_pd(rr[%]);
|
||||
const __m128d bir% = _mm_shuffle_pd(bri%,bri%, 5);
|
||||
const __m128d aii% = _mm_shuffle_pd(ari%,ari%, 15);
|
||||
const __m128d pro% = _mm_fmaddsub_pd(aii%,bir%,rri%);
|
||||
const __m128d arr% = _mm_shuffle_pd(ari%,ari%, 0);
|
||||
const __m128d res% = _mm_fmaddsub_pd(arr%,bri%,pro%);
|
||||
_mm_storeu_pd(rr[%],res%);
|
||||
rr += @; // ONCE
|
||||
aa += @; // ONCE
|
||||
bb += @; // ONCE
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 2
|
||||
const __m128d ari0 = _mm_loadu_pd(aa[0]);
|
||||
const __m128d ari1 = _mm_loadu_pd(aa[1]);
|
||||
const __m128d bri0 = _mm_loadu_pd(bb[0]);
|
||||
const __m128d bri1 = _mm_loadu_pd(bb[1]);
|
||||
const __m128d rri0 = _mm_loadu_pd(rr[0]);
|
||||
const __m128d rri1 = _mm_loadu_pd(rr[1]);
|
||||
const __m128d bir0 = _mm_shuffle_pd(bri0, bri0, 0b01);
|
||||
const __m128d bir1 = _mm_shuffle_pd(bri1, bri1, 0b01);
|
||||
const __m128d aii0 = _mm_shuffle_pd(ari0, ari0, 0b11);
|
||||
const __m128d aii1 = _mm_shuffle_pd(ari1, ari1, 0b11);
|
||||
const __m128d pro0 = _mm_fmaddsub_pd(aii0, bir0, rri0);
|
||||
const __m128d pro1 = _mm_fmaddsub_pd(aii1, bir1, rri1);
|
||||
const __m128d arr0 = _mm_shuffle_pd(ari0, ari0, 0b00);
|
||||
const __m128d arr1 = _mm_shuffle_pd(ari1, ari1, 0b00);
|
||||
const __m128d res0 = _mm_fmaddsub_pd(arr0, bri0, pro0);
|
||||
const __m128d res1 = _mm_fmaddsub_pd(arr1, bri1, pro1);
|
||||
_mm_storeu_pd(rr[0], res0);
|
||||
_mm_storeu_pd(rr[1], res1);
|
||||
rr += 2; // ONCE
|
||||
aa += 2; // ONCE
|
||||
bb += 2; // ONCE
|
||||
// END_INTERLEAVE
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
#if 0
|
||||
EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||
const uint32_t m = precomp->m;
|
||||
const double(*aa)[4] = (double(*)[4])a;
|
||||
const double(*bb)[4] = (double(*)[4])b;
|
||||
double(*rr)[4] = (double(*)[4])r;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||
const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5); // conj of b
|
||||
const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15); // im of a
|
||||
const __m256d pro% = _mm256_mul_pd(aii%,bir%);
|
||||
const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0); // rr of a
|
||||
const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%);
|
||||
_mm256_storeu_pd(rr[%],res%);
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||
const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0, 5); // conj of b
|
||||
const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1, 5); // conj of b
|
||||
const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2, 5); // conj of b
|
||||
const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3, 5); // conj of b
|
||||
const __m256d aii0 = _mm256_shuffle_pd(ari0,ari0, 15); // im of a
|
||||
const __m256d aii1 = _mm256_shuffle_pd(ari1,ari1, 15); // im of a
|
||||
const __m256d aii2 = _mm256_shuffle_pd(ari2,ari2, 15); // im of a
|
||||
const __m256d aii3 = _mm256_shuffle_pd(ari3,ari3, 15); // im of a
|
||||
const __m256d pro0 = _mm256_mul_pd(aii0,bir0);
|
||||
const __m256d pro1 = _mm256_mul_pd(aii1,bir1);
|
||||
const __m256d pro2 = _mm256_mul_pd(aii2,bir2);
|
||||
const __m256d pro3 = _mm256_mul_pd(aii3,bir3);
|
||||
const __m256d arr0 = _mm256_shuffle_pd(ari0,ari0, 0); // rr of a
|
||||
const __m256d arr1 = _mm256_shuffle_pd(ari1,ari1, 0); // rr of a
|
||||
const __m256d arr2 = _mm256_shuffle_pd(ari2,ari2, 0); // rr of a
|
||||
const __m256d arr3 = _mm256_shuffle_pd(ari3,ari3, 0); // rr of a
|
||||
const __m256d res0 = _mm256_fmaddsub_pd(arr0,bri0,pro0);
|
||||
const __m256d res1 = _mm256_fmaddsub_pd(arr1,bri1,pro1);
|
||||
const __m256d res2 = _mm256_fmaddsub_pd(arr2,bri2,pro2);
|
||||
const __m256d res3 = _mm256_fmaddsub_pd(arr3,bri3,pro3);
|
||||
_mm256_storeu_pd(rr[0],res0);
|
||||
_mm256_storeu_pd(rr[1],res1);
|
||||
_mm256_storeu_pd(rr[2],res2);
|
||||
_mm256_storeu_pd(rr[3],res3);
|
||||
// END_INTERLEAVE
|
||||
rr += 4;
|
||||
aa += 4;
|
||||
bb += 4;
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b) {
|
||||
const double(*aa)[4] = (double(*)[4])a;
|
||||
const double(*bb)[4] = (double(*)[4])b;
|
||||
double(*rr)[4] = (double(*)[4])r;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||
const __m256d res% = _mm256_add_pd(ari%,bri%);
|
||||
_mm256_storeu_pd(rr[%],res%);
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||
const __m256d res0 = _mm256_add_pd(ari0,bri0);
|
||||
const __m256d res1 = _mm256_add_pd(ari1,bri1);
|
||||
const __m256d res2 = _mm256_add_pd(ari2,bri2);
|
||||
const __m256d res3 = _mm256_add_pd(ari3,bri3);
|
||||
_mm256_storeu_pd(rr[0],res0);
|
||||
_mm256_storeu_pd(rr[1],res1);
|
||||
_mm256_storeu_pd(rr[2],res2);
|
||||
_mm256_storeu_pd(rr[3],res3);
|
||||
// END_INTERLEAVE
|
||||
rr += 4;
|
||||
aa += 4;
|
||||
bb += 4;
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b) {
|
||||
const double(*aa)[4] = (double(*)[4])a;
|
||||
const double(*bb)[4] = (double(*)[4])b;
|
||||
double(*rr)[4] = (double(*)[4])r;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||
const __m256d sum% = _mm256_add_pd(ari%,bri%);
|
||||
const __m256d rri% = _mm256_loadu_pd(rr[%]);
|
||||
const __m256d res% = _mm256_sub_pd(rri%,sum%);
|
||||
_mm256_storeu_pd(rr[%],res%);
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||
const __m256d sum0 = _mm256_add_pd(ari0,bri0);
|
||||
const __m256d sum1 = _mm256_add_pd(ari1,bri1);
|
||||
const __m256d sum2 = _mm256_add_pd(ari2,bri2);
|
||||
const __m256d sum3 = _mm256_add_pd(ari3,bri3);
|
||||
const __m256d rri0 = _mm256_loadu_pd(rr[0]);
|
||||
const __m256d rri1 = _mm256_loadu_pd(rr[1]);
|
||||
const __m256d rri2 = _mm256_loadu_pd(rr[2]);
|
||||
const __m256d rri3 = _mm256_loadu_pd(rr[3]);
|
||||
const __m256d res0 = _mm256_sub_pd(rri0,sum0);
|
||||
const __m256d res1 = _mm256_sub_pd(rri1,sum1);
|
||||
const __m256d res2 = _mm256_sub_pd(rri2,sum2);
|
||||
const __m256d res3 = _mm256_sub_pd(rri3,sum3);
|
||||
_mm256_storeu_pd(rr[0],res0);
|
||||
_mm256_storeu_pd(rr[1],res1);
|
||||
_mm256_storeu_pd(rr[2],res2);
|
||||
_mm256_storeu_pd(rr[3],res3);
|
||||
// END_INTERLEAVE
|
||||
rr += 4;
|
||||
aa += 4;
|
||||
bb += 4;
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) {
|
||||
const double(*aa)[4] = (double(*)[4])a;
|
||||
double(*rr)[4] = (double(*)[4])r;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
_mm256_storeu_pd(rr[%],ari%);
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
_mm256_storeu_pd(rr[0],ari0);
|
||||
_mm256_storeu_pd(rr[1],ari1);
|
||||
_mm256_storeu_pd(rr[2],ari2);
|
||||
_mm256_storeu_pd(rr[3],ari3);
|
||||
// END_INTERLEAVE
|
||||
rr += 4;
|
||||
aa += 4;
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_twiddle_fma(uint32_t m, void* a, void* b, const void* omg) {
|
||||
double(*aa)[4] = (double(*)[4])a;
|
||||
double(*bb)[4] = (double(*)[4])b;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
const __m256d om = _mm256_loadu_pd(omg);
|
||||
const __m256d omrr = _mm256_shuffle_pd(om, om, 0);
|
||||
const __m256d omii = _mm256_shuffle_pd(om, om, 15);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||
const __m256d bir% = _mm256_shuffle_pd(bri%,bri%,5);
|
||||
__m256d p% = _mm256_mul_pd(bir%,omii);
|
||||
p% = _mm256_fmaddsub_pd(bri%,omrr,p%);
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
_mm256_storeu_pd(aa[%],_mm256_add_pd(ari%,p%));
|
||||
_mm256_storeu_pd(bb[%],_mm256_sub_pd(ari%,p%));
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||
const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0,5);
|
||||
const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1,5);
|
||||
const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2,5);
|
||||
const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3,5);
|
||||
__m256d p0 = _mm256_mul_pd(bir0,omii);
|
||||
__m256d p1 = _mm256_mul_pd(bir1,omii);
|
||||
__m256d p2 = _mm256_mul_pd(bir2,omii);
|
||||
__m256d p3 = _mm256_mul_pd(bir3,omii);
|
||||
p0 = _mm256_fmaddsub_pd(bri0,omrr,p0);
|
||||
p1 = _mm256_fmaddsub_pd(bri1,omrr,p1);
|
||||
p2 = _mm256_fmaddsub_pd(bri2,omrr,p2);
|
||||
p3 = _mm256_fmaddsub_pd(bri3,omrr,p3);
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
_mm256_storeu_pd(aa[0],_mm256_add_pd(ari0,p0));
|
||||
_mm256_storeu_pd(aa[1],_mm256_add_pd(ari1,p1));
|
||||
_mm256_storeu_pd(aa[2],_mm256_add_pd(ari2,p2));
|
||||
_mm256_storeu_pd(aa[3],_mm256_add_pd(ari3,p3));
|
||||
_mm256_storeu_pd(bb[0],_mm256_sub_pd(ari0,p0));
|
||||
_mm256_storeu_pd(bb[1],_mm256_sub_pd(ari1,p1));
|
||||
_mm256_storeu_pd(bb[2],_mm256_sub_pd(ari2,p2));
|
||||
_mm256_storeu_pd(bb[3],_mm256_sub_pd(ari3,p3));
|
||||
// END_INTERLEAVE
|
||||
bb += 4;
|
||||
aa += 4;
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_innerprod_avx2_fma(const CPLX_FFTVEC_INNERPROD_PRECOMP* precomp, const int32_t ellbar,
|
||||
const uint64_t lda, const uint64_t ldb,
|
||||
void* r, const void* a, const void* b) {
|
||||
const uint32_t m = precomp->m;
|
||||
const uint32_t blk = precomp->blk;
|
||||
const uint32_t nblocks = precomp->nblocks;
|
||||
const CPLX* aa = (CPLX*)a;
|
||||
const CPLX* bb = (CPLX*)b;
|
||||
CPLX* rr = (CPLX*)r;
|
||||
const uint64_t ldda = lda >> 4; // in CPLX
|
||||
const uint64_t lddb = ldb >> 4;
|
||||
if (m==0) {
|
||||
memset(r, 0, m*sizeof(CPLX));
|
||||
return;
|
||||
}
|
||||
for (uint32_t k=0; k<nblocks; ++k) {
|
||||
const uint64_t offset = k*blk;
|
||||
const CPLX* aaa = aa+offset;
|
||||
const CPLX* bbb = bb+offset;
|
||||
CPLX *rrr = rr+offset;
|
||||
cplx_fftvec_mul_fma(&precomp->mul_func, rrr, aaa, bbb);
|
||||
for (int32_t i=1; i<ellbar; ++i) {
|
||||
cplx_fftvec_addmul_fma(&precomp->addmul_func, rrr, aaa + i * ldda, bbb + i * lddb);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -1,387 +0,0 @@
|
||||
#include <immintrin.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
typedef double D4MEM[4];
|
||||
|
||||
EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||
const uint32_t m = precomp->m;
|
||||
const double(*aa)[4] = (double(*)[4])a;
|
||||
const double(*bb)[4] = (double(*)[4])b;
|
||||
double(*rr)[4] = (double(*)[4])r;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||
const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5); // conj of b
|
||||
const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15); // im of a
|
||||
const __m256d pro% = _mm256_mul_pd(aii%,bir%);
|
||||
const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0); // rr of a
|
||||
const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%);
|
||||
_mm256_storeu_pd(rr[%],res%);
|
||||
rr += @; // ONCE
|
||||
aa += @; // ONCE
|
||||
bb += @; // ONCE
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
// This block is automatically generated from the template above
|
||||
// by the interleave.pl script. Please do not edit by hand
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||
const __m256d bir0 = _mm256_shuffle_pd(bri0, bri0, 5); // conj of b
|
||||
const __m256d bir1 = _mm256_shuffle_pd(bri1, bri1, 5); // conj of b
|
||||
const __m256d bir2 = _mm256_shuffle_pd(bri2, bri2, 5); // conj of b
|
||||
const __m256d bir3 = _mm256_shuffle_pd(bri3, bri3, 5); // conj of b
|
||||
const __m256d aii0 = _mm256_shuffle_pd(ari0, ari0, 15); // im of a
|
||||
const __m256d aii1 = _mm256_shuffle_pd(ari1, ari1, 15); // im of a
|
||||
const __m256d aii2 = _mm256_shuffle_pd(ari2, ari2, 15); // im of a
|
||||
const __m256d aii3 = _mm256_shuffle_pd(ari3, ari3, 15); // im of a
|
||||
const __m256d pro0 = _mm256_mul_pd(aii0, bir0);
|
||||
const __m256d pro1 = _mm256_mul_pd(aii1, bir1);
|
||||
const __m256d pro2 = _mm256_mul_pd(aii2, bir2);
|
||||
const __m256d pro3 = _mm256_mul_pd(aii3, bir3);
|
||||
const __m256d arr0 = _mm256_shuffle_pd(ari0, ari0, 0); // rr of a
|
||||
const __m256d arr1 = _mm256_shuffle_pd(ari1, ari1, 0); // rr of a
|
||||
const __m256d arr2 = _mm256_shuffle_pd(ari2, ari2, 0); // rr of a
|
||||
const __m256d arr3 = _mm256_shuffle_pd(ari3, ari3, 0); // rr of a
|
||||
const __m256d res0 = _mm256_fmaddsub_pd(arr0, bri0, pro0);
|
||||
const __m256d res1 = _mm256_fmaddsub_pd(arr1, bri1, pro1);
|
||||
const __m256d res2 = _mm256_fmaddsub_pd(arr2, bri2, pro2);
|
||||
const __m256d res3 = _mm256_fmaddsub_pd(arr3, bri3, pro3);
|
||||
_mm256_storeu_pd(rr[0], res0);
|
||||
_mm256_storeu_pd(rr[1], res1);
|
||||
_mm256_storeu_pd(rr[2], res2);
|
||||
_mm256_storeu_pd(rr[3], res3);
|
||||
rr += 4; // ONCE
|
||||
aa += 4; // ONCE
|
||||
bb += 4; // ONCE
|
||||
// END_INTERLEAVE
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||
const uint32_t m = precomp->m;
|
||||
const double(*aa)[4] = (double(*)[4])a;
|
||||
const double(*bb)[4] = (double(*)[4])b;
|
||||
double(*rr)[4] = (double(*)[4])r;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||
const __m256d rri% = _mm256_loadu_pd(rr[%]);
|
||||
const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5);
|
||||
const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15);
|
||||
const __m256d pro% = _mm256_fmaddsub_pd(aii%,bir%,rri%);
|
||||
const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0);
|
||||
const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%);
|
||||
_mm256_storeu_pd(rr[%],res%);
|
||||
rr += @; // ONCE
|
||||
aa += @; // ONCE
|
||||
bb += @; // ONCE
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 2
|
||||
// This block is automatically generated from the template above
|
||||
// by the interleave.pl script. Please do not edit by hand
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||
const __m256d rri0 = _mm256_loadu_pd(rr[0]);
|
||||
const __m256d rri1 = _mm256_loadu_pd(rr[1]);
|
||||
const __m256d bir0 = _mm256_shuffle_pd(bri0, bri0, 5);
|
||||
const __m256d bir1 = _mm256_shuffle_pd(bri1, bri1, 5);
|
||||
const __m256d aii0 = _mm256_shuffle_pd(ari0, ari0, 15);
|
||||
const __m256d aii1 = _mm256_shuffle_pd(ari1, ari1, 15);
|
||||
const __m256d pro0 = _mm256_fmaddsub_pd(aii0, bir0, rri0);
|
||||
const __m256d pro1 = _mm256_fmaddsub_pd(aii1, bir1, rri1);
|
||||
const __m256d arr0 = _mm256_shuffle_pd(ari0, ari0, 0);
|
||||
const __m256d arr1 = _mm256_shuffle_pd(ari1, ari1, 0);
|
||||
const __m256d res0 = _mm256_fmaddsub_pd(arr0, bri0, pro0);
|
||||
const __m256d res1 = _mm256_fmaddsub_pd(arr1, bri1, pro1);
|
||||
_mm256_storeu_pd(rr[0], res0);
|
||||
_mm256_storeu_pd(rr[1], res1);
|
||||
rr += 2; // ONCE
|
||||
aa += 2; // ONCE
|
||||
bb += 2; // ONCE
|
||||
// END_INTERLEAVE
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* precomp, void* a, uint64_t slicea,
|
||||
const void* omg) {
|
||||
const uint32_t m = precomp->m;
|
||||
const uint64_t OFFSET = slicea / sizeof(D4MEM);
|
||||
D4MEM* aa = (D4MEM*)a;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
const __m256d om = _mm256_loadu_pd(omg);
|
||||
const __m256d om1rr = _mm256_shuffle_pd(om, om, 0);
|
||||
const __m256d om1ii = _mm256_shuffle_pd(om, om, 15);
|
||||
const __m256d om2rr = _mm256_shuffle_pd(om, om, 0);
|
||||
const __m256d om2ii = _mm256_shuffle_pd(om, om, 0);
|
||||
const __m256d om3rr = _mm256_shuffle_pd(om, om, 15);
|
||||
const __m256d om3ii = _mm256_shuffle_pd(om, om, 15);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
__m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
__m256d bri% = _mm256_loadu_pd((aa+OFFSET)[%]);
|
||||
__m256d cri% = _mm256_loadu_pd((aa+2*OFFSET)[%]);
|
||||
__m256d dri% = _mm256_loadu_pd((aa+3*OFFSET)[%]);
|
||||
__m256d pa% = _mm256_shuffle_pd(cri%,cri%,5);
|
||||
__m256d pb% = _mm256_shuffle_pd(dri%,dri%,5);
|
||||
pa% = _mm256_mul_pd(pa%,om1ii);
|
||||
pb% = _mm256_mul_pd(pb%,om1ii);
|
||||
pa% = _mm256_fmaddsub_pd(cri%,om1rr,pa%);
|
||||
pb% = _mm256_fmaddsub_pd(dri%,om1rr,pb%);
|
||||
cri% = _mm256_sub_pd(ari%,pa%);
|
||||
dri% = _mm256_sub_pd(bri%,pb%);
|
||||
ari% = _mm256_add_pd(ari%,pa%);
|
||||
bri% = _mm256_add_pd(bri%,pb%);
|
||||
pa% = _mm256_shuffle_pd(bri%,bri%,5);
|
||||
pb% = _mm256_shuffle_pd(dri%,dri%,5);
|
||||
pa% = _mm256_mul_pd(pa%,om2ii);
|
||||
pb% = _mm256_mul_pd(pb%,om3ii);
|
||||
pa% = _mm256_fmaddsub_pd(bri%,om2rr,pa%);
|
||||
pb% = _mm256_fmaddsub_pd(dri%,om3rr,pb%);
|
||||
bri% = _mm256_sub_pd(ari%,pa%);
|
||||
dri% = _mm256_sub_pd(cri%,pb%);
|
||||
ari% = _mm256_add_pd(ari%,pa%);
|
||||
cri% = _mm256_add_pd(cri%,pb%);
|
||||
_mm256_storeu_pd(aa[%], ari%);
|
||||
_mm256_storeu_pd((aa+OFFSET)[%],bri%);
|
||||
_mm256_storeu_pd((aa+2*OFFSET)[%],cri%);
|
||||
_mm256_storeu_pd((aa+3*OFFSET)[%],dri%);
|
||||
aa += @; // ONCE
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 1
|
||||
// This block is automatically generated from the template above
|
||||
// by the interleave.pl script. Please do not edit by hand
|
||||
__m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
__m256d bri0 = _mm256_loadu_pd((aa + OFFSET)[0]);
|
||||
__m256d cri0 = _mm256_loadu_pd((aa + 2 * OFFSET)[0]);
|
||||
__m256d dri0 = _mm256_loadu_pd((aa + 3 * OFFSET)[0]);
|
||||
__m256d pa0 = _mm256_shuffle_pd(cri0, cri0, 5);
|
||||
__m256d pb0 = _mm256_shuffle_pd(dri0, dri0, 5);
|
||||
pa0 = _mm256_mul_pd(pa0, om1ii);
|
||||
pb0 = _mm256_mul_pd(pb0, om1ii);
|
||||
pa0 = _mm256_fmaddsub_pd(cri0, om1rr, pa0);
|
||||
pb0 = _mm256_fmaddsub_pd(dri0, om1rr, pb0);
|
||||
cri0 = _mm256_sub_pd(ari0, pa0);
|
||||
dri0 = _mm256_sub_pd(bri0, pb0);
|
||||
ari0 = _mm256_add_pd(ari0, pa0);
|
||||
bri0 = _mm256_add_pd(bri0, pb0);
|
||||
pa0 = _mm256_shuffle_pd(bri0, bri0, 5);
|
||||
pb0 = _mm256_shuffle_pd(dri0, dri0, 5);
|
||||
pa0 = _mm256_mul_pd(pa0, om2ii);
|
||||
pb0 = _mm256_mul_pd(pb0, om3ii);
|
||||
pa0 = _mm256_fmaddsub_pd(bri0, om2rr, pa0);
|
||||
pb0 = _mm256_fmaddsub_pd(dri0, om3rr, pb0);
|
||||
bri0 = _mm256_sub_pd(ari0, pa0);
|
||||
dri0 = _mm256_sub_pd(cri0, pb0);
|
||||
ari0 = _mm256_add_pd(ari0, pa0);
|
||||
cri0 = _mm256_add_pd(cri0, pb0);
|
||||
_mm256_storeu_pd(aa[0], ari0);
|
||||
_mm256_storeu_pd((aa + OFFSET)[0], bri0);
|
||||
_mm256_storeu_pd((aa + 2 * OFFSET)[0], cri0);
|
||||
_mm256_storeu_pd((aa + 3 * OFFSET)[0], dri0);
|
||||
aa += 1; // ONCE
|
||||
// END_INTERLEAVE
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b) {
|
||||
const double(*aa)[4] = (double(*)[4])a;
|
||||
const double(*bb)[4] = (double(*)[4])b;
|
||||
double(*rr)[4] = (double(*)[4])r;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||
const __m256d sum% = _mm256_add_pd(ari%,bri%);
|
||||
const __m256d rri% = _mm256_loadu_pd(rr[%]);
|
||||
const __m256d res% = _mm256_sub_pd(rri%,sum%);
|
||||
_mm256_storeu_pd(rr[%],res%);
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
// This block is automatically generated from the template above
|
||||
// by the interleave.pl script. Please do not edit by hand
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||
const __m256d sum0 = _mm256_add_pd(ari0, bri0);
|
||||
const __m256d sum1 = _mm256_add_pd(ari1, bri1);
|
||||
const __m256d sum2 = _mm256_add_pd(ari2, bri2);
|
||||
const __m256d sum3 = _mm256_add_pd(ari3, bri3);
|
||||
const __m256d rri0 = _mm256_loadu_pd(rr[0]);
|
||||
const __m256d rri1 = _mm256_loadu_pd(rr[1]);
|
||||
const __m256d rri2 = _mm256_loadu_pd(rr[2]);
|
||||
const __m256d rri3 = _mm256_loadu_pd(rr[3]);
|
||||
const __m256d res0 = _mm256_sub_pd(rri0, sum0);
|
||||
const __m256d res1 = _mm256_sub_pd(rri1, sum1);
|
||||
const __m256d res2 = _mm256_sub_pd(rri2, sum2);
|
||||
const __m256d res3 = _mm256_sub_pd(rri3, sum3);
|
||||
_mm256_storeu_pd(rr[0], res0);
|
||||
_mm256_storeu_pd(rr[1], res1);
|
||||
_mm256_storeu_pd(rr[2], res2);
|
||||
_mm256_storeu_pd(rr[3], res3);
|
||||
// END_INTERLEAVE
|
||||
rr += 4;
|
||||
aa += 4;
|
||||
bb += 4;
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b) {
|
||||
const double(*aa)[4] = (double(*)[4])a;
|
||||
const double(*bb)[4] = (double(*)[4])b;
|
||||
double(*rr)[4] = (double(*)[4])r;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||
const __m256d res% = _mm256_add_pd(ari%,bri%);
|
||||
_mm256_storeu_pd(rr[%],res%);
|
||||
rr += @; // ONCE
|
||||
aa += @; // ONCE
|
||||
bb += @; // ONCE
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
// This block is automatically generated from the template above
|
||||
// by the interleave.pl script. Please do not edit by hand
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||
const __m256d res0 = _mm256_add_pd(ari0, bri0);
|
||||
const __m256d res1 = _mm256_add_pd(ari1, bri1);
|
||||
const __m256d res2 = _mm256_add_pd(ari2, bri2);
|
||||
const __m256d res3 = _mm256_add_pd(ari3, bri3);
|
||||
_mm256_storeu_pd(rr[0], res0);
|
||||
_mm256_storeu_pd(rr[1], res1);
|
||||
_mm256_storeu_pd(rr[2], res2);
|
||||
_mm256_storeu_pd(rr[3], res3);
|
||||
rr += 4; // ONCE
|
||||
aa += 4; // ONCE
|
||||
bb += 4; // ONCE
|
||||
// END_INTERLEAVE
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* precomp, void* a, void* b, const void* omg) {
|
||||
const uint32_t m = precomp->m;
|
||||
double(*aa)[4] = (double(*)[4])a;
|
||||
double(*bb)[4] = (double(*)[4])b;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
const __m256d om = _mm256_loadu_pd(omg);
|
||||
const __m256d omrr = _mm256_shuffle_pd(om, om, 0);
|
||||
const __m256d omii = _mm256_shuffle_pd(om, om, 15);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d bri% = _mm256_loadu_pd(bb[%]);
|
||||
const __m256d bir% = _mm256_shuffle_pd(bri%,bri%,5);
|
||||
__m256d p% = _mm256_mul_pd(bir%,omii);
|
||||
p% = _mm256_fmaddsub_pd(bri%,omrr,p%);
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
_mm256_storeu_pd(aa[%],_mm256_add_pd(ari%,p%));
|
||||
_mm256_storeu_pd(bb[%],_mm256_sub_pd(ari%,p%));
|
||||
bb += @; // ONCE
|
||||
aa += @; // ONCE
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
// This block is automatically generated from the template above
|
||||
// by the interleave.pl script. Please do not edit by hand
|
||||
const __m256d bri0 = _mm256_loadu_pd(bb[0]);
|
||||
const __m256d bri1 = _mm256_loadu_pd(bb[1]);
|
||||
const __m256d bri2 = _mm256_loadu_pd(bb[2]);
|
||||
const __m256d bri3 = _mm256_loadu_pd(bb[3]);
|
||||
const __m256d bir0 = _mm256_shuffle_pd(bri0, bri0, 5);
|
||||
const __m256d bir1 = _mm256_shuffle_pd(bri1, bri1, 5);
|
||||
const __m256d bir2 = _mm256_shuffle_pd(bri2, bri2, 5);
|
||||
const __m256d bir3 = _mm256_shuffle_pd(bri3, bri3, 5);
|
||||
__m256d p0 = _mm256_mul_pd(bir0, omii);
|
||||
__m256d p1 = _mm256_mul_pd(bir1, omii);
|
||||
__m256d p2 = _mm256_mul_pd(bir2, omii);
|
||||
__m256d p3 = _mm256_mul_pd(bir3, omii);
|
||||
p0 = _mm256_fmaddsub_pd(bri0, omrr, p0);
|
||||
p1 = _mm256_fmaddsub_pd(bri1, omrr, p1);
|
||||
p2 = _mm256_fmaddsub_pd(bri2, omrr, p2);
|
||||
p3 = _mm256_fmaddsub_pd(bri3, omrr, p3);
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
_mm256_storeu_pd(aa[0], _mm256_add_pd(ari0, p0));
|
||||
_mm256_storeu_pd(aa[1], _mm256_add_pd(ari1, p1));
|
||||
_mm256_storeu_pd(aa[2], _mm256_add_pd(ari2, p2));
|
||||
_mm256_storeu_pd(aa[3], _mm256_add_pd(ari3, p3));
|
||||
_mm256_storeu_pd(bb[0], _mm256_sub_pd(ari0, p0));
|
||||
_mm256_storeu_pd(bb[1], _mm256_sub_pd(ari1, p1));
|
||||
_mm256_storeu_pd(bb[2], _mm256_sub_pd(ari2, p2));
|
||||
_mm256_storeu_pd(bb[3], _mm256_sub_pd(ari3, p3));
|
||||
bb += 4; // ONCE
|
||||
aa += 4; // ONCE
|
||||
// END_INTERLEAVE
|
||||
} while (aa < aend);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) {
|
||||
const double(*aa)[4] = (double(*)[4])a;
|
||||
double(*rr)[4] = (double(*)[4])r;
|
||||
const double(*const aend)[4] = aa + (m >> 1);
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d ari% = _mm256_loadu_pd(aa[%]);
|
||||
_mm256_storeu_pd(rr[%],ari%);
|
||||
rr += @; // ONCE
|
||||
aa += @; // ONCE
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 4
|
||||
// This block is automatically generated from the template above
|
||||
// by the interleave.pl script. Please do not edit by hand
|
||||
const __m256d ari0 = _mm256_loadu_pd(aa[0]);
|
||||
const __m256d ari1 = _mm256_loadu_pd(aa[1]);
|
||||
const __m256d ari2 = _mm256_loadu_pd(aa[2]);
|
||||
const __m256d ari3 = _mm256_loadu_pd(aa[3]);
|
||||
_mm256_storeu_pd(rr[0], ari0);
|
||||
_mm256_storeu_pd(rr[1], ari1);
|
||||
_mm256_storeu_pd(rr[2], ari2);
|
||||
_mm256_storeu_pd(rr[3], ari3);
|
||||
rr += 4; // ONCE
|
||||
aa += 4; // ONCE
|
||||
// END_INTERLEAVE
|
||||
} while (aa < aend);
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "../commons_private.h"
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
EXPORT void cplx_fftvec_addmul_ref(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||
const uint32_t m = precomp->m;
|
||||
const CPLX* aa = (CPLX*)a;
|
||||
const CPLX* bb = (CPLX*)b;
|
||||
CPLX* rr = (CPLX*)r;
|
||||
for (uint32_t i = 0; i < m; ++i) {
|
||||
const double re = aa[i][0] * bb[i][0] - aa[i][1] * bb[i][1];
|
||||
const double im = aa[i][0] * bb[i][1] + aa[i][1] * bb[i][0];
|
||||
rr[i][0] += re;
|
||||
rr[i][1] += im;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_mul_ref(const CPLX_FFTVEC_MUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||
const uint32_t m = precomp->m;
|
||||
const CPLX* aa = (CPLX*)a;
|
||||
const CPLX* bb = (CPLX*)b;
|
||||
CPLX* rr = (CPLX*)r;
|
||||
for (uint32_t i = 0; i < m; ++i) {
|
||||
const double re = aa[i][0] * bb[i][0] - aa[i][1] * bb[i][1];
|
||||
const double im = aa[i][0] * bb[i][1] + aa[i][1] * bb[i][0];
|
||||
rr[i][0] = re;
|
||||
rr[i][1] = im;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void* init_cplx_fftvec_addmul_precomp(CPLX_FFTVEC_ADDMUL_PRECOMP* r, uint32_t m) {
|
||||
if (m & (m - 1)) return spqlios_error("m must be a power of two");
|
||||
r->m = m;
|
||||
if (m <= 4) {
|
||||
r->function = cplx_fftvec_addmul_ref;
|
||||
} else if (CPU_SUPPORTS("fma")) {
|
||||
r->function = cplx_fftvec_addmul_fma;
|
||||
} else {
|
||||
r->function = cplx_fftvec_addmul_ref;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
EXPORT void* init_cplx_fftvec_mul_precomp(CPLX_FFTVEC_MUL_PRECOMP* r, uint32_t m) {
|
||||
if (m & (m - 1)) return spqlios_error("m must be a power of two");
|
||||
r->m = m;
|
||||
if (m <= 4) {
|
||||
r->function = cplx_fftvec_mul_ref;
|
||||
} else if (CPU_SUPPORTS("fma")) {
|
||||
r->function = cplx_fftvec_mul_fma;
|
||||
} else {
|
||||
r->function = cplx_fftvec_mul_ref;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
EXPORT CPLX_FFTVEC_ADDMUL_PRECOMP* new_cplx_fftvec_addmul_precomp(uint32_t m) {
|
||||
CPLX_FFTVEC_ADDMUL_PRECOMP* r = malloc(sizeof(CPLX_FFTVEC_MUL_PRECOMP));
|
||||
return spqlios_keep_or_free(r, init_cplx_fftvec_addmul_precomp(r, m));
|
||||
}
|
||||
|
||||
EXPORT CPLX_FFTVEC_MUL_PRECOMP* new_cplx_fftvec_mul_precomp(uint32_t m) {
|
||||
CPLX_FFTVEC_MUL_PRECOMP* r = malloc(sizeof(CPLX_FFTVEC_MUL_PRECOMP));
|
||||
return spqlios_keep_or_free(r, init_cplx_fftvec_mul_precomp(r, m));
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b) {
|
||||
static CPLX_FFTVEC_MUL_PRECOMP p[31] = {0};
|
||||
CPLX_FFTVEC_MUL_PRECOMP* f = p + log2m(m);
|
||||
if (!f->function) {
|
||||
if (!init_cplx_fftvec_mul_precomp(f, m)) abort();
|
||||
}
|
||||
f->function(f, r, a, b);
|
||||
}
|
||||
|
||||
EXPORT void cplx_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b) {
|
||||
static CPLX_FFTVEC_ADDMUL_PRECOMP p[31] = {0};
|
||||
CPLX_FFTVEC_ADDMUL_PRECOMP* f = p + log2m(m);
|
||||
if (!f->function) {
|
||||
if (!init_cplx_fftvec_addmul_precomp(f, m)) abort();
|
||||
}
|
||||
f->function(f, r, a, b);
|
||||
}
|
||||
@@ -1,157 +0,0 @@
|
||||
# shifted FFT over X^16-i
|
||||
# 1st argument (rdi) contains 16 complexes
|
||||
# 2nd argument (rsi) contains: 8 complexes
|
||||
# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma
|
||||
# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta)
|
||||
# j = sqrt(i), k=sqrt(j)
|
||||
.globl cplx_ifft16_avx_fma
|
||||
cplx_ifft16_avx_fma:
|
||||
vmovupd (%rdi),%ymm8 # load data into registers %ymm8 -> %ymm15
|
||||
vmovupd 0x20(%rdi),%ymm9
|
||||
vmovupd 0x40(%rdi),%ymm10
|
||||
vmovupd 0x60(%rdi),%ymm11
|
||||
vmovupd 0x80(%rdi),%ymm12
|
||||
vmovupd 0xa0(%rdi),%ymm13
|
||||
vmovupd 0xc0(%rdi),%ymm14
|
||||
vmovupd 0xe0(%rdi),%ymm15
|
||||
|
||||
.fourth_pass:
|
||||
vmovupd 0(%rsi),%ymm0 /* gamma */
|
||||
vmovupd 32(%rsi),%ymm2 /* delta */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||
vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5
|
||||
vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7
|
||||
vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13
|
||||
vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15
|
||||
vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||
vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||
vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12
|
||||
vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14
|
||||
vsubpd %ymm4,%ymm8,%ymm12 # tw: to mul by gamma
|
||||
vsubpd %ymm5,%ymm9,%ymm13 # itw: to mul by i.gamma
|
||||
vsubpd %ymm6,%ymm10,%ymm14 # tw: to mul by delta
|
||||
vsubpd %ymm7,%ymm11,%ymm15 # itw: to mul by i.delta
|
||||
vaddpd %ymm4,%ymm8,%ymm8
|
||||
vaddpd %ymm5,%ymm9,%ymm9
|
||||
vaddpd %ymm6,%ymm10,%ymm10
|
||||
vaddpd %ymm7,%ymm11,%ymm11
|
||||
vshufpd $5, %ymm12, %ymm12, %ymm4
|
||||
vshufpd $5, %ymm13, %ymm13, %ymm5
|
||||
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||
vmulpd %ymm4,%ymm1,%ymm4
|
||||
vmulpd %ymm5,%ymm0,%ymm5
|
||||
vmulpd %ymm6,%ymm3,%ymm6
|
||||
vmulpd %ymm7,%ymm2,%ymm7
|
||||
vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4
|
||||
vfmsubadd231pd %ymm13, %ymm1, %ymm5
|
||||
vfmaddsub231pd %ymm14, %ymm2, %ymm6
|
||||
vfmsubadd231pd %ymm15, %ymm3, %ymm7
|
||||
|
||||
vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma
|
||||
vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma
|
||||
vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta
|
||||
vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta
|
||||
vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12
|
||||
vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14
|
||||
vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||
vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||
|
||||
|
||||
.third_pass:
|
||||
vmovupd 64(%rsi),%xmm0 /* gamma */
|
||||
vmovupd 80(%rsi),%xmm2 /* delta */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0
|
||||
vinsertf128 $1, %xmm2, %ymm2, %ymm2
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||
vsubpd %ymm9,%ymm8,%ymm4
|
||||
vsubpd %ymm11,%ymm10,%ymm5
|
||||
vsubpd %ymm13,%ymm12,%ymm6
|
||||
vsubpd %ymm15,%ymm14,%ymm7
|
||||
vaddpd %ymm9,%ymm8,%ymm8
|
||||
vaddpd %ymm11,%ymm10,%ymm10
|
||||
vaddpd %ymm13,%ymm12,%ymm12
|
||||
vaddpd %ymm15,%ymm14,%ymm14
|
||||
vshufpd $5, %ymm4, %ymm4, %ymm9
|
||||
vshufpd $5, %ymm5, %ymm5, %ymm11
|
||||
vshufpd $5, %ymm6, %ymm6, %ymm13
|
||||
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||
vmulpd %ymm9,%ymm1,%ymm9
|
||||
vmulpd %ymm11,%ymm0,%ymm11
|
||||
vmulpd %ymm13,%ymm3,%ymm13
|
||||
vmulpd %ymm15,%ymm2,%ymm15
|
||||
vfmaddsub231pd %ymm4, %ymm0, %ymm9 # ymm9 = (ymm0 * ymm4) +/- ymm9
|
||||
vfmsubadd231pd %ymm5, %ymm1, %ymm11
|
||||
vfmaddsub231pd %ymm6, %ymm2, %ymm13
|
||||
vfmsubadd231pd %ymm7, %ymm3, %ymm15
|
||||
|
||||
.second_pass:
|
||||
vmovupd 96(%rsi),%xmm0 /* omri */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||
vsubpd %ymm10,%ymm8,%ymm4
|
||||
vsubpd %ymm11,%ymm9,%ymm5
|
||||
vsubpd %ymm14,%ymm12,%ymm6
|
||||
vsubpd %ymm15,%ymm13,%ymm7
|
||||
vaddpd %ymm10,%ymm8,%ymm8
|
||||
vaddpd %ymm11,%ymm9,%ymm9
|
||||
vaddpd %ymm14,%ymm12,%ymm12
|
||||
vaddpd %ymm15,%ymm13,%ymm13
|
||||
vshufpd $5, %ymm4, %ymm4, %ymm10
|
||||
vshufpd $5, %ymm5, %ymm5, %ymm11
|
||||
vshufpd $5, %ymm6, %ymm6, %ymm14
|
||||
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||
vmulpd %ymm10,%ymm1,%ymm10
|
||||
vmulpd %ymm11,%ymm1,%ymm11
|
||||
vmulpd %ymm14,%ymm0,%ymm14
|
||||
vmulpd %ymm15,%ymm0,%ymm15
|
||||
vfmaddsub231pd %ymm4, %ymm0, %ymm10 # ymm10 = (ymm0 * ymm4) +/- ymm10
|
||||
vfmaddsub231pd %ymm5, %ymm0, %ymm11
|
||||
vfmsubadd231pd %ymm6, %ymm1, %ymm14
|
||||
vfmsubadd231pd %ymm7, %ymm1, %ymm15
|
||||
|
||||
.first_pass:
|
||||
vmovupd 112(%rsi),%xmm0 /* omri */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||
vsubpd %ymm12,%ymm8,%ymm4
|
||||
vsubpd %ymm13,%ymm9,%ymm5
|
||||
vsubpd %ymm14,%ymm10,%ymm6
|
||||
vsubpd %ymm15,%ymm11,%ymm7
|
||||
vaddpd %ymm12,%ymm8,%ymm8
|
||||
vaddpd %ymm13,%ymm9,%ymm9
|
||||
vaddpd %ymm14,%ymm10,%ymm10
|
||||
vaddpd %ymm15,%ymm11,%ymm11
|
||||
vshufpd $5, %ymm4, %ymm4, %ymm12
|
||||
vshufpd $5, %ymm5, %ymm5, %ymm13
|
||||
vshufpd $5, %ymm6, %ymm6, %ymm14
|
||||
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||
vmulpd %ymm12,%ymm1,%ymm12
|
||||
vmulpd %ymm13,%ymm1,%ymm13
|
||||
vmulpd %ymm14,%ymm1,%ymm14
|
||||
vmulpd %ymm15,%ymm1,%ymm15
|
||||
vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12
|
||||
vfmaddsub231pd %ymm5, %ymm0, %ymm13
|
||||
vfmaddsub231pd %ymm6, %ymm0, %ymm14
|
||||
vfmaddsub231pd %ymm7, %ymm0, %ymm15
|
||||
|
||||
.save_and_return:
|
||||
vmovupd %ymm8,(%rdi)
|
||||
vmovupd %ymm9,0x20(%rdi)
|
||||
vmovupd %ymm10,0x40(%rdi)
|
||||
vmovupd %ymm11,0x60(%rdi)
|
||||
vmovupd %ymm12,0x80(%rdi)
|
||||
vmovupd %ymm13,0xa0(%rdi)
|
||||
vmovupd %ymm14,0xc0(%rdi)
|
||||
vmovupd %ymm15,0xe0(%rdi)
|
||||
ret
|
||||
.size cplx_ifft16_avx_fma, .-cplx_ifft16_avx_fma
|
||||
.section .note.GNU-stack,"",@progbits
|
||||
@@ -1,192 +0,0 @@
|
||||
.text
|
||||
.p2align 4
|
||||
.globl cplx_ifft16_avx_fma
|
||||
.def cplx_ifft16_avx_fma; .scl 2; .type 32; .endef
|
||||
cplx_ifft16_avx_fma:
|
||||
|
||||
pushq %rdi
|
||||
pushq %rsi
|
||||
movq %rcx,%rdi
|
||||
movq %rdx,%rsi
|
||||
subq $0x100,%rsp
|
||||
movdqu %xmm6,(%rsp)
|
||||
movdqu %xmm7,0x10(%rsp)
|
||||
movdqu %xmm8,0x20(%rsp)
|
||||
movdqu %xmm9,0x30(%rsp)
|
||||
movdqu %xmm10,0x40(%rsp)
|
||||
movdqu %xmm11,0x50(%rsp)
|
||||
movdqu %xmm12,0x60(%rsp)
|
||||
movdqu %xmm13,0x70(%rsp)
|
||||
movdqu %xmm14,0x80(%rsp)
|
||||
movdqu %xmm15,0x90(%rsp)
|
||||
callq cplx_ifft16_avx_fma_amd64
|
||||
movdqu (%rsp),%xmm6
|
||||
movdqu 0x10(%rsp),%xmm7
|
||||
movdqu 0x20(%rsp),%xmm8
|
||||
movdqu 0x30(%rsp),%xmm9
|
||||
movdqu 0x40(%rsp),%xmm10
|
||||
movdqu 0x50(%rsp),%xmm11
|
||||
movdqu 0x60(%rsp),%xmm12
|
||||
movdqu 0x70(%rsp),%xmm13
|
||||
movdqu 0x80(%rsp),%xmm14
|
||||
movdqu 0x90(%rsp),%xmm15
|
||||
addq $0x100,%rsp
|
||||
popq %rsi
|
||||
popq %rdi
|
||||
retq
|
||||
|
||||
# shifted FFT over X^16-i
|
||||
# 1st argument (rdi) contains 16 complexes
|
||||
# 2nd argument (rsi) contains: 8 complexes
|
||||
# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma
|
||||
# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta)
|
||||
# j = sqrt(i), k=sqrt(j)
|
||||
|
||||
cplx_ifft16_avx_fma_amd64:
|
||||
vmovupd (%rdi),%ymm8 # load data into registers %ymm8 -> %ymm15
|
||||
vmovupd 0x20(%rdi),%ymm9
|
||||
vmovupd 0x40(%rdi),%ymm10
|
||||
vmovupd 0x60(%rdi),%ymm11
|
||||
vmovupd 0x80(%rdi),%ymm12
|
||||
vmovupd 0xa0(%rdi),%ymm13
|
||||
vmovupd 0xc0(%rdi),%ymm14
|
||||
vmovupd 0xe0(%rdi),%ymm15
|
||||
|
||||
.fourth_pass:
|
||||
vmovupd 0(%rsi),%ymm0 /* gamma */
|
||||
vmovupd 32(%rsi),%ymm2 /* delta */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||
vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5
|
||||
vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7
|
||||
vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13
|
||||
vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15
|
||||
vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||
vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||
vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12
|
||||
vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14
|
||||
vsubpd %ymm4,%ymm8,%ymm12 # tw: to mul by gamma
|
||||
vsubpd %ymm5,%ymm9,%ymm13 # itw: to mul by i.gamma
|
||||
vsubpd %ymm6,%ymm10,%ymm14 # tw: to mul by delta
|
||||
vsubpd %ymm7,%ymm11,%ymm15 # itw: to mul by i.delta
|
||||
vaddpd %ymm4,%ymm8,%ymm8
|
||||
vaddpd %ymm5,%ymm9,%ymm9
|
||||
vaddpd %ymm6,%ymm10,%ymm10
|
||||
vaddpd %ymm7,%ymm11,%ymm11
|
||||
vshufpd $5, %ymm12, %ymm12, %ymm4
|
||||
vshufpd $5, %ymm13, %ymm13, %ymm5
|
||||
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||
vmulpd %ymm4,%ymm1,%ymm4
|
||||
vmulpd %ymm5,%ymm0,%ymm5
|
||||
vmulpd %ymm6,%ymm3,%ymm6
|
||||
vmulpd %ymm7,%ymm2,%ymm7
|
||||
vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4
|
||||
vfmsubadd231pd %ymm13, %ymm1, %ymm5
|
||||
vfmaddsub231pd %ymm14, %ymm2, %ymm6
|
||||
vfmsubadd231pd %ymm15, %ymm3, %ymm7
|
||||
|
||||
vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma
|
||||
vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma
|
||||
vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta
|
||||
vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta
|
||||
vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12
|
||||
vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14
|
||||
vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||
vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||
|
||||
|
||||
.third_pass:
|
||||
vmovupd 64(%rsi),%xmm0 /* gamma */
|
||||
vmovupd 80(%rsi),%xmm2 /* delta */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0
|
||||
vinsertf128 $1, %xmm2, %ymm2, %ymm2
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||
vsubpd %ymm9,%ymm8,%ymm4
|
||||
vsubpd %ymm11,%ymm10,%ymm5
|
||||
vsubpd %ymm13,%ymm12,%ymm6
|
||||
vsubpd %ymm15,%ymm14,%ymm7
|
||||
vaddpd %ymm9,%ymm8,%ymm8
|
||||
vaddpd %ymm11,%ymm10,%ymm10
|
||||
vaddpd %ymm13,%ymm12,%ymm12
|
||||
vaddpd %ymm15,%ymm14,%ymm14
|
||||
vshufpd $5, %ymm4, %ymm4, %ymm9
|
||||
vshufpd $5, %ymm5, %ymm5, %ymm11
|
||||
vshufpd $5, %ymm6, %ymm6, %ymm13
|
||||
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||
vmulpd %ymm9,%ymm1,%ymm9
|
||||
vmulpd %ymm11,%ymm0,%ymm11
|
||||
vmulpd %ymm13,%ymm3,%ymm13
|
||||
vmulpd %ymm15,%ymm2,%ymm15
|
||||
vfmaddsub231pd %ymm4, %ymm0, %ymm9 # ymm9 = (ymm0 * ymm4) +/- ymm9
|
||||
vfmsubadd231pd %ymm5, %ymm1, %ymm11
|
||||
vfmaddsub231pd %ymm6, %ymm2, %ymm13
|
||||
vfmsubadd231pd %ymm7, %ymm3, %ymm15
|
||||
|
||||
.second_pass:
|
||||
vmovupd 96(%rsi),%xmm0 /* omri */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||
vsubpd %ymm10,%ymm8,%ymm4
|
||||
vsubpd %ymm11,%ymm9,%ymm5
|
||||
vsubpd %ymm14,%ymm12,%ymm6
|
||||
vsubpd %ymm15,%ymm13,%ymm7
|
||||
vaddpd %ymm10,%ymm8,%ymm8
|
||||
vaddpd %ymm11,%ymm9,%ymm9
|
||||
vaddpd %ymm14,%ymm12,%ymm12
|
||||
vaddpd %ymm15,%ymm13,%ymm13
|
||||
vshufpd $5, %ymm4, %ymm4, %ymm10
|
||||
vshufpd $5, %ymm5, %ymm5, %ymm11
|
||||
vshufpd $5, %ymm6, %ymm6, %ymm14
|
||||
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||
vmulpd %ymm10,%ymm1,%ymm10
|
||||
vmulpd %ymm11,%ymm1,%ymm11
|
||||
vmulpd %ymm14,%ymm0,%ymm14
|
||||
vmulpd %ymm15,%ymm0,%ymm15
|
||||
vfmaddsub231pd %ymm4, %ymm0, %ymm10 # ymm10 = (ymm0 * ymm4) +/- ymm10
|
||||
vfmaddsub231pd %ymm5, %ymm0, %ymm11
|
||||
vfmsubadd231pd %ymm6, %ymm1, %ymm14
|
||||
vfmsubadd231pd %ymm7, %ymm1, %ymm15
|
||||
|
||||
.first_pass:
|
||||
vmovupd 112(%rsi),%xmm0 /* omri */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||
vsubpd %ymm12,%ymm8,%ymm4
|
||||
vsubpd %ymm13,%ymm9,%ymm5
|
||||
vsubpd %ymm14,%ymm10,%ymm6
|
||||
vsubpd %ymm15,%ymm11,%ymm7
|
||||
vaddpd %ymm12,%ymm8,%ymm8
|
||||
vaddpd %ymm13,%ymm9,%ymm9
|
||||
vaddpd %ymm14,%ymm10,%ymm10
|
||||
vaddpd %ymm15,%ymm11,%ymm11
|
||||
vshufpd $5, %ymm4, %ymm4, %ymm12
|
||||
vshufpd $5, %ymm5, %ymm5, %ymm13
|
||||
vshufpd $5, %ymm6, %ymm6, %ymm14
|
||||
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||
vmulpd %ymm12,%ymm1,%ymm12
|
||||
vmulpd %ymm13,%ymm1,%ymm13
|
||||
vmulpd %ymm14,%ymm1,%ymm14
|
||||
vmulpd %ymm15,%ymm1,%ymm15
|
||||
vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12
|
||||
vfmaddsub231pd %ymm5, %ymm0, %ymm13
|
||||
vfmaddsub231pd %ymm6, %ymm0, %ymm14
|
||||
vfmaddsub231pd %ymm7, %ymm0, %ymm15
|
||||
|
||||
.save_and_return:
|
||||
vmovupd %ymm8,(%rdi)
|
||||
vmovupd %ymm9,0x20(%rdi)
|
||||
vmovupd %ymm10,0x40(%rdi)
|
||||
vmovupd %ymm11,0x60(%rdi)
|
||||
vmovupd %ymm12,0x80(%rdi)
|
||||
vmovupd %ymm13,0xa0(%rdi)
|
||||
vmovupd %ymm14,0xc0(%rdi)
|
||||
vmovupd %ymm15,0xe0(%rdi)
|
||||
ret
|
||||
@@ -1,267 +0,0 @@
|
||||
#include <immintrin.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
typedef double D4MEM[4];
|
||||
typedef double D2MEM[2];
|
||||
|
||||
/**
|
||||
* @brief complex ifft via bfs strategy (for m between 2 and 8)
|
||||
* @param dat the data to run the algorithm on
|
||||
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||
*/
|
||||
void cplx_ifft_avx2_fma_bfs_2(D4MEM* dat, const D2MEM** omga, uint32_t m) {
|
||||
double* data = (double*)dat;
|
||||
D4MEM* const finaldd = (D4MEM*)(data + 2 * m);
|
||||
{
|
||||
// loop with h = 1
|
||||
// we do not do any particular optimization in this loop,
|
||||
// since this function is only called for small dimensions
|
||||
D4MEM* dd = (D4MEM*)data;
|
||||
do {
|
||||
/*
|
||||
BEGIN_TEMPLATE
|
||||
const __m256d ab% = _mm256_loadu_pd(dd[0+2*%]);
|
||||
const __m256d cd% = _mm256_loadu_pd(dd[1+2*%]);
|
||||
const __m256d ac% = _mm256_permute2f128_pd(ab%, cd%, 0b100000);
|
||||
const __m256d bd% = _mm256_permute2f128_pd(ab%, cd%, 0b110001);
|
||||
const __m256d sum% = _mm256_add_pd(ac%, bd%);
|
||||
const __m256d diff% = _mm256_sub_pd(ac%, bd%);
|
||||
const __m256d diffbar% = _mm256_shuffle_pd(diff%, diff%, 5);
|
||||
const __m256d om% = _mm256_load_pd((*omg)[0+%]);
|
||||
const __m256d omre% = _mm256_unpacklo_pd(om%, om%);
|
||||
const __m256d omim% = _mm256_unpackhi_pd(om%, om%);
|
||||
const __m256d t1% = _mm256_mul_pd(diffbar%, omim%);
|
||||
const __m256d t2% = _mm256_fmaddsub_pd(diff%, omre%, t1%);
|
||||
const __m256d newab% = _mm256_permute2f128_pd(sum%, t2%, 0b100000);
|
||||
const __m256d newcd% = _mm256_permute2f128_pd(sum%, t2%, 0b110001);
|
||||
_mm256_storeu_pd(dd[0+2*%], newab%);
|
||||
_mm256_storeu_pd(dd[1+2*%], newcd%);
|
||||
dd += 2*@;
|
||||
*omg += 2*@;
|
||||
END_TEMPLATE
|
||||
*/
|
||||
// BEGIN_INTERLEAVE 1
|
||||
const __m256d ab0 = _mm256_loadu_pd(dd[0 + 2 * 0]);
|
||||
const __m256d cd0 = _mm256_loadu_pd(dd[1 + 2 * 0]);
|
||||
const __m256d ac0 = _mm256_permute2f128_pd(ab0, cd0, 0b100000);
|
||||
const __m256d bd0 = _mm256_permute2f128_pd(ab0, cd0, 0b110001);
|
||||
const __m256d sum0 = _mm256_add_pd(ac0, bd0);
|
||||
const __m256d diff0 = _mm256_sub_pd(ac0, bd0);
|
||||
const __m256d diffbar0 = _mm256_shuffle_pd(diff0, diff0, 5);
|
||||
const __m256d om0 = _mm256_load_pd((*omga)[0 + 0]);
|
||||
const __m256d omre0 = _mm256_unpacklo_pd(om0, om0);
|
||||
const __m256d omim0 = _mm256_unpackhi_pd(om0, om0);
|
||||
const __m256d t10 = _mm256_mul_pd(diffbar0, omim0);
|
||||
const __m256d t20 = _mm256_fmaddsub_pd(diff0, omre0, t10);
|
||||
const __m256d newab0 = _mm256_permute2f128_pd(sum0, t20, 0b100000);
|
||||
const __m256d newcd0 = _mm256_permute2f128_pd(sum0, t20, 0b110001);
|
||||
_mm256_storeu_pd(dd[0 + 2 * 0], newab0);
|
||||
_mm256_storeu_pd(dd[1 + 2 * 0], newcd0);
|
||||
dd += 2 * 1;
|
||||
*omga += 2 * 1;
|
||||
// END_INTERLEAVE
|
||||
} while (dd < finaldd);
|
||||
#if 0
|
||||
printf("c after first: ");
|
||||
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||
printf("%.6lf %.6lf ",ddata[ii][0],ddata[ii][1]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
}
|
||||
// general case
|
||||
const uint32_t ms2 = m >> 1;
|
||||
for (uint32_t _2nblock = 2; _2nblock <= ms2; _2nblock <<= 1) {
|
||||
// _2nblock = h in ref code
|
||||
uint32_t nblock = _2nblock >> 1; // =h/2 in ref code
|
||||
D4MEM* dd = (D4MEM*)data;
|
||||
do {
|
||||
const __m256d om = _mm256_load_pd((*omga)[0]);
|
||||
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||
const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om));
|
||||
D4MEM* const ddend = (dd + nblock);
|
||||
D4MEM* ddmid = ddend;
|
||||
do {
|
||||
const __m256d a = _mm256_loadu_pd(dd[0]);
|
||||
const __m256d b = _mm256_loadu_pd(ddmid[0]);
|
||||
const __m256d newa = _mm256_add_pd(a, b);
|
||||
_mm256_storeu_pd(dd[0], newa);
|
||||
const __m256d diff = _mm256_sub_pd(a, b);
|
||||
const __m256d t1 = _mm256_mul_pd(diff, omre);
|
||||
const __m256d bardiff = _mm256_shuffle_pd(diff, diff, 5);
|
||||
const __m256d t2 = _mm256_fmadd_pd(bardiff, omim, t1);
|
||||
_mm256_storeu_pd(ddmid[0], t2);
|
||||
dd += 1;
|
||||
ddmid += 1;
|
||||
} while (dd < ddend);
|
||||
dd += nblock;
|
||||
*omga += 2;
|
||||
} while (dd < finaldd);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief complex fft via bfs strategy (for m >= 16)
|
||||
* @param dat the data to run the algorithm on
|
||||
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||
*/
|
||||
void cplx_ifft_avx2_fma_bfs_16(D4MEM* dat, const D2MEM** omga, uint32_t m) {
|
||||
double* data = (double*)dat;
|
||||
D4MEM* const finaldd = (D4MEM*)(data + 2 * m);
|
||||
// base iteration when h = _2nblock == 8
|
||||
{
|
||||
D4MEM* dd = (D4MEM*)data;
|
||||
do {
|
||||
cplx_ifft16_avx_fma(dd, *omga);
|
||||
dd += 8;
|
||||
*omga += 8;
|
||||
} while (dd < finaldd);
|
||||
}
|
||||
// general case
|
||||
const uint32_t log2m = _mm_popcnt_u32(m - 1); //_popcnt32(m-1); //log2(m);
|
||||
uint32_t h = 16;
|
||||
if (log2m % 2 == 1) {
|
||||
uint32_t nblock = h >> 1; // =h/2 in ref code
|
||||
D4MEM* dd = (D4MEM*)data;
|
||||
do {
|
||||
const __m128d om1 = _mm_loadu_pd((*omga)[0]);
|
||||
const __m256d om = _mm256_set_m128d(om1, om1);
|
||||
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||
const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om));
|
||||
D4MEM* const ddend = (dd + nblock);
|
||||
D4MEM* ddmid = ddend;
|
||||
do {
|
||||
const __m256d a = _mm256_loadu_pd(dd[0]);
|
||||
const __m256d b = _mm256_loadu_pd(ddmid[0]);
|
||||
const __m256d newa = _mm256_add_pd(a, b);
|
||||
_mm256_storeu_pd(dd[0], newa);
|
||||
const __m256d diff = _mm256_sub_pd(a, b);
|
||||
const __m256d t1 = _mm256_mul_pd(diff, omre);
|
||||
const __m256d bardiff = _mm256_shuffle_pd(diff, diff, 5);
|
||||
const __m256d t2 = _mm256_fmadd_pd(bardiff, omim, t1);
|
||||
_mm256_storeu_pd(ddmid[0], t2);
|
||||
dd += 1;
|
||||
ddmid += 1;
|
||||
} while (dd < ddend);
|
||||
dd += nblock;
|
||||
*omga += 1;
|
||||
} while (dd < finaldd);
|
||||
h = 32;
|
||||
}
|
||||
for (; h < m; h <<= 2) {
|
||||
// _2nblock = h in ref code
|
||||
uint32_t nblock = h >> 1; // =h/2 in ref code
|
||||
D4MEM* dd0 = (D4MEM*)data;
|
||||
do {
|
||||
const __m128d om1 = _mm_loadu_pd((*omga)[0]);
|
||||
const __m128d al1 = _mm_loadu_pd((*omga)[1]);
|
||||
const __m256d om = _mm256_set_m128d(om1, om1);
|
||||
const __m256d al = _mm256_set_m128d(al1, al1);
|
||||
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||
const __m256d omim = _mm256_unpackhi_pd(om, om);
|
||||
const __m256d alre = _mm256_unpacklo_pd(al, al);
|
||||
const __m256d alim = _mm256_unpackhi_pd(al, al);
|
||||
D4MEM* const ddend = (dd0 + nblock);
|
||||
D4MEM* dd1 = ddend;
|
||||
D4MEM* dd2 = dd1 + nblock;
|
||||
D4MEM* dd3 = dd2 + nblock;
|
||||
do {
|
||||
__m256d u0 = _mm256_loadu_pd(dd0[0]);
|
||||
__m256d u1 = _mm256_loadu_pd(dd1[0]);
|
||||
__m256d u2 = _mm256_loadu_pd(dd2[0]);
|
||||
__m256d u3 = _mm256_loadu_pd(dd3[0]);
|
||||
__m256d u4 = _mm256_add_pd(u0, u1);
|
||||
__m256d u5 = _mm256_sub_pd(u0, u1);
|
||||
__m256d u6 = _mm256_add_pd(u2, u3);
|
||||
__m256d u7 = _mm256_sub_pd(u2, u3);
|
||||
u0 = _mm256_shuffle_pd(u5, u5, 5);
|
||||
u2 = _mm256_shuffle_pd(u7, u7, 5);
|
||||
u1 = _mm256_mul_pd(u0, omim);
|
||||
u3 = _mm256_mul_pd(u2, omre);
|
||||
u5 = _mm256_fmaddsub_pd(u5, omre, u1);
|
||||
u7 = _mm256_fmsubadd_pd(u7, omim, u3);
|
||||
//////
|
||||
u0 = _mm256_add_pd(u4, u6);
|
||||
u1 = _mm256_add_pd(u5, u7);
|
||||
u2 = _mm256_sub_pd(u4, u6);
|
||||
u3 = _mm256_sub_pd(u5, u7);
|
||||
u4 = _mm256_shuffle_pd(u2, u2, 5);
|
||||
u5 = _mm256_shuffle_pd(u3, u3, 5);
|
||||
u6 = _mm256_mul_pd(u4, alim);
|
||||
u7 = _mm256_mul_pd(u5, alim);
|
||||
u2 = _mm256_fmaddsub_pd(u2, alre, u6);
|
||||
u3 = _mm256_fmaddsub_pd(u3, alre, u7);
|
||||
///////
|
||||
_mm256_storeu_pd(dd0[0], u0);
|
||||
_mm256_storeu_pd(dd1[0], u1);
|
||||
_mm256_storeu_pd(dd2[0], u2);
|
||||
_mm256_storeu_pd(dd3[0], u3);
|
||||
dd0 += 1;
|
||||
dd1 += 1;
|
||||
dd2 += 1;
|
||||
dd3 += 1;
|
||||
} while (dd0 < ddend);
|
||||
dd0 += 3 * nblock;
|
||||
*omga += 2;
|
||||
} while (dd0 < finaldd);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief complex ifft via dfs recursion (for m >= 16)
|
||||
* @param dat the data to run the algorithm on
|
||||
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||
*/
|
||||
void cplx_ifft_avx2_fma_rec_16(D4MEM* dat, const D2MEM** omga, uint32_t m) {
|
||||
if (m <= 8) return cplx_ifft_avx2_fma_bfs_2(dat, omga, m);
|
||||
if (m <= 2048) return cplx_ifft_avx2_fma_bfs_16(dat, omga, m);
|
||||
const uint32_t _2nblock = m >> 1; // = h in ref code
|
||||
const uint32_t nblock = _2nblock >> 1; // =h/2 in ref code
|
||||
cplx_ifft_avx2_fma_rec_16(dat, omga, _2nblock);
|
||||
cplx_ifft_avx2_fma_rec_16(dat + nblock, omga, _2nblock);
|
||||
{
|
||||
// final iteration
|
||||
D4MEM* dd = dat;
|
||||
const __m256d om = _mm256_load_pd((*omga)[0]);
|
||||
const __m256d omre = _mm256_unpacklo_pd(om, om);
|
||||
const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om));
|
||||
D4MEM* const ddend = (dd + nblock);
|
||||
D4MEM* ddmid = ddend;
|
||||
do {
|
||||
const __m256d a = _mm256_loadu_pd(dd[0]);
|
||||
const __m256d b = _mm256_loadu_pd(ddmid[0]);
|
||||
const __m256d newa = _mm256_add_pd(a, b);
|
||||
_mm256_storeu_pd(dd[0], newa);
|
||||
const __m256d diff = _mm256_sub_pd(a, b);
|
||||
const __m256d t1 = _mm256_mul_pd(diff, omre);
|
||||
const __m256d bardiff = _mm256_shuffle_pd(diff, diff, 5);
|
||||
const __m256d t2 = _mm256_fmadd_pd(bardiff, omim, t1);
|
||||
_mm256_storeu_pd(ddmid[0], t2);
|
||||
dd += 1;
|
||||
ddmid += 1;
|
||||
} while (dd < ddend);
|
||||
*omga += 2;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief complex ifft via best strategy (for m>=1)
|
||||
* @param dat the data to run the algorithm on: m complex numbers
|
||||
* @param omg precomputed tables (must have been filled with fill_omega)
|
||||
* @param m ring dimension of the FFT (modulo X^m-i)
|
||||
*/
|
||||
EXPORT void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* precomp, void* d) {
|
||||
const uint32_t m = precomp->m;
|
||||
const D2MEM* omg = (D2MEM*)precomp->powomegas;
|
||||
if (m <= 1) return;
|
||||
if (m <= 8) return cplx_ifft_avx2_fma_bfs_2(d, &omg, m);
|
||||
if (m <= 2048) return cplx_ifft_avx2_fma_bfs_16(d, &omg, m);
|
||||
cplx_ifft_avx2_fma_rec_16(d, &omg, m);
|
||||
}
|
||||
@@ -1,312 +0,0 @@
|
||||
#include <memory.h>
|
||||
|
||||
#include "../commons_private.h"
|
||||
#include "cplx_fft.h"
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
/** @brief (a,b) <- (a+b,omegabar.(a-b)) */
|
||||
void invctwiddle(CPLX a, CPLX b, const CPLX ombar) {
|
||||
double diffre = a[0] - b[0];
|
||||
double diffim = a[1] - b[1];
|
||||
a[0] = a[0] + b[0];
|
||||
a[1] = a[1] + b[1];
|
||||
b[0] = diffre * ombar[0] - diffim * ombar[1];
|
||||
b[1] = diffre * ombar[1] + diffim * ombar[0];
|
||||
}
|
||||
|
||||
/** @brief (a,b) <- (a+b,-i.omegabar(a-b)) */
|
||||
void invcitwiddle(CPLX a, CPLX b, const CPLX ombar) {
|
||||
double diffre = a[0] - b[0];
|
||||
double diffim = a[1] - b[1];
|
||||
a[0] = a[0] + b[0];
|
||||
a[1] = a[1] + b[1];
|
||||
//-i(x+iy)=-ix+y
|
||||
b[0] = diffre * ombar[1] + diffim * ombar[0];
|
||||
b[1] = -diffre * ombar[0] + diffim * ombar[1];
|
||||
}
|
||||
|
||||
/** @brief exp(-i.2pi.x) */
|
||||
void cplx_set_e2pimx(CPLX res, double x) {
|
||||
res[0] = m_accurate_cos(2 * M_PI * x);
|
||||
res[1] = -m_accurate_sin(2 * M_PI * x);
|
||||
}
|
||||
/**
|
||||
* @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,...
|
||||
* essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/
|
||||
double fracrevbits(uint32_t i);
|
||||
/** @brief fft modulo X^m-exp(i.2pi.entry+pwr) -- reference code */
|
||||
void cplx_ifft_naive(const uint32_t m, const double entry_pwr, CPLX* data) {
|
||||
if (m == 1) return;
|
||||
const double pom = entry_pwr / 2.;
|
||||
const uint32_t h = m / 2;
|
||||
CPLX cpom;
|
||||
cplx_set_e2pimx(cpom, pom);
|
||||
// do the recursive calls
|
||||
cplx_ifft_naive(h, pom, data);
|
||||
cplx_ifft_naive(h, pom + 0.5, data + h);
|
||||
// apply the inverse twiddle factors
|
||||
for (uint64_t i = 0; i < h; ++i) {
|
||||
invctwiddle(data[i], data[i + h], cpom);
|
||||
}
|
||||
}
|
||||
|
||||
void cplx_ifft16_precomp(const double entry_pwr, CPLX** omg) {
|
||||
static const double j_pow = 1. / 8.;
|
||||
static const double k_pow = 1. / 16.;
|
||||
const double pom = entry_pwr / 2.;
|
||||
const double pom_2 = entry_pwr / 4.;
|
||||
const double pom_4 = entry_pwr / 8.;
|
||||
const double pom_8 = entry_pwr / 16.;
|
||||
cplx_set_e2pimx((*omg)[0], pom_8);
|
||||
cplx_set_e2pimx((*omg)[1], pom_8 + j_pow);
|
||||
cplx_set_e2pimx((*omg)[2], pom_8 + k_pow);
|
||||
cplx_set_e2pimx((*omg)[3], pom_8 + j_pow + k_pow);
|
||||
cplx_set_e2pimx((*omg)[4], pom_4);
|
||||
cplx_set_e2pimx((*omg)[5], pom_4 + j_pow);
|
||||
cplx_set_e2pimx((*omg)[6], pom_2);
|
||||
cplx_set_e2pimx((*omg)[7], pom);
|
||||
*omg += 8;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief iFFT modulo X^16-omega^2 (in registers)
|
||||
* @param data contains 16 complexes
|
||||
* @param omegabar 8 complexes in this order:
|
||||
* gammabar,jb.gammabar,kb.gammabar,kbjb.gammabar,
|
||||
* betabar,jb.betabar,alphabar,omegabar
|
||||
* alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta)
|
||||
* jb = sqrt(ib), kb=sqrt(jb)
|
||||
*/
|
||||
void cplx_ifft16_ref(void* data, const void* omegabar) {
|
||||
CPLX* d = data;
|
||||
const CPLX* om = omegabar;
|
||||
// fourth pass inverse
|
||||
invctwiddle(d[0], d[1], om[0]);
|
||||
invcitwiddle(d[2], d[3], om[0]);
|
||||
invctwiddle(d[4], d[5], om[1]);
|
||||
invcitwiddle(d[6], d[7], om[1]);
|
||||
invctwiddle(d[8], d[9], om[2]);
|
||||
invcitwiddle(d[10], d[11], om[2]);
|
||||
invctwiddle(d[12], d[13], om[3]);
|
||||
invcitwiddle(d[14], d[15], om[3]);
|
||||
// third pass inverse
|
||||
invctwiddle(d[0], d[2], om[4]);
|
||||
invctwiddle(d[1], d[3], om[4]);
|
||||
invcitwiddle(d[4], d[6], om[4]);
|
||||
invcitwiddle(d[5], d[7], om[4]);
|
||||
invctwiddle(d[8], d[10], om[5]);
|
||||
invctwiddle(d[9], d[11], om[5]);
|
||||
invcitwiddle(d[12], d[14], om[5]);
|
||||
invcitwiddle(d[13], d[15], om[5]);
|
||||
// second pass inverse
|
||||
invctwiddle(d[0], d[4], om[6]);
|
||||
invctwiddle(d[1], d[5], om[6]);
|
||||
invctwiddle(d[2], d[6], om[6]);
|
||||
invctwiddle(d[3], d[7], om[6]);
|
||||
invcitwiddle(d[8], d[12], om[6]);
|
||||
invcitwiddle(d[9], d[13], om[6]);
|
||||
invcitwiddle(d[10], d[14], om[6]);
|
||||
invcitwiddle(d[11], d[15], om[6]);
|
||||
// first pass
|
||||
for (uint64_t i = 0; i < 8; ++i) {
|
||||
invctwiddle(d[0 + i], d[8 + i], om[7]);
|
||||
}
|
||||
}
|
||||
|
||||
void cplx_ifft_ref_bfs_2(CPLX* dat, const CPLX** omg, uint32_t m) {
|
||||
CPLX* const dend = dat + m;
|
||||
for (CPLX* d = dat; d < dend; d += 2) {
|
||||
split_fft_last_ref(d, (*omg)[0]);
|
||||
*omg += 1;
|
||||
}
|
||||
#if 0
|
||||
printf("after first: ");
|
||||
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||
printf("%.6lf %.6lf ",data[ii][0],data[ii][1]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
int32_t Ms2 = m / 2;
|
||||
for (int32_t h = 2; h <= Ms2; h <<= 1) {
|
||||
for (CPLX* d = dat; d < dend; d += 2 * h) {
|
||||
if (memcmp((*omg)[0], (*omg)[1], 8) != 0) abort();
|
||||
cplx_split_fft_ref(h, d, **omg);
|
||||
*omg += 2;
|
||||
}
|
||||
#if 0
|
||||
printf("after split %d: ", h);
|
||||
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||
printf("%.6lf %.6lf ",data[ii][0],data[ii][1]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void cplx_ifft_ref_bfs_16(CPLX* dat, const CPLX** omg, uint32_t m) {
|
||||
const uint64_t log2m = log2(m);
|
||||
CPLX* const dend = dat + m;
|
||||
// h=1,2,4,8 use the 16-dim macroblock
|
||||
for (CPLX* d = dat; d < dend; d += 16) {
|
||||
cplx_ifft16_ref(d, *omg);
|
||||
*omg += 8;
|
||||
}
|
||||
#if 0
|
||||
printf("after first: ");
|
||||
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||
printf("%.6lf %.6lf ",data[ii][0],data[ii][1]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
int32_t h = 16;
|
||||
if (log2m % 2 != 0) {
|
||||
// if parity needs it, uses one regular twiddle
|
||||
for (CPLX* d = dat; d < dend; d += 2 * h) {
|
||||
cplx_split_fft_ref(h, d, **omg);
|
||||
*omg += 1;
|
||||
}
|
||||
h = 32;
|
||||
}
|
||||
// h=16,...,2*floor(Ms2/2) use the bitwiddle
|
||||
for (; h < m; h <<= 2) {
|
||||
for (CPLX* d = dat; d < dend; d += 4 * h) {
|
||||
cplx_bisplit_fft_ref(h, d, *omg);
|
||||
*omg += 2;
|
||||
}
|
||||
#if 0
|
||||
printf("after split %d: ", h);
|
||||
for (uint64_t ii=0; ii<nn/2; ++ii) {
|
||||
printf("%.6lf %.6lf ",data[ii][0],data[ii][1]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void fill_cplx_ifft_omegas_bfs_16(const double entry_pwr, CPLX** omg, uint32_t m) {
|
||||
const uint64_t log2m = log2(m);
|
||||
double pwr = entry_pwr * 16. / m;
|
||||
{
|
||||
// h=8
|
||||
for (uint32_t i = 0; i < m / 16; i++) {
|
||||
cplx_ifft16_precomp(pwr + fracrevbits(i), omg);
|
||||
}
|
||||
}
|
||||
int32_t h = 16;
|
||||
if (log2m % 2 != 0) {
|
||||
// if parity needs it, uses one regular twiddle
|
||||
for (uint32_t i = 0; i < m / (2 * h); i++) {
|
||||
cplx_set_e2pimx(omg[0][0], pwr + fracrevbits(i) / 2.);
|
||||
*omg += 1;
|
||||
}
|
||||
pwr *= 2.;
|
||||
h = 32;
|
||||
}
|
||||
for (; h < m; h <<= 2) {
|
||||
for (uint32_t i = 0; i < m / (2 * h); i += 2) {
|
||||
cplx_set_e2pimx(omg[0][0], pwr + fracrevbits(i) / 2.);
|
||||
cplx_set_e2pimx(omg[0][1], 2. * pwr + fracrevbits(i));
|
||||
*omg += 2;
|
||||
}
|
||||
pwr *= 4.;
|
||||
}
|
||||
}
|
||||
|
||||
void fill_cplx_ifft_omegas_bfs_2(const double entry_pwr, CPLX** omg, uint32_t m) {
|
||||
double pom = entry_pwr / m;
|
||||
{
|
||||
// h=1
|
||||
for (uint32_t i = 0; i < m / 2; i++) {
|
||||
cplx_set_e2pimx((*omg)[0], pom + fracrevbits(i) / 2.);
|
||||
*omg += 1; // optim function reads by 1
|
||||
}
|
||||
}
|
||||
for (int32_t h = 2; h <= m / 2; h <<= 1) {
|
||||
pom *= 2;
|
||||
for (uint32_t i = 0; i < m / (2 * h); i++) {
|
||||
cplx_set_e2pimx(omg[0][0], pom + fracrevbits(i) / 2.);
|
||||
cplx_set(omg[0][1], omg[0][0]);
|
||||
*omg += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fill_cplx_ifft_omegas_rec_16(const double entry_pwr, CPLX** omg, uint32_t m) {
|
||||
if (m == 1) return;
|
||||
if (m <= 8) return fill_cplx_ifft_omegas_bfs_2(entry_pwr, omg, m);
|
||||
if (m <= 2048) return fill_cplx_ifft_omegas_bfs_16(entry_pwr, omg, m);
|
||||
double pom = entry_pwr / 2.;
|
||||
uint32_t h = m / 2;
|
||||
fill_cplx_ifft_omegas_rec_16(pom, omg, h);
|
||||
fill_cplx_ifft_omegas_rec_16(pom + 0.5, omg, h);
|
||||
cplx_set_e2pimx((*omg)[0], pom);
|
||||
cplx_set((*omg)[1], (*omg)[0]);
|
||||
*omg += 2;
|
||||
}
|
||||
|
||||
void cplx_ifft_ref_rec_16(CPLX* dat, const CPLX** omg, uint32_t m) {
|
||||
if (m == 1) return;
|
||||
if (m <= 8) return cplx_ifft_ref_bfs_2(dat, omg, m);
|
||||
if (m <= 2048) return cplx_ifft_ref_bfs_16(dat, omg, m);
|
||||
const uint32_t h = m / 2;
|
||||
cplx_ifft_ref_rec_16(dat, omg, h);
|
||||
cplx_ifft_ref_rec_16(dat + h, omg, h);
|
||||
if (memcmp((*omg)[0], (*omg)[1], 8) != 0) abort();
|
||||
cplx_split_fft_ref(h, dat, **omg);
|
||||
*omg += 2;
|
||||
}
|
||||
|
||||
EXPORT void cplx_ifft_ref(const CPLX_IFFT_PRECOMP* precomp, void* d) {
|
||||
CPLX* data = (CPLX*)d;
|
||||
const int32_t m = precomp->m;
|
||||
const CPLX* omg = (CPLX*)precomp->powomegas;
|
||||
if (m == 1) return;
|
||||
if (m <= 8) return cplx_ifft_ref_bfs_2(data, &omg, m);
|
||||
if (m <= 2048) return cplx_ifft_ref_bfs_16(data, &omg, m);
|
||||
cplx_ifft_ref_rec_16(data, &omg, m);
|
||||
}
|
||||
|
||||
EXPORT CPLX_IFFT_PRECOMP* new_cplx_ifft_precomp(uint32_t m, uint32_t num_buffers) {
|
||||
const uint64_t OMG_SPACE = ceilto64b(2 * m * sizeof(CPLX));
|
||||
const uint64_t BUF_SIZE = ceilto64b(m * sizeof(CPLX));
|
||||
void* reps = malloc(sizeof(CPLX_IFFT_PRECOMP) + 63 // padding
|
||||
+ OMG_SPACE // tables
|
||||
+ num_buffers * BUF_SIZE // buffers
|
||||
);
|
||||
uint64_t aligned_addr = ceilto64b((uint64_t)reps + sizeof(CPLX_IFFT_PRECOMP));
|
||||
CPLX_IFFT_PRECOMP* r = (CPLX_IFFT_PRECOMP*)reps;
|
||||
r->m = m;
|
||||
r->buf_size = BUF_SIZE;
|
||||
r->powomegas = (double*)aligned_addr;
|
||||
r->aligned_buffers = (void*)(aligned_addr + OMG_SPACE);
|
||||
// fill in powomegas
|
||||
CPLX* omg = (CPLX*)r->powomegas;
|
||||
fill_cplx_ifft_omegas_rec_16(0.25, &omg, m);
|
||||
if (((uint64_t)omg) - aligned_addr > OMG_SPACE) abort();
|
||||
{
|
||||
if (m <= 4) {
|
||||
// currently, we do not have any acceletated
|
||||
// implementation for m<=4
|
||||
r->function = cplx_ifft_ref;
|
||||
} else if (CPU_SUPPORTS("fma")) {
|
||||
r->function = cplx_ifft_avx2_fma;
|
||||
} else {
|
||||
r->function = cplx_ifft_ref;
|
||||
}
|
||||
}
|
||||
return reps;
|
||||
}
|
||||
|
||||
EXPORT void* cplx_ifft_precomp_get_buffer(const CPLX_IFFT_PRECOMP* itables, uint32_t buffer_index) {
|
||||
return (uint8_t*)itables->aligned_buffers + buffer_index * itables->buf_size;
|
||||
}
|
||||
|
||||
EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* itables, void* data) { itables->function(itables, data); }
|
||||
|
||||
EXPORT void cplx_ifft_simple(uint32_t m, void* data) {
|
||||
static CPLX_IFFT_PRECOMP* p[31] = {0};
|
||||
CPLX_IFFT_PRECOMP** f = p + log2m(m);
|
||||
if (!*f) *f = new_cplx_ifft_precomp(m, 0);
|
||||
(*f)->function(*f, data);
|
||||
}
|
||||
@@ -1,136 +0,0 @@
|
||||
/*
|
||||
* This file is extracted from the implementation of the FFT on Arm64/Neon
|
||||
* available in https://github.com/cothan/Falcon-Arm (neon/macrof.h).
|
||||
* =============================================================================
|
||||
* Copyright (c) 2022 by Cryptographic Engineering Research Group (CERG)
|
||||
* ECE Department, George Mason University
|
||||
* Fairfax, VA, U.S.A.
|
||||
* @author: Duc Tri Nguyen dnguye69@gmu.edu, cothannguyen@gmail.com
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* =============================================================================
|
||||
*
|
||||
* This 64-bit Floating point NEON macro x1 has not been modified and is provided as is.
|
||||
*/
|
||||
|
||||
#ifndef MACROF_H
|
||||
#define MACROF_H
|
||||
|
||||
#include <arm_neon.h>
|
||||
|
||||
// c <= addr x1
|
||||
#define vload(c, addr) c = vld1q_f64(addr);
|
||||
// c <= addr interleave 2
|
||||
#define vload2(c, addr) c = vld2q_f64(addr);
|
||||
// c <= addr interleave 4
|
||||
#define vload4(c, addr) c = vld4q_f64(addr);
|
||||
|
||||
#define vstore(addr, c) vst1q_f64(addr, c);
|
||||
// addr <= c
|
||||
#define vstore2(addr, c) vst2q_f64(addr, c);
|
||||
// addr <= c
|
||||
#define vstore4(addr, c) vst4q_f64(addr, c);
|
||||
|
||||
// c <= addr x2
|
||||
#define vloadx2(c, addr) c = vld1q_f64_x2(addr);
|
||||
// c <= addr x3
|
||||
#define vloadx3(c, addr) c = vld1q_f64_x3(addr);
|
||||
|
||||
// addr <= c
|
||||
#define vstorex2(addr, c) vst1q_f64_x2(addr, c);
|
||||
|
||||
// c = a - b
|
||||
#define vfsub(c, a, b) c = vsubq_f64(a, b);
|
||||
|
||||
// c = a + b
|
||||
#define vfadd(c, a, b) c = vaddq_f64(a, b);
|
||||
|
||||
// c = a * b
|
||||
#define vfmul(c, a, b) c = vmulq_f64(a, b);
|
||||
|
||||
// c = a * n (n is constant)
|
||||
#define vfmuln(c, a, n) c = vmulq_n_f64(a, n);
|
||||
|
||||
// Swap from a|b to b|a
|
||||
#define vswap(c, a) c = vextq_f64(a, a, 1);
|
||||
|
||||
// c = a * b[i]
|
||||
#define vfmul_lane(c, a, b, i) c = vmulq_laneq_f64(a, b, i);
|
||||
|
||||
// c = 1/a
|
||||
#define vfinv(c, a) c = vdivq_f64(vdupq_n_f64(1.0), a);
|
||||
|
||||
// c = -a
|
||||
#define vfneg(c, a) c = vnegq_f64(a);
|
||||
|
||||
#define transpose_f64(a, b, t, ia, ib, it) \
|
||||
t.val[it] = a.val[ia]; \
|
||||
a.val[ia] = vzip1q_f64(a.val[ia], b.val[ib]); \
|
||||
b.val[ib] = vzip2q_f64(t.val[it], b.val[ib]);
|
||||
|
||||
/*
|
||||
* c = a + jb
|
||||
* c[0] = a[0] - b[1]
|
||||
* c[1] = a[1] + b[0]
|
||||
*/
|
||||
#define vfcaddj(c, a, b) c = vcaddq_rot90_f64(a, b);
|
||||
|
||||
/*
|
||||
* c = a - jb
|
||||
* c[0] = a[0] + b[1]
|
||||
* c[1] = a[1] - b[0]
|
||||
*/
|
||||
#define vfcsubj(c, a, b) c = vcaddq_rot270_f64(a, b);
|
||||
|
||||
// c[0] = c[0] + b[0]*a[0], c[1] = c[1] + b[1]*a[0]
|
||||
#define vfcmla(c, a, b) c = vcmlaq_f64(c, a, b);
|
||||
|
||||
// c[0] = c[0] - b[1]*a[1], c[1] = c[1] + b[0]*a[1]
|
||||
#define vfcmla_90(c, a, b) c = vcmlaq_rot90_f64(c, a, b);
|
||||
|
||||
// c[0] = c[0] - b[0]*a[0], c[1] = c[1] - b[1]*a[0]
|
||||
#define vfcmla_180(c, a, b) c = vcmlaq_rot180_f64(c, a, b);
|
||||
|
||||
// c[0] = c[0] + b[1]*a[1], c[1] = c[1] - b[0]*a[1]
|
||||
#define vfcmla_270(c, a, b) c = vcmlaq_rot270_f64(c, a, b);
|
||||
|
||||
/*
|
||||
* Complex MUL: c = a*b
|
||||
* c[0] = a[0]*b[0] - a[1]*b[1]
|
||||
* c[1] = a[0]*b[1] + a[1]*b[0]
|
||||
*/
|
||||
#define FPC_CMUL(c, a, b) \
|
||||
c = vmulq_laneq_f64(b, a, 0); \
|
||||
c = vcmlaq_rot90_f64(c, a, b);
|
||||
|
||||
/*
|
||||
* Complex MUL: c = a * conjugate(b) = a * (b[0], -b[1])
|
||||
* c[0] = b[0]*a[0] + b[1]*a[1]
|
||||
* c[1] = + b[0]*a[1] - b[1]*a[0]
|
||||
*/
|
||||
#define FPC_CMUL_CONJ(c, a, b) \
|
||||
c = vmulq_laneq_f64(a, b, 0); \
|
||||
c = vcmlaq_rot270_f64(c, b, a);
|
||||
|
||||
#if FMA == 1
|
||||
// d = c + a *b
|
||||
#define vfmla(d, c, a, b) d = vfmaq_f64(c, a, b);
|
||||
// d = c - a * b
|
||||
#define vfmls(d, c, a, b) d = vfmsq_f64(c, a, b);
|
||||
// d = c + a * b[i]
|
||||
#define vfmla_lane(d, c, a, b, i) d = vfmaq_laneq_f64(c, a, b, i);
|
||||
// d = c - a * b[i]
|
||||
#define vfmls_lane(d, c, a, b, i) d = vfmsq_laneq_f64(c, a, b, i);
|
||||
|
||||
#else
|
||||
// d = c + a *b
|
||||
#define vfmla(d, c, a, b) d = vaddq_f64(c, vmulq_f64(a, b));
|
||||
// d = c - a *b
|
||||
#define vfmls(d, c, a, b) d = vsubq_f64(c, vmulq_f64(a, b));
|
||||
// d = c + a * b[i]
|
||||
#define vfmla_lane(d, c, a, b, i) d = vaddq_f64(c, vmulq_laneq_f64(a, b, i));
|
||||
|
||||
#define vfmls_lane(d, c, a, b, i) d = vsubq_f64(c, vmulq_laneq_f64(a, b, i));
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -1,420 +0,0 @@
|
||||
/*
|
||||
* This file is extracted from the implementation of the FFT on Arm64/Neon
|
||||
* available in https://github.com/cothan/Falcon-Arm (neon/macrof.h).
|
||||
* =============================================================================
|
||||
* Copyright (c) 2022 by Cryptographic Engineering Research Group (CERG)
|
||||
* ECE Department, George Mason University
|
||||
* Fairfax, VA, U.S.A.
|
||||
* @author: Duc Tri Nguyen dnguye69@gmu.edu, cothannguyen@gmail.com
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* =============================================================================
|
||||
*
|
||||
* This 64-bit Floating point NEON macro x4 has not been modified and is provided as is.
|
||||
*/
|
||||
|
||||
#ifndef MACROFX4_H
|
||||
#define MACROFX4_H
|
||||
|
||||
#include <arm_neon.h>
|
||||
|
||||
#include "macrof.h"
|
||||
|
||||
#define vloadx4(c, addr) c = vld1q_f64_x4(addr);
|
||||
|
||||
#define vstorex4(addr, c) vst1q_f64_x4(addr, c);
|
||||
|
||||
#define vfdupx4(c, constant) \
|
||||
c.val[0] = vdupq_n_f64(constant); \
|
||||
c.val[1] = vdupq_n_f64(constant); \
|
||||
c.val[2] = vdupq_n_f64(constant); \
|
||||
c.val[3] = vdupq_n_f64(constant);
|
||||
|
||||
#define vfnegx4(c, a) \
|
||||
c.val[0] = vnegq_f64(a.val[0]); \
|
||||
c.val[1] = vnegq_f64(a.val[1]); \
|
||||
c.val[2] = vnegq_f64(a.val[2]); \
|
||||
c.val[3] = vnegq_f64(a.val[3]);
|
||||
|
||||
#define vfmulnx4(c, a, n) \
|
||||
c.val[0] = vmulq_n_f64(a.val[0], n); \
|
||||
c.val[1] = vmulq_n_f64(a.val[1], n); \
|
||||
c.val[2] = vmulq_n_f64(a.val[2], n); \
|
||||
c.val[3] = vmulq_n_f64(a.val[3], n);
|
||||
|
||||
// c = a - b
|
||||
#define vfsubx4(c, a, b) \
|
||||
c.val[0] = vsubq_f64(a.val[0], b.val[0]); \
|
||||
c.val[1] = vsubq_f64(a.val[1], b.val[1]); \
|
||||
c.val[2] = vsubq_f64(a.val[2], b.val[2]); \
|
||||
c.val[3] = vsubq_f64(a.val[3], b.val[3]);
|
||||
|
||||
// c = a + b
|
||||
#define vfaddx4(c, a, b) \
|
||||
c.val[0] = vaddq_f64(a.val[0], b.val[0]); \
|
||||
c.val[1] = vaddq_f64(a.val[1], b.val[1]); \
|
||||
c.val[2] = vaddq_f64(a.val[2], b.val[2]); \
|
||||
c.val[3] = vaddq_f64(a.val[3], b.val[3]);
|
||||
|
||||
#define vfmulx4(c, a, b) \
|
||||
c.val[0] = vmulq_f64(a.val[0], b.val[0]); \
|
||||
c.val[1] = vmulq_f64(a.val[1], b.val[1]); \
|
||||
c.val[2] = vmulq_f64(a.val[2], b.val[2]); \
|
||||
c.val[3] = vmulq_f64(a.val[3], b.val[3]);
|
||||
|
||||
#define vfmulx4_i(c, a, b) \
|
||||
c.val[0] = vmulq_f64(a.val[0], b); \
|
||||
c.val[1] = vmulq_f64(a.val[1], b); \
|
||||
c.val[2] = vmulq_f64(a.val[2], b); \
|
||||
c.val[3] = vmulq_f64(a.val[3], b);
|
||||
|
||||
#define vfinvx4(c, a) \
|
||||
c.val[0] = vdivq_f64(vdupq_n_f64(1.0), a.val[0]); \
|
||||
c.val[1] = vdivq_f64(vdupq_n_f64(1.0), a.val[1]); \
|
||||
c.val[2] = vdivq_f64(vdupq_n_f64(1.0), a.val[2]); \
|
||||
c.val[3] = vdivq_f64(vdupq_n_f64(1.0), a.val[3]);
|
||||
|
||||
#define vfcvtx4(c, a) \
|
||||
c.val[0] = vcvtq_f64_s64(a.val[0]); \
|
||||
c.val[1] = vcvtq_f64_s64(a.val[1]); \
|
||||
c.val[2] = vcvtq_f64_s64(a.val[2]); \
|
||||
c.val[3] = vcvtq_f64_s64(a.val[3]);
|
||||
|
||||
#define vfmlax4(d, c, a, b) \
|
||||
vfmla(d.val[0], c.val[0], a.val[0], b.val[0]); \
|
||||
vfmla(d.val[1], c.val[1], a.val[1], b.val[1]); \
|
||||
vfmla(d.val[2], c.val[2], a.val[2], b.val[2]); \
|
||||
vfmla(d.val[3], c.val[3], a.val[3], b.val[3]);
|
||||
|
||||
#define vfmlsx4(d, c, a, b) \
|
||||
vfmls(d.val[0], c.val[0], a.val[0], b.val[0]); \
|
||||
vfmls(d.val[1], c.val[1], a.val[1], b.val[1]); \
|
||||
vfmls(d.val[2], c.val[2], a.val[2], b.val[2]); \
|
||||
vfmls(d.val[3], c.val[3], a.val[3], b.val[3]);
|
||||
|
||||
#define vfrintx4(c, a) \
|
||||
c.val[0] = vcvtnq_s64_f64(a.val[0]); \
|
||||
c.val[1] = vcvtnq_s64_f64(a.val[1]); \
|
||||
c.val[2] = vcvtnq_s64_f64(a.val[2]); \
|
||||
c.val[3] = vcvtnq_s64_f64(a.val[3]);
|
||||
|
||||
/*
|
||||
* Wrapper for FFT, split/merge and poly_float.c
|
||||
*/
|
||||
|
||||
#define FPC_MUL(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
vfmul(d_re, a_re, b_re); \
|
||||
vfmls(d_re, d_re, a_im, b_im); \
|
||||
vfmul(d_im, a_re, b_im); \
|
||||
vfmla(d_im, d_im, a_im, b_re);
|
||||
|
||||
#define FPC_MULx2(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
vfmul(d_re.val[0], a_re.val[0], b_re.val[0]); \
|
||||
vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \
|
||||
vfmul(d_re.val[1], a_re.val[1], b_re.val[1]); \
|
||||
vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \
|
||||
vfmul(d_im.val[0], a_re.val[0], b_im.val[0]); \
|
||||
vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \
|
||||
vfmul(d_im.val[1], a_re.val[1], b_im.val[1]); \
|
||||
vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]);
|
||||
|
||||
#define FPC_MULx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
vfmul(d_re.val[0], a_re.val[0], b_re.val[0]); \
|
||||
vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \
|
||||
vfmul(d_re.val[1], a_re.val[1], b_re.val[1]); \
|
||||
vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \
|
||||
vfmul(d_re.val[2], a_re.val[2], b_re.val[2]); \
|
||||
vfmls(d_re.val[2], d_re.val[2], a_im.val[2], b_im.val[2]); \
|
||||
vfmul(d_re.val[3], a_re.val[3], b_re.val[3]); \
|
||||
vfmls(d_re.val[3], d_re.val[3], a_im.val[3], b_im.val[3]); \
|
||||
vfmul(d_im.val[0], a_re.val[0], b_im.val[0]); \
|
||||
vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \
|
||||
vfmul(d_im.val[1], a_re.val[1], b_im.val[1]); \
|
||||
vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]); \
|
||||
vfmul(d_im.val[2], a_re.val[2], b_im.val[2]); \
|
||||
vfmla(d_im.val[2], d_im.val[2], a_im.val[2], b_re.val[2]); \
|
||||
vfmul(d_im.val[3], a_re.val[3], b_im.val[3]); \
|
||||
vfmla(d_im.val[3], d_im.val[3], a_im.val[3], b_re.val[3]);
|
||||
|
||||
#define FPC_MLA(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
vfmla(d_re, d_re, a_re, b_re); \
|
||||
vfmls(d_re, d_re, a_im, b_im); \
|
||||
vfmla(d_im, d_im, a_re, b_im); \
|
||||
vfmla(d_im, d_im, a_im, b_re);
|
||||
|
||||
#define FPC_MLAx2(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \
|
||||
vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \
|
||||
vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \
|
||||
vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \
|
||||
vfmla(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \
|
||||
vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \
|
||||
vfmla(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \
|
||||
vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]);
|
||||
|
||||
#define FPC_MLAx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \
|
||||
vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \
|
||||
vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \
|
||||
vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \
|
||||
vfmla(d_re.val[2], d_re.val[2], a_re.val[2], b_re.val[2]); \
|
||||
vfmls(d_re.val[2], d_re.val[2], a_im.val[2], b_im.val[2]); \
|
||||
vfmla(d_re.val[3], d_re.val[3], a_re.val[3], b_re.val[3]); \
|
||||
vfmls(d_re.val[3], d_re.val[3], a_im.val[3], b_im.val[3]); \
|
||||
vfmla(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \
|
||||
vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \
|
||||
vfmla(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \
|
||||
vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]); \
|
||||
vfmla(d_im.val[2], d_im.val[2], a_re.val[2], b_im.val[2]); \
|
||||
vfmla(d_im.val[2], d_im.val[2], a_im.val[2], b_re.val[2]); \
|
||||
vfmla(d_im.val[3], d_im.val[3], a_re.val[3], b_im.val[3]); \
|
||||
vfmla(d_im.val[3], d_im.val[3], a_im.val[3], b_re.val[3]);
|
||||
|
||||
#define FPC_MUL_CONJx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
vfmul(d_re.val[0], b_im.val[0], a_im.val[0]); \
|
||||
vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \
|
||||
vfmul(d_re.val[1], b_im.val[1], a_im.val[1]); \
|
||||
vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \
|
||||
vfmul(d_re.val[2], b_im.val[2], a_im.val[2]); \
|
||||
vfmla(d_re.val[2], d_re.val[2], a_re.val[2], b_re.val[2]); \
|
||||
vfmul(d_re.val[3], b_im.val[3], a_im.val[3]); \
|
||||
vfmla(d_re.val[3], d_re.val[3], a_re.val[3], b_re.val[3]); \
|
||||
vfmul(d_im.val[0], b_re.val[0], a_im.val[0]); \
|
||||
vfmls(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \
|
||||
vfmul(d_im.val[1], b_re.val[1], a_im.val[1]); \
|
||||
vfmls(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \
|
||||
vfmul(d_im.val[2], b_re.val[2], a_im.val[2]); \
|
||||
vfmls(d_im.val[2], d_im.val[2], a_re.val[2], b_im.val[2]); \
|
||||
vfmul(d_im.val[3], b_re.val[3], a_im.val[3]); \
|
||||
vfmls(d_im.val[3], d_im.val[3], a_re.val[3], b_im.val[3]);
|
||||
|
||||
#define FPC_MLA_CONJx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
vfmla(d_re.val[0], d_re.val[0], b_im.val[0], a_im.val[0]); \
|
||||
vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \
|
||||
vfmla(d_re.val[1], d_re.val[1], b_im.val[1], a_im.val[1]); \
|
||||
vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \
|
||||
vfmla(d_re.val[2], d_re.val[2], b_im.val[2], a_im.val[2]); \
|
||||
vfmla(d_re.val[2], d_re.val[2], a_re.val[2], b_re.val[2]); \
|
||||
vfmla(d_re.val[3], d_re.val[3], b_im.val[3], a_im.val[3]); \
|
||||
vfmla(d_re.val[3], d_re.val[3], a_re.val[3], b_re.val[3]); \
|
||||
vfmla(d_im.val[0], d_im.val[0], b_re.val[0], a_im.val[0]); \
|
||||
vfmls(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \
|
||||
vfmla(d_im.val[1], d_im.val[1], b_re.val[1], a_im.val[1]); \
|
||||
vfmls(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \
|
||||
vfmla(d_im.val[2], d_im.val[2], b_re.val[2], a_im.val[2]); \
|
||||
vfmls(d_im.val[2], d_im.val[2], a_re.val[2], b_im.val[2]); \
|
||||
vfmla(d_im.val[3], d_im.val[3], b_re.val[3], a_im.val[3]); \
|
||||
vfmls(d_im.val[3], d_im.val[3], a_re.val[3], b_im.val[3]);
|
||||
|
||||
#define FPC_MUL_LANE(d_re, d_im, a_re, a_im, b_re_im) \
|
||||
vfmul_lane(d_re, a_re, b_re_im, 0); \
|
||||
vfmls_lane(d_re, d_re, a_im, b_re_im, 1); \
|
||||
vfmul_lane(d_im, a_re, b_re_im, 1); \
|
||||
vfmla_lane(d_im, d_im, a_im, b_re_im, 0);
|
||||
|
||||
#define FPC_MUL_LANEx4(d_re, d_im, a_re, a_im, b_re_im) \
|
||||
vfmul_lane(d_re.val[0], a_re.val[0], b_re_im, 0); \
|
||||
vfmls_lane(d_re.val[0], d_re.val[0], a_im.val[0], b_re_im, 1); \
|
||||
vfmul_lane(d_re.val[1], a_re.val[1], b_re_im, 0); \
|
||||
vfmls_lane(d_re.val[1], d_re.val[1], a_im.val[1], b_re_im, 1); \
|
||||
vfmul_lane(d_re.val[2], a_re.val[2], b_re_im, 0); \
|
||||
vfmls_lane(d_re.val[2], d_re.val[2], a_im.val[2], b_re_im, 1); \
|
||||
vfmul_lane(d_re.val[3], a_re.val[3], b_re_im, 0); \
|
||||
vfmls_lane(d_re.val[3], d_re.val[3], a_im.val[3], b_re_im, 1); \
|
||||
vfmul_lane(d_im.val[0], a_re.val[0], b_re_im, 1); \
|
||||
vfmla_lane(d_im.val[0], d_im.val[0], a_im.val[0], b_re_im, 0); \
|
||||
vfmul_lane(d_im.val[1], a_re.val[1], b_re_im, 1); \
|
||||
vfmla_lane(d_im.val[1], d_im.val[1], a_im.val[1], b_re_im, 0); \
|
||||
vfmul_lane(d_im.val[2], a_re.val[2], b_re_im, 1); \
|
||||
vfmla_lane(d_im.val[2], d_im.val[2], a_im.val[2], b_re_im, 0); \
|
||||
vfmul_lane(d_im.val[3], a_re.val[3], b_re_im, 1); \
|
||||
vfmla_lane(d_im.val[3], d_im.val[3], a_im.val[3], b_re_im, 0);
|
||||
|
||||
#define FWD_TOP(t_re, t_im, b_re, b_im, zeta_re, zeta_im) FPC_MUL(t_re, t_im, b_re, b_im, zeta_re, zeta_im);
|
||||
|
||||
#define FWD_TOP_LANE(t_re, t_im, b_re, b_im, zeta) FPC_MUL_LANE(t_re, t_im, b_re, b_im, zeta);
|
||||
|
||||
#define FWD_TOP_LANEx4(t_re, t_im, b_re, b_im, zeta) FPC_MUL_LANEx4(t_re, t_im, b_re, b_im, zeta);
|
||||
|
||||
/*
|
||||
* FPC
|
||||
*/
|
||||
|
||||
#define FPC_SUB(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
d_re = vsubq_f64(a_re, b_re); \
|
||||
d_im = vsubq_f64(a_im, b_im);
|
||||
|
||||
#define FPC_SUBx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
d_re.val[0] = vsubq_f64(a_re.val[0], b_re.val[0]); \
|
||||
d_im.val[0] = vsubq_f64(a_im.val[0], b_im.val[0]); \
|
||||
d_re.val[1] = vsubq_f64(a_re.val[1], b_re.val[1]); \
|
||||
d_im.val[1] = vsubq_f64(a_im.val[1], b_im.val[1]); \
|
||||
d_re.val[2] = vsubq_f64(a_re.val[2], b_re.val[2]); \
|
||||
d_im.val[2] = vsubq_f64(a_im.val[2], b_im.val[2]); \
|
||||
d_re.val[3] = vsubq_f64(a_re.val[3], b_re.val[3]); \
|
||||
d_im.val[3] = vsubq_f64(a_im.val[3], b_im.val[3]);
|
||||
|
||||
#define FPC_ADD(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
d_re = vaddq_f64(a_re, b_re); \
|
||||
d_im = vaddq_f64(a_im, b_im);
|
||||
|
||||
#define FPC_ADDx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
d_re.val[0] = vaddq_f64(a_re.val[0], b_re.val[0]); \
|
||||
d_im.val[0] = vaddq_f64(a_im.val[0], b_im.val[0]); \
|
||||
d_re.val[1] = vaddq_f64(a_re.val[1], b_re.val[1]); \
|
||||
d_im.val[1] = vaddq_f64(a_im.val[1], b_im.val[1]); \
|
||||
d_re.val[2] = vaddq_f64(a_re.val[2], b_re.val[2]); \
|
||||
d_im.val[2] = vaddq_f64(a_im.val[2], b_im.val[2]); \
|
||||
d_re.val[3] = vaddq_f64(a_re.val[3], b_re.val[3]); \
|
||||
d_im.val[3] = vaddq_f64(a_im.val[3], b_im.val[3]);
|
||||
|
||||
#define FWD_BOT(a_re, a_im, b_re, b_im, t_re, t_im) \
|
||||
FPC_SUB(b_re, b_im, a_re, a_im, t_re, t_im); \
|
||||
FPC_ADD(a_re, a_im, a_re, a_im, t_re, t_im);
|
||||
|
||||
#define FWD_BOTx4(a_re, a_im, b_re, b_im, t_re, t_im) \
|
||||
FPC_SUBx4(b_re, b_im, a_re, a_im, t_re, t_im); \
|
||||
FPC_ADDx4(a_re, a_im, a_re, a_im, t_re, t_im);
|
||||
|
||||
/*
|
||||
* FPC_J
|
||||
*/
|
||||
|
||||
#define FPC_ADDJ(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
d_re = vsubq_f64(a_re, b_im); \
|
||||
d_im = vaddq_f64(a_im, b_re);
|
||||
|
||||
#define FPC_ADDJx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
d_re.val[0] = vsubq_f64(a_re.val[0], b_im.val[0]); \
|
||||
d_im.val[0] = vaddq_f64(a_im.val[0], b_re.val[0]); \
|
||||
d_re.val[1] = vsubq_f64(a_re.val[1], b_im.val[1]); \
|
||||
d_im.val[1] = vaddq_f64(a_im.val[1], b_re.val[1]); \
|
||||
d_re.val[2] = vsubq_f64(a_re.val[2], b_im.val[2]); \
|
||||
d_im.val[2] = vaddq_f64(a_im.val[2], b_re.val[2]); \
|
||||
d_re.val[3] = vsubq_f64(a_re.val[3], b_im.val[3]); \
|
||||
d_im.val[3] = vaddq_f64(a_im.val[3], b_re.val[3]);
|
||||
|
||||
#define FPC_SUBJ(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
d_re = vaddq_f64(a_re, b_im); \
|
||||
d_im = vsubq_f64(a_im, b_re);
|
||||
|
||||
#define FPC_SUBJx4(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
d_re.val[0] = vaddq_f64(a_re.val[0], b_im.val[0]); \
|
||||
d_im.val[0] = vsubq_f64(a_im.val[0], b_re.val[0]); \
|
||||
d_re.val[1] = vaddq_f64(a_re.val[1], b_im.val[1]); \
|
||||
d_im.val[1] = vsubq_f64(a_im.val[1], b_re.val[1]); \
|
||||
d_re.val[2] = vaddq_f64(a_re.val[2], b_im.val[2]); \
|
||||
d_im.val[2] = vsubq_f64(a_im.val[2], b_re.val[2]); \
|
||||
d_re.val[3] = vaddq_f64(a_re.val[3], b_im.val[3]); \
|
||||
d_im.val[3] = vsubq_f64(a_im.val[3], b_re.val[3]);
|
||||
|
||||
#define FWD_BOTJ(a_re, a_im, b_re, b_im, t_re, t_im) \
|
||||
FPC_SUBJ(b_re, b_im, a_re, a_im, t_re, t_im); \
|
||||
FPC_ADDJ(a_re, a_im, a_re, a_im, t_re, t_im);
|
||||
|
||||
#define FWD_BOTJx4(a_re, a_im, b_re, b_im, t_re, t_im) \
|
||||
FPC_SUBJx4(b_re, b_im, a_re, a_im, t_re, t_im); \
|
||||
FPC_ADDJx4(a_re, a_im, a_re, a_im, t_re, t_im);
|
||||
|
||||
//============== Inverse FFT
|
||||
/*
|
||||
* FPC_J
|
||||
* a * conj(b)
|
||||
* Original (without swap):
|
||||
* d_re = b_im * a_im + a_re * b_re;
|
||||
* d_im = b_re * a_im - a_re * b_im;
|
||||
*/
|
||||
#define FPC_MUL_BOTJ_LANE(d_re, d_im, a_re, a_im, b_re_im) \
|
||||
vfmul_lane(d_re, a_re, b_re_im, 0); \
|
||||
vfmla_lane(d_re, d_re, a_im, b_re_im, 1); \
|
||||
vfmul_lane(d_im, a_im, b_re_im, 0); \
|
||||
vfmls_lane(d_im, d_im, a_re, b_re_im, 1);
|
||||
|
||||
#define FPC_MUL_BOTJ_LANEx4(d_re, d_im, a_re, a_im, b_re_im) \
|
||||
vfmul_lane(d_re.val[0], a_re.val[0], b_re_im, 0); \
|
||||
vfmla_lane(d_re.val[0], d_re.val[0], a_im.val[0], b_re_im, 1); \
|
||||
vfmul_lane(d_im.val[0], a_im.val[0], b_re_im, 0); \
|
||||
vfmls_lane(d_im.val[0], d_im.val[0], a_re.val[0], b_re_im, 1); \
|
||||
vfmul_lane(d_re.val[1], a_re.val[1], b_re_im, 0); \
|
||||
vfmla_lane(d_re.val[1], d_re.val[1], a_im.val[1], b_re_im, 1); \
|
||||
vfmul_lane(d_im.val[1], a_im.val[1], b_re_im, 0); \
|
||||
vfmls_lane(d_im.val[1], d_im.val[1], a_re.val[1], b_re_im, 1); \
|
||||
vfmul_lane(d_re.val[2], a_re.val[2], b_re_im, 0); \
|
||||
vfmla_lane(d_re.val[2], d_re.val[2], a_im.val[2], b_re_im, 1); \
|
||||
vfmul_lane(d_im.val[2], a_im.val[2], b_re_im, 0); \
|
||||
vfmls_lane(d_im.val[2], d_im.val[2], a_re.val[2], b_re_im, 1); \
|
||||
vfmul_lane(d_re.val[3], a_re.val[3], b_re_im, 0); \
|
||||
vfmla_lane(d_re.val[3], d_re.val[3], a_im.val[3], b_re_im, 1); \
|
||||
vfmul_lane(d_im.val[3], a_im.val[3], b_re_im, 0); \
|
||||
vfmls_lane(d_im.val[3], d_im.val[3], a_re.val[3], b_re_im, 1);
|
||||
|
||||
#define FPC_MUL_BOTJ(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
vfmul(d_re, b_im, a_im); \
|
||||
vfmla(d_re, d_re, a_re, b_re); \
|
||||
vfmul(d_im, b_re, a_im); \
|
||||
vfmls(d_im, d_im, a_re, b_im);
|
||||
|
||||
#define INV_TOPJ(t_re, t_im, a_re, a_im, b_re, b_im) \
|
||||
FPC_SUB(t_re, t_im, a_re, a_im, b_re, b_im); \
|
||||
FPC_ADD(a_re, a_im, a_re, a_im, b_re, b_im);
|
||||
|
||||
#define INV_TOPJx4(t_re, t_im, a_re, a_im, b_re, b_im) \
|
||||
FPC_SUBx4(t_re, t_im, a_re, a_im, b_re, b_im); \
|
||||
FPC_ADDx4(a_re, a_im, a_re, a_im, b_re, b_im);
|
||||
|
||||
#define INV_BOTJ(b_re, b_im, t_re, t_im, zeta_re, zeta_im) FPC_MUL_BOTJ(b_re, b_im, t_re, t_im, zeta_re, zeta_im);
|
||||
|
||||
#define INV_BOTJ_LANE(b_re, b_im, t_re, t_im, zeta) FPC_MUL_BOTJ_LANE(b_re, b_im, t_re, t_im, zeta);
|
||||
|
||||
#define INV_BOTJ_LANEx4(b_re, b_im, t_re, t_im, zeta) FPC_MUL_BOTJ_LANEx4(b_re, b_im, t_re, t_im, zeta);
|
||||
|
||||
/*
|
||||
* FPC_Jm
|
||||
* a * -conj(b)
|
||||
* d_re = a_re * b_im - a_im * b_re;
|
||||
* d_im = a_im * b_im + a_re * b_re;
|
||||
*/
|
||||
#define FPC_MUL_BOTJm_LANE(d_re, d_im, a_re, a_im, b_re_im) \
|
||||
vfmul_lane(d_re, a_re, b_re_im, 1); \
|
||||
vfmls_lane(d_re, d_re, a_im, b_re_im, 0); \
|
||||
vfmul_lane(d_im, a_re, b_re_im, 0); \
|
||||
vfmla_lane(d_im, d_im, a_im, b_re_im, 1);
|
||||
|
||||
#define FPC_MUL_BOTJm_LANEx4(d_re, d_im, a_re, a_im, b_re_im) \
|
||||
vfmul_lane(d_re.val[0], a_re.val[0], b_re_im, 1); \
|
||||
vfmls_lane(d_re.val[0], d_re.val[0], a_im.val[0], b_re_im, 0); \
|
||||
vfmul_lane(d_im.val[0], a_re.val[0], b_re_im, 0); \
|
||||
vfmla_lane(d_im.val[0], d_im.val[0], a_im.val[0], b_re_im, 1); \
|
||||
vfmul_lane(d_re.val[1], a_re.val[1], b_re_im, 1); \
|
||||
vfmls_lane(d_re.val[1], d_re.val[1], a_im.val[1], b_re_im, 0); \
|
||||
vfmul_lane(d_im.val[1], a_re.val[1], b_re_im, 0); \
|
||||
vfmla_lane(d_im.val[1], d_im.val[1], a_im.val[1], b_re_im, 1); \
|
||||
vfmul_lane(d_re.val[2], a_re.val[2], b_re_im, 1); \
|
||||
vfmls_lane(d_re.val[2], d_re.val[2], a_im.val[2], b_re_im, 0); \
|
||||
vfmul_lane(d_im.val[2], a_re.val[2], b_re_im, 0); \
|
||||
vfmla_lane(d_im.val[2], d_im.val[2], a_im.val[2], b_re_im, 1); \
|
||||
vfmul_lane(d_re.val[3], a_re.val[3], b_re_im, 1); \
|
||||
vfmls_lane(d_re.val[3], d_re.val[3], a_im.val[3], b_re_im, 0); \
|
||||
vfmul_lane(d_im.val[3], a_re.val[3], b_re_im, 0); \
|
||||
vfmla_lane(d_im.val[3], d_im.val[3], a_im.val[3], b_re_im, 1);
|
||||
|
||||
#define FPC_MUL_BOTJm(d_re, d_im, a_re, a_im, b_re, b_im) \
|
||||
vfmul(d_re, a_re, b_im); \
|
||||
vfmls(d_re, d_re, a_im, b_re); \
|
||||
vfmul(d_im, a_im, b_im); \
|
||||
vfmla(d_im, d_im, a_re, b_re);
|
||||
|
||||
#define INV_TOPJm(t_re, t_im, a_re, a_im, b_re, b_im) \
|
||||
FPC_SUB(t_re, t_im, b_re, b_im, a_re, a_im); \
|
||||
FPC_ADD(a_re, a_im, a_re, a_im, b_re, b_im);
|
||||
|
||||
#define INV_TOPJmx4(t_re, t_im, a_re, a_im, b_re, b_im) \
|
||||
FPC_SUBx4(t_re, t_im, b_re, b_im, a_re, a_im); \
|
||||
FPC_ADDx4(a_re, a_im, a_re, a_im, b_re, b_im);
|
||||
|
||||
#define INV_BOTJm(b_re, b_im, t_re, t_im, zeta_re, zeta_im) FPC_MUL_BOTJm(b_re, b_im, t_re, t_im, zeta_re, zeta_im);
|
||||
|
||||
#define INV_BOTJm_LANE(b_re, b_im, t_re, t_im, zeta) FPC_MUL_BOTJm_LANE(b_re, b_im, t_re, t_im, zeta);
|
||||
|
||||
#define INV_BOTJm_LANEx4(b_re, b_im, t_re, t_im, zeta) FPC_MUL_BOTJm_LANEx4(b_re, b_im, t_re, t_im, zeta);
|
||||
|
||||
#endif
|
||||
@@ -1,115 +0,0 @@
|
||||
#ifndef SPQLIOS_Q120_ARITHMETIC_H
|
||||
#define SPQLIOS_Q120_ARITHMETIC_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../commons.h"
|
||||
#include "q120_common.h"
|
||||
|
||||
typedef struct _q120_mat1col_product_baa_precomp q120_mat1col_product_baa_precomp;
|
||||
typedef struct _q120_mat1col_product_bbb_precomp q120_mat1col_product_bbb_precomp;
|
||||
typedef struct _q120_mat1col_product_bbc_precomp q120_mat1col_product_bbc_precomp;
|
||||
|
||||
EXPORT q120_mat1col_product_baa_precomp* q120_new_vec_mat1col_product_baa_precomp();
|
||||
EXPORT void q120_delete_vec_mat1col_product_baa_precomp(q120_mat1col_product_baa_precomp*);
|
||||
EXPORT q120_mat1col_product_bbb_precomp* q120_new_vec_mat1col_product_bbb_precomp();
|
||||
EXPORT void q120_delete_vec_mat1col_product_bbb_precomp(q120_mat1col_product_bbb_precomp*);
|
||||
EXPORT q120_mat1col_product_bbc_precomp* q120_new_vec_mat1col_product_bbc_precomp();
|
||||
EXPORT void q120_delete_vec_mat1col_product_bbc_precomp(q120_mat1col_product_bbc_precomp*);
|
||||
|
||||
// ell < 10000
|
||||
EXPORT void q120_vec_mat1col_product_baa_ref(q120_mat1col_product_baa_precomp*, const uint64_t ell, q120b* const res,
|
||||
const q120a* const x, const q120a* const y);
|
||||
EXPORT void q120_vec_mat1col_product_bbb_ref(q120_mat1col_product_bbb_precomp*, const uint64_t ell, q120b* const res,
|
||||
const q120b* const x, const q120b* const y);
|
||||
EXPORT void q120_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp*, const uint64_t ell, q120b* const res,
|
||||
const q120b* const x, const q120c* const y);
|
||||
|
||||
EXPORT void q120_vec_mat1col_product_baa_avx2(q120_mat1col_product_baa_precomp*, const uint64_t ell, q120b* const res,
|
||||
const q120a* const x, const q120a* const y);
|
||||
EXPORT void q120_vec_mat1col_product_bbb_avx2(q120_mat1col_product_bbb_precomp*, const uint64_t ell, q120b* const res,
|
||||
const q120b* const x, const q120b* const y);
|
||||
EXPORT void q120_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp*, const uint64_t ell, q120b* const res,
|
||||
const q120b* const x, const q120c* const y);
|
||||
|
||||
EXPORT void q120x2_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120b* const x, const q120c* const y);
|
||||
EXPORT void q120x2_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120b* const x, const q120c* const y);
|
||||
EXPORT void q120x2_vec_mat2cols_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120b* const x, const q120c* const y);
|
||||
EXPORT void q120x2_vec_mat2cols_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120b* const x, const q120c* const y);
|
||||
|
||||
/**
|
||||
* @brief extract 1 q120x2 block from one q120 ntt vectors
|
||||
* @param nn the size of each vector
|
||||
* @param blk the block id to extract (<nn/2)
|
||||
* @param dst the output: nrows q120x2's dst[i] = src[i](blk)
|
||||
* @param src the input: nrows q120 ntt vecs's
|
||||
*/
|
||||
EXPORT void q120x2_extract_1blk_from_q120b_ref(uint64_t nn, uint64_t blk,
|
||||
q120x2b* const dst, // 8 doubles
|
||||
const q120b* const src // a reim vector
|
||||
);
|
||||
EXPORT void q120x2_extract_1blk_from_q120c_ref(uint64_t nn, uint64_t blk,
|
||||
q120x2c* const dst, // 8 doubles
|
||||
const q120c* const src // a reim vector
|
||||
);
|
||||
EXPORT void q120x2_extract_1blk_from_q120b_avx(uint64_t nn, uint64_t blk,
|
||||
q120x2b* const dst, // 8 doubles
|
||||
const q120b* const src // a reim vector
|
||||
);
|
||||
EXPORT void q120x2_extract_1blk_from_q120c_avx(uint64_t nn, uint64_t blk,
|
||||
q120x2c* const dst, // 8 doubles
|
||||
const q120c* const src // a reim vector
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief extract 1 reim4 block from nrows reim vectors of m complexes
|
||||
* @param nn the size of each q120
|
||||
* @param nrows the number of q120 (ntt) vectors
|
||||
* @param blk the block id to extract (<m/4)
|
||||
* @param dst the output: nrows q120x2's dst[i] = src[i](blk)
|
||||
* @param src the input: nrows q120 ntt vectors
|
||||
*/
|
||||
EXPORT void q120x2_extract_1blk_from_contiguous_q120b_ref(
|
||||
uint64_t nn, uint64_t nrows, uint64_t blk,
|
||||
q120x2b* const dst, // nrows * 2 q120
|
||||
const q120b* const src // a contiguous array of nrows q120b vectors
|
||||
);
|
||||
EXPORT void q120x2_extract_1blk_from_contiguous_q120b_avx(
|
||||
uint64_t nn, uint64_t nrows, uint64_t blk,
|
||||
q120x2b* const dst, // nrows * 2 q120
|
||||
const q120b* const src // a contiguous array of nrows q120b vectors
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief saves 1 single q120x2 block in a q120 vectors of size nn
|
||||
* @param nn the size of the output q120
|
||||
* @param blk the block id to save (<nn/2)
|
||||
* @param dest the output q120b vector: dst(blk) = src
|
||||
* @param src the input q120x2b
|
||||
*/
|
||||
EXPORT void q120x2b_save_1blk_to_q120b_ref(uint64_t nn, uint64_t blk,
|
||||
q120b* dest, // 1 reim vector of length m
|
||||
const q120x2b* src // 8 doubles
|
||||
);
|
||||
EXPORT void q120x2b_save_1blk_to_q120b_avx(uint64_t nn, uint64_t blk,
|
||||
q120b* dest, // 1 reim vector of length m
|
||||
const q120x2b* src // 8 doubles
|
||||
);
|
||||
|
||||
EXPORT void q120_add_bbb_simple(uint64_t nn, q120b* const res, const q120b* const x, const q120b* const y);
|
||||
|
||||
EXPORT void q120_add_ccc_simple(uint64_t nn, q120c* const res, const q120c* const x, const q120c* const y);
|
||||
|
||||
EXPORT void q120_c_from_b_simple(uint64_t nn, q120c* const res, const q120b* const x);
|
||||
|
||||
EXPORT void q120_b_from_znx64_simple(uint64_t nn, q120b* const res, const int64_t* const x);
|
||||
|
||||
EXPORT void q120_c_from_znx64_simple(uint64_t nn, q120c* const res, const int64_t* const x);
|
||||
|
||||
EXPORT void q120_b_to_znx128_simple(uint64_t nn, __int128_t* const res, const q120b* const x);
|
||||
|
||||
#endif // SPQLIOS_Q120_ARITHMETIC_H
|
||||
@@ -1,567 +0,0 @@
|
||||
#include <immintrin.h>
|
||||
|
||||
#include "q120_arithmetic.h"
|
||||
#include "q120_arithmetic_private.h"
|
||||
|
||||
EXPORT void q120_vec_mat1col_product_baa_avx2(q120_mat1col_product_baa_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120a* const x, const q120a* const y) {
|
||||
/**
|
||||
* Algorithm:
|
||||
* - res = acc1 + acc2 . ((2^H) % Q)
|
||||
* - acc1 is the sum of H LSB of products x[i].y[i]
|
||||
* - acc2 is the sum of 64-H MSB of products x[i]].y[i]
|
||||
* - for l < 10k acc1 will have H + log2(10000) and acc2 64 - H + log2(10000) bits
|
||||
* - final sum has max(H, 64 - H + bit_size((2^H) % Q)) + log2(10000) + 1 bits
|
||||
*/
|
||||
|
||||
const uint64_t H = precomp->h;
|
||||
const __m256i MASK = _mm256_set1_epi64x((UINT64_C(1) << H) - 1);
|
||||
|
||||
__m256i acc1 = _mm256_setzero_si256();
|
||||
__m256i acc2 = _mm256_setzero_si256();
|
||||
|
||||
const __m256i* x_ptr = (__m256i*)x;
|
||||
const __m256i* y_ptr = (__m256i*)y;
|
||||
|
||||
for (uint64_t i = 0; i < ell; ++i) {
|
||||
__m256i a = _mm256_loadu_si256(x_ptr);
|
||||
__m256i b = _mm256_loadu_si256(y_ptr);
|
||||
__m256i t = _mm256_mul_epu32(a, b);
|
||||
|
||||
acc1 = _mm256_add_epi64(acc1, _mm256_and_si256(t, MASK));
|
||||
acc2 = _mm256_add_epi64(acc2, _mm256_srli_epi64(t, H));
|
||||
|
||||
x_ptr++;
|
||||
y_ptr++;
|
||||
}
|
||||
|
||||
const __m256i H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->h_pow_red);
|
||||
|
||||
__m256i t = _mm256_add_epi64(acc1, _mm256_mul_epu32(acc2, H_POW_RED));
|
||||
_mm256_storeu_si256((__m256i*)res, t);
|
||||
}
|
||||
|
||||
EXPORT void q120_vec_mat1col_product_bbb_avx2(q120_mat1col_product_bbb_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120b* const x, const q120b* const y) {
|
||||
/**
|
||||
* Algorithm:
|
||||
* 1. Split x_i and y_i in 2 32-bit parts and compute the cross-products:
|
||||
* - x_i = xl_i + xh_i . 2^32
|
||||
* - y_i = yl_i + yh_i . 2^32
|
||||
* - A_i = xl_i . yl_i
|
||||
* - B_i = xl_i . yh_i
|
||||
* - C_i = xh_i . yl_i
|
||||
* - D_i = xh_i . yh_i
|
||||
* - we have x_i . y_i == A_i + (B_i + C_i) . 2^32 + D_i . 2^64
|
||||
* 2. Split A_i, B_i, C_i and D_i into 2 32-bit parts
|
||||
* - A_i = Al_i + Ah_i . 2^32
|
||||
* - B_i = Bl_i + Bh_i . 2^32
|
||||
* - C_i = Cl_i + Ch_i . 2^32
|
||||
* - D_i = Dl_i + Dh_i . 2^32
|
||||
* 3. Compute the sums:
|
||||
* - S1 = \sum Al_i
|
||||
* - S2 = \sum (Ah_i + Bl_i + Cl_i)
|
||||
* - S3 = \sum (Bh_i + Ch_i + Dl_i)
|
||||
* - S4 = \sum Dh_i
|
||||
* - here S1, S4 have 32 + log2(ell) bits and S2, S3 have 32 + log2(ell) +
|
||||
* log2(3) bits
|
||||
* - for ell == 10000 S2, S3 have < 47 bits
|
||||
* 4. Split S1, S2, S3 and S4 in 2 24-bit parts (24 = ceil(47/2))
|
||||
* - S1 = S1l + S1h . 2^24
|
||||
* - S2 = S2l + S2h . 2^24
|
||||
* - S3 = S3l + S3h . 2^24
|
||||
* - S4 = S4l + S4h . 2^24
|
||||
* 5. Compute final result as:
|
||||
* - \sum x_i . y_i = S1l + S1h . 2^24
|
||||
* + S2l . 2^32 + S2h . 2^(32+24)
|
||||
* + S3l . 2^64 + S3h . 2^(64 + 24)
|
||||
* + S4l . 2^96 + S4l . 2^(96+24)
|
||||
* - here the powers of 2 are reduced modulo the primes Q before
|
||||
* multiplications
|
||||
* - the result will be on 24 + 3 + bit size of primes Q
|
||||
*/
|
||||
const uint64_t H1 = 32;
|
||||
const __m256i MASK1 = _mm256_set1_epi64x((UINT64_C(1) << H1) - 1);
|
||||
|
||||
__m256i s1 = _mm256_setzero_si256();
|
||||
__m256i s2 = _mm256_setzero_si256();
|
||||
__m256i s3 = _mm256_setzero_si256();
|
||||
__m256i s4 = _mm256_setzero_si256();
|
||||
|
||||
const __m256i* x_ptr = (__m256i*)x;
|
||||
const __m256i* y_ptr = (__m256i*)y;
|
||||
|
||||
for (uint64_t i = 0; i < ell; ++i) {
|
||||
__m256i x = _mm256_loadu_si256(x_ptr);
|
||||
__m256i xl = _mm256_and_si256(x, MASK1);
|
||||
__m256i xh = _mm256_srli_epi64(x, H1);
|
||||
|
||||
__m256i y = _mm256_loadu_si256(y_ptr);
|
||||
__m256i yl = _mm256_and_si256(y, MASK1);
|
||||
__m256i yh = _mm256_srli_epi64(y, H1);
|
||||
|
||||
__m256i a = _mm256_mul_epu32(xl, yl);
|
||||
__m256i b = _mm256_mul_epu32(xl, yh);
|
||||
__m256i c = _mm256_mul_epu32(xh, yl);
|
||||
__m256i d = _mm256_mul_epu32(xh, yh);
|
||||
|
||||
s1 = _mm256_add_epi64(s1, _mm256_and_si256(a, MASK1));
|
||||
|
||||
s2 = _mm256_add_epi64(s2, _mm256_srli_epi64(a, H1));
|
||||
s2 = _mm256_add_epi64(s2, _mm256_and_si256(b, MASK1));
|
||||
s2 = _mm256_add_epi64(s2, _mm256_and_si256(c, MASK1));
|
||||
|
||||
s3 = _mm256_add_epi64(s3, _mm256_srli_epi64(b, H1));
|
||||
s3 = _mm256_add_epi64(s3, _mm256_srli_epi64(c, H1));
|
||||
s3 = _mm256_add_epi64(s3, _mm256_and_si256(d, MASK1));
|
||||
|
||||
s4 = _mm256_add_epi64(s4, _mm256_srli_epi64(d, H1));
|
||||
|
||||
x_ptr++;
|
||||
y_ptr++;
|
||||
}
|
||||
|
||||
const uint64_t H2 = precomp->h;
|
||||
const __m256i MASK2 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1);
|
||||
|
||||
const __m256i S1H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s1h_pow_red);
|
||||
__m256i s1l = _mm256_and_si256(s1, MASK2);
|
||||
__m256i s1h = _mm256_srli_epi64(s1, H2);
|
||||
__m256i t = _mm256_add_epi64(s1l, _mm256_mul_epu32(s1h, S1H_POW_RED));
|
||||
|
||||
const __m256i S2L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red);
|
||||
const __m256i S2H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red);
|
||||
__m256i s2l = _mm256_and_si256(s2, MASK2);
|
||||
__m256i s2h = _mm256_srli_epi64(s2, H2);
|
||||
t = _mm256_add_epi64(t, _mm256_mul_epu32(s2l, S2L_POW_RED));
|
||||
t = _mm256_add_epi64(t, _mm256_mul_epu32(s2h, S2H_POW_RED));
|
||||
|
||||
const __m256i S3L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s3l_pow_red);
|
||||
const __m256i S3H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s3h_pow_red);
|
||||
__m256i s3l = _mm256_and_si256(s3, MASK2);
|
||||
__m256i s3h = _mm256_srli_epi64(s3, H2);
|
||||
t = _mm256_add_epi64(t, _mm256_mul_epu32(s3l, S3L_POW_RED));
|
||||
t = _mm256_add_epi64(t, _mm256_mul_epu32(s3h, S3H_POW_RED));
|
||||
|
||||
const __m256i S4L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s4l_pow_red);
|
||||
const __m256i S4H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s4h_pow_red);
|
||||
__m256i s4l = _mm256_and_si256(s4, MASK2);
|
||||
__m256i s4h = _mm256_srli_epi64(s4, H2);
|
||||
t = _mm256_add_epi64(t, _mm256_mul_epu32(s4l, S4L_POW_RED));
|
||||
t = _mm256_add_epi64(t, _mm256_mul_epu32(s4h, S4H_POW_RED));
|
||||
|
||||
_mm256_storeu_si256((__m256i*)res, t);
|
||||
}
|
||||
|
||||
EXPORT void q120_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||
/**
|
||||
* Algorithm:
|
||||
* 0. We have
|
||||
* - y0_i == y_i % Q and y1_i == (y_i . 2^32) % Q
|
||||
* 1. Split x_i in 2 32-bit parts and compute the cross-products:
|
||||
* - x_i = xl_i + xh_i . 2^32
|
||||
* - A_i = xl_i . y1_i
|
||||
* - B_i = xh_i . y2_i
|
||||
* - we have x_i . y_i == A_i + B_i
|
||||
* 2. Split A_i and B_i into 2 32-bit parts
|
||||
* - A_i = Al_i + Ah_i . 2^32
|
||||
* - B_i = Bl_i + Bh_i . 2^32
|
||||
* 3. Compute the sums:
|
||||
* - S1 = \sum Al_i + Bl_i
|
||||
* - S2 = \sum Ah_i + Bh_i
|
||||
* - here S1 and S2 have 32 + log2(ell) bits
|
||||
* - for ell == 10000 S1, S2 have < 46 bits
|
||||
* 4. Split S2 in 27-bit and 19-bit parts (27+19 == 46)
|
||||
* - S2 = S2l + S2h . 2^27
|
||||
* 5. Compute final result as:
|
||||
* - \sum x_i . y_i = S1 + S2l . 2^32 + S2h . 2^(32+27)
|
||||
* - here the powers of 2 are reduced modulo the primes Q before
|
||||
* multiplications
|
||||
* - the result will be on < 52 bits
|
||||
*/
|
||||
|
||||
const uint64_t H1 = 32;
|
||||
const __m256i MASK1 = _mm256_set1_epi64x((UINT64_C(1) << H1) - 1);
|
||||
|
||||
__m256i s1 = _mm256_setzero_si256();
|
||||
__m256i s2 = _mm256_setzero_si256();
|
||||
|
||||
const __m256i* x_ptr = (__m256i*)x;
|
||||
const __m256i* y_ptr = (__m256i*)y;
|
||||
|
||||
for (uint64_t i = 0; i < ell; ++i) {
|
||||
__m256i x = _mm256_loadu_si256(x_ptr);
|
||||
__m256i xl = _mm256_and_si256(x, MASK1);
|
||||
__m256i xh = _mm256_srli_epi64(x, H1);
|
||||
|
||||
__m256i y = _mm256_loadu_si256(y_ptr);
|
||||
__m256i y0 = _mm256_and_si256(y, MASK1);
|
||||
__m256i y1 = _mm256_srli_epi64(y, H1);
|
||||
|
||||
__m256i a = _mm256_mul_epu32(xl, y0);
|
||||
__m256i b = _mm256_mul_epu32(xh, y1);
|
||||
|
||||
s1 = _mm256_add_epi64(s1, _mm256_and_si256(a, MASK1));
|
||||
s1 = _mm256_add_epi64(s1, _mm256_and_si256(b, MASK1));
|
||||
|
||||
s2 = _mm256_add_epi64(s2, _mm256_srli_epi64(a, H1));
|
||||
s2 = _mm256_add_epi64(s2, _mm256_srli_epi64(b, H1));
|
||||
|
||||
x_ptr++;
|
||||
y_ptr++;
|
||||
}
|
||||
|
||||
const uint64_t H2 = precomp->h;
|
||||
const __m256i MASK2 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1);
|
||||
|
||||
__m256i t = s1;
|
||||
|
||||
const __m256i S2L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red);
|
||||
const __m256i S2H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red);
|
||||
__m256i s2l = _mm256_and_si256(s2, MASK2);
|
||||
__m256i s2h = _mm256_srli_epi64(s2, H2);
|
||||
t = _mm256_add_epi64(t, _mm256_mul_epu32(s2l, S2L_POW_RED));
|
||||
t = _mm256_add_epi64(t, _mm256_mul_epu32(s2h, S2H_POW_RED));
|
||||
|
||||
_mm256_storeu_si256((__m256i*)res, t);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated keeping this one for history only.
|
||||
* There is a slight register starvation condition on the q120x2_vec_mat2cols
|
||||
* strategy below sounds better.
|
||||
*/
|
||||
EXPORT void q120x2_vec_mat2cols_product_bbc_avx2_old(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||
__m256i s0 = _mm256_setzero_si256(); // col 1a
|
||||
__m256i s1 = _mm256_setzero_si256();
|
||||
__m256i s2 = _mm256_setzero_si256(); // col 1b
|
||||
__m256i s3 = _mm256_setzero_si256();
|
||||
__m256i s4 = _mm256_setzero_si256(); // col 2a
|
||||
__m256i s5 = _mm256_setzero_si256();
|
||||
__m256i s6 = _mm256_setzero_si256(); // col 2b
|
||||
__m256i s7 = _mm256_setzero_si256();
|
||||
__m256i s8, s9, s10, s11;
|
||||
__m256i s12, s13, s14, s15;
|
||||
|
||||
const __m256i* x_ptr = (__m256i*)x;
|
||||
const __m256i* y_ptr = (__m256i*)y;
|
||||
__m256i* res_ptr = (__m256i*)res;
|
||||
for (uint64_t i = 0; i < ell; ++i) {
|
||||
s8 = _mm256_loadu_si256(x_ptr);
|
||||
s9 = _mm256_loadu_si256(x_ptr + 1);
|
||||
s10 = _mm256_srli_epi64(s8, 32);
|
||||
s11 = _mm256_srli_epi64(s9, 32);
|
||||
|
||||
s12 = _mm256_loadu_si256(y_ptr);
|
||||
s13 = _mm256_loadu_si256(y_ptr + 1);
|
||||
s14 = _mm256_srli_epi64(s12, 32);
|
||||
s15 = _mm256_srli_epi64(s13, 32);
|
||||
|
||||
s12 = _mm256_mul_epu32(s8, s12); // -> s0,s1
|
||||
s13 = _mm256_mul_epu32(s9, s13); // -> s2,s3
|
||||
s14 = _mm256_mul_epu32(s10, s14); // -> s0,s1
|
||||
s15 = _mm256_mul_epu32(s11, s15); // -> s2,s3
|
||||
|
||||
s10 = _mm256_slli_epi64(s12, 32); // -> s0
|
||||
s11 = _mm256_slli_epi64(s13, 32); // -> s2
|
||||
s12 = _mm256_srli_epi64(s12, 32); // -> s1
|
||||
s13 = _mm256_srli_epi64(s13, 32); // -> s3
|
||||
s10 = _mm256_srli_epi64(s10, 32); // -> s0
|
||||
s11 = _mm256_srli_epi64(s11, 32); // -> s2
|
||||
|
||||
s0 = _mm256_add_epi64(s0, s10);
|
||||
s1 = _mm256_add_epi64(s1, s12);
|
||||
s2 = _mm256_add_epi64(s2, s11);
|
||||
s3 = _mm256_add_epi64(s3, s13);
|
||||
|
||||
s10 = _mm256_slli_epi64(s14, 32); // -> s0
|
||||
s11 = _mm256_slli_epi64(s15, 32); // -> s2
|
||||
s14 = _mm256_srli_epi64(s14, 32); // -> s1
|
||||
s15 = _mm256_srli_epi64(s15, 32); // -> s3
|
||||
s10 = _mm256_srli_epi64(s10, 32); // -> s0
|
||||
s11 = _mm256_srli_epi64(s11, 32); // -> s2
|
||||
|
||||
s0 = _mm256_add_epi64(s0, s10);
|
||||
s1 = _mm256_add_epi64(s1, s14);
|
||||
s2 = _mm256_add_epi64(s2, s11);
|
||||
s3 = _mm256_add_epi64(s3, s15);
|
||||
|
||||
// deal with the second column
|
||||
// s8,s9 are still in place!
|
||||
s10 = _mm256_srli_epi64(s8, 32);
|
||||
s11 = _mm256_srli_epi64(s9, 32);
|
||||
|
||||
s12 = _mm256_loadu_si256(y_ptr + 2);
|
||||
s13 = _mm256_loadu_si256(y_ptr + 3);
|
||||
s14 = _mm256_srli_epi64(s12, 32);
|
||||
s15 = _mm256_srli_epi64(s13, 32);
|
||||
|
||||
s12 = _mm256_mul_epu32(s8, s12); // -> s4,s5
|
||||
s13 = _mm256_mul_epu32(s9, s13); // -> s6,s7
|
||||
s14 = _mm256_mul_epu32(s10, s14); // -> s4,s5
|
||||
s15 = _mm256_mul_epu32(s11, s15); // -> s6,s7
|
||||
|
||||
s10 = _mm256_slli_epi64(s12, 32); // -> s4
|
||||
s11 = _mm256_slli_epi64(s13, 32); // -> s6
|
||||
s12 = _mm256_srli_epi64(s12, 32); // -> s5
|
||||
s13 = _mm256_srli_epi64(s13, 32); // -> s7
|
||||
s10 = _mm256_srli_epi64(s10, 32); // -> s4
|
||||
s11 = _mm256_srli_epi64(s11, 32); // -> s6
|
||||
|
||||
s4 = _mm256_add_epi64(s4, s10);
|
||||
s5 = _mm256_add_epi64(s5, s12);
|
||||
s6 = _mm256_add_epi64(s6, s11);
|
||||
s7 = _mm256_add_epi64(s7, s13);
|
||||
|
||||
s10 = _mm256_slli_epi64(s14, 32); // -> s4
|
||||
s11 = _mm256_slli_epi64(s15, 32); // -> s6
|
||||
s14 = _mm256_srli_epi64(s14, 32); // -> s5
|
||||
s15 = _mm256_srli_epi64(s15, 32); // -> s7
|
||||
s10 = _mm256_srli_epi64(s10, 32); // -> s4
|
||||
s11 = _mm256_srli_epi64(s11, 32); // -> s6
|
||||
|
||||
s4 = _mm256_add_epi64(s4, s10);
|
||||
s5 = _mm256_add_epi64(s5, s14);
|
||||
s6 = _mm256_add_epi64(s6, s11);
|
||||
s7 = _mm256_add_epi64(s7, s15);
|
||||
|
||||
x_ptr += 2;
|
||||
y_ptr += 4;
|
||||
}
|
||||
// final reduction
|
||||
const uint64_t H2 = precomp->h;
|
||||
s8 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); // MASK2
|
||||
s9 = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); // S2L_POW_RED
|
||||
s10 = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); // S2H_POW_RED
|
||||
//--- s0,s1
|
||||
s11 = _mm256_and_si256(s1, s8);
|
||||
s12 = _mm256_srli_epi64(s1, H2);
|
||||
s13 = _mm256_mul_epu32(s11, s9);
|
||||
s14 = _mm256_mul_epu32(s12, s10);
|
||||
s0 = _mm256_add_epi64(s0, s13);
|
||||
s0 = _mm256_add_epi64(s0, s14);
|
||||
_mm256_storeu_si256(res_ptr + 0, s0);
|
||||
//--- s2,s3
|
||||
s11 = _mm256_and_si256(s3, s8);
|
||||
s12 = _mm256_srli_epi64(s3, H2);
|
||||
s13 = _mm256_mul_epu32(s11, s9);
|
||||
s14 = _mm256_mul_epu32(s12, s10);
|
||||
s2 = _mm256_add_epi64(s2, s13);
|
||||
s2 = _mm256_add_epi64(s2, s14);
|
||||
_mm256_storeu_si256(res_ptr + 1, s2);
|
||||
//--- s4,s5
|
||||
s11 = _mm256_and_si256(s5, s8);
|
||||
s12 = _mm256_srli_epi64(s5, H2);
|
||||
s13 = _mm256_mul_epu32(s11, s9);
|
||||
s14 = _mm256_mul_epu32(s12, s10);
|
||||
s4 = _mm256_add_epi64(s4, s13);
|
||||
s4 = _mm256_add_epi64(s4, s14);
|
||||
_mm256_storeu_si256(res_ptr + 2, s4);
|
||||
//--- s6,s7
|
||||
s11 = _mm256_and_si256(s7, s8);
|
||||
s12 = _mm256_srli_epi64(s7, H2);
|
||||
s13 = _mm256_mul_epu32(s11, s9);
|
||||
s14 = _mm256_mul_epu32(s12, s10);
|
||||
s6 = _mm256_add_epi64(s6, s13);
|
||||
s6 = _mm256_add_epi64(s6, s14);
|
||||
_mm256_storeu_si256(res_ptr + 3, s6);
|
||||
}
|
||||
|
||||
EXPORT void q120x2_vec_mat2cols_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||
__m256i s0 = _mm256_setzero_si256(); // col 1a
|
||||
__m256i s1 = _mm256_setzero_si256();
|
||||
__m256i s2 = _mm256_setzero_si256(); // col 1b
|
||||
__m256i s3 = _mm256_setzero_si256();
|
||||
__m256i s4 = _mm256_setzero_si256(); // col 2a
|
||||
__m256i s5 = _mm256_setzero_si256();
|
||||
__m256i s6 = _mm256_setzero_si256(); // col 2b
|
||||
__m256i s7 = _mm256_setzero_si256();
|
||||
__m256i s8, s9, s10, s11;
|
||||
__m256i s12, s13, s14, s15;
|
||||
|
||||
s11 = _mm256_set1_epi64x(0xFFFFFFFFUL);
|
||||
const __m256i* x_ptr = (__m256i*)x;
|
||||
const __m256i* y_ptr = (__m256i*)y;
|
||||
__m256i* res_ptr = (__m256i*)res;
|
||||
for (uint64_t i = 0; i < ell; ++i) {
|
||||
// treat item a
|
||||
s8 = _mm256_loadu_si256(x_ptr);
|
||||
s9 = _mm256_srli_epi64(s8, 32);
|
||||
|
||||
s12 = _mm256_loadu_si256(y_ptr);
|
||||
s13 = _mm256_loadu_si256(y_ptr + 2);
|
||||
s14 = _mm256_srli_epi64(s12, 32);
|
||||
s15 = _mm256_srli_epi64(s13, 32);
|
||||
|
||||
s12 = _mm256_mul_epu32(s8, s12); // c1a -> s0,s1
|
||||
s13 = _mm256_mul_epu32(s8, s13); // c2a -> s4,s5
|
||||
s14 = _mm256_mul_epu32(s9, s14); // c1a -> s0,s1
|
||||
s15 = _mm256_mul_epu32(s9, s15); // c2a -> s4,s5
|
||||
|
||||
s8 = _mm256_and_si256(s12, s11); // -> s0
|
||||
s9 = _mm256_and_si256(s13, s11); // -> s4
|
||||
s12 = _mm256_srli_epi64(s12, 32); // -> s1
|
||||
s13 = _mm256_srli_epi64(s13, 32); // -> s5
|
||||
s0 = _mm256_add_epi64(s0, s8);
|
||||
s1 = _mm256_add_epi64(s1, s12);
|
||||
s4 = _mm256_add_epi64(s4, s9);
|
||||
s5 = _mm256_add_epi64(s5, s13);
|
||||
|
||||
s8 = _mm256_and_si256(s14, s11); // -> s0
|
||||
s9 = _mm256_and_si256(s15, s11); // -> s4
|
||||
s14 = _mm256_srli_epi64(s14, 32); // -> s1
|
||||
s15 = _mm256_srli_epi64(s15, 32); // -> s5
|
||||
s0 = _mm256_add_epi64(s0, s8);
|
||||
s1 = _mm256_add_epi64(s1, s14);
|
||||
s4 = _mm256_add_epi64(s4, s9);
|
||||
s5 = _mm256_add_epi64(s5, s15);
|
||||
|
||||
// treat item b
|
||||
s8 = _mm256_loadu_si256(x_ptr + 1);
|
||||
s9 = _mm256_srli_epi64(s8, 32);
|
||||
|
||||
s12 = _mm256_loadu_si256(y_ptr + 1);
|
||||
s13 = _mm256_loadu_si256(y_ptr + 3);
|
||||
s14 = _mm256_srli_epi64(s12, 32);
|
||||
s15 = _mm256_srli_epi64(s13, 32);
|
||||
|
||||
s12 = _mm256_mul_epu32(s8, s12); // c1b -> s2,s3
|
||||
s13 = _mm256_mul_epu32(s8, s13); // c2b -> s6,s7
|
||||
s14 = _mm256_mul_epu32(s9, s14); // c1b -> s2,s3
|
||||
s15 = _mm256_mul_epu32(s9, s15); // c2b -> s6,s7
|
||||
|
||||
s8 = _mm256_and_si256(s12, s11); // -> s2
|
||||
s9 = _mm256_and_si256(s13, s11); // -> s6
|
||||
s12 = _mm256_srli_epi64(s12, 32); // -> s3
|
||||
s13 = _mm256_srli_epi64(s13, 32); // -> s7
|
||||
s2 = _mm256_add_epi64(s2, s8);
|
||||
s3 = _mm256_add_epi64(s3, s12);
|
||||
s6 = _mm256_add_epi64(s6, s9);
|
||||
s7 = _mm256_add_epi64(s7, s13);
|
||||
|
||||
s8 = _mm256_and_si256(s14, s11); // -> s2
|
||||
s9 = _mm256_and_si256(s15, s11); // -> s6
|
||||
s14 = _mm256_srli_epi64(s14, 32); // -> s3
|
||||
s15 = _mm256_srli_epi64(s15, 32); // -> s7
|
||||
s2 = _mm256_add_epi64(s2, s8);
|
||||
s3 = _mm256_add_epi64(s3, s14);
|
||||
s6 = _mm256_add_epi64(s6, s9);
|
||||
s7 = _mm256_add_epi64(s7, s15);
|
||||
|
||||
x_ptr += 2;
|
||||
y_ptr += 4;
|
||||
}
|
||||
// final reduction
|
||||
const uint64_t H2 = precomp->h;
|
||||
s8 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); // MASK2
|
||||
s9 = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); // S2L_POW_RED
|
||||
s10 = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); // S2H_POW_RED
|
||||
//--- s0,s1
|
||||
s11 = _mm256_and_si256(s1, s8);
|
||||
s12 = _mm256_srli_epi64(s1, H2);
|
||||
s13 = _mm256_mul_epu32(s11, s9);
|
||||
s14 = _mm256_mul_epu32(s12, s10);
|
||||
s0 = _mm256_add_epi64(s0, s13);
|
||||
s0 = _mm256_add_epi64(s0, s14);
|
||||
_mm256_storeu_si256(res_ptr + 0, s0);
|
||||
//--- s2,s3
|
||||
s11 = _mm256_and_si256(s3, s8);
|
||||
s12 = _mm256_srli_epi64(s3, H2);
|
||||
s13 = _mm256_mul_epu32(s11, s9);
|
||||
s14 = _mm256_mul_epu32(s12, s10);
|
||||
s2 = _mm256_add_epi64(s2, s13);
|
||||
s2 = _mm256_add_epi64(s2, s14);
|
||||
_mm256_storeu_si256(res_ptr + 1, s2);
|
||||
//--- s4,s5
|
||||
s11 = _mm256_and_si256(s5, s8);
|
||||
s12 = _mm256_srli_epi64(s5, H2);
|
||||
s13 = _mm256_mul_epu32(s11, s9);
|
||||
s14 = _mm256_mul_epu32(s12, s10);
|
||||
s4 = _mm256_add_epi64(s4, s13);
|
||||
s4 = _mm256_add_epi64(s4, s14);
|
||||
_mm256_storeu_si256(res_ptr + 2, s4);
|
||||
//--- s6,s7
|
||||
s11 = _mm256_and_si256(s7, s8);
|
||||
s12 = _mm256_srli_epi64(s7, H2);
|
||||
s13 = _mm256_mul_epu32(s11, s9);
|
||||
s14 = _mm256_mul_epu32(s12, s10);
|
||||
s6 = _mm256_add_epi64(s6, s13);
|
||||
s6 = _mm256_add_epi64(s6, s14);
|
||||
_mm256_storeu_si256(res_ptr + 3, s6);
|
||||
}
|
||||
|
||||
EXPORT void q120x2_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||
__m256i s0 = _mm256_setzero_si256(); // col 1a
|
||||
__m256i s1 = _mm256_setzero_si256();
|
||||
__m256i s2 = _mm256_setzero_si256(); // col 1b
|
||||
__m256i s3 = _mm256_setzero_si256();
|
||||
__m256i s4 = _mm256_set1_epi64x(0xFFFFFFFFUL);
|
||||
__m256i s8, s9, s10, s11;
|
||||
__m256i s12, s13, s14, s15;
|
||||
|
||||
const __m256i* x_ptr = (__m256i*)x;
|
||||
const __m256i* y_ptr = (__m256i*)y;
|
||||
__m256i* res_ptr = (__m256i*)res;
|
||||
for (uint64_t i = 0; i < ell; ++i) {
|
||||
s8 = _mm256_loadu_si256(x_ptr);
|
||||
s9 = _mm256_loadu_si256(x_ptr + 1);
|
||||
s10 = _mm256_srli_epi64(s8, 32);
|
||||
s11 = _mm256_srli_epi64(s9, 32);
|
||||
|
||||
s12 = _mm256_loadu_si256(y_ptr);
|
||||
s13 = _mm256_loadu_si256(y_ptr + 1);
|
||||
s14 = _mm256_srli_epi64(s12, 32);
|
||||
s15 = _mm256_srli_epi64(s13, 32);
|
||||
|
||||
s12 = _mm256_mul_epu32(s8, s12); // -> s0,s1
|
||||
s13 = _mm256_mul_epu32(s9, s13); // -> s2,s3
|
||||
s14 = _mm256_mul_epu32(s10, s14); // -> s0,s1
|
||||
s15 = _mm256_mul_epu32(s11, s15); // -> s2,s3
|
||||
|
||||
s8 = _mm256_and_si256(s12, s4); // -> s0
|
||||
s9 = _mm256_and_si256(s13, s4); // -> s2
|
||||
s10 = _mm256_and_si256(s14, s4); // -> s0
|
||||
s11 = _mm256_and_si256(s15, s4); // -> s2
|
||||
s12 = _mm256_srli_epi64(s12, 32); // -> s1
|
||||
s13 = _mm256_srli_epi64(s13, 32); // -> s3
|
||||
s14 = _mm256_srli_epi64(s14, 32); // -> s1
|
||||
s15 = _mm256_srli_epi64(s15, 32); // -> s3
|
||||
|
||||
s0 = _mm256_add_epi64(s0, s8);
|
||||
s1 = _mm256_add_epi64(s1, s12);
|
||||
s2 = _mm256_add_epi64(s2, s9);
|
||||
s3 = _mm256_add_epi64(s3, s13);
|
||||
s0 = _mm256_add_epi64(s0, s10);
|
||||
s1 = _mm256_add_epi64(s1, s14);
|
||||
s2 = _mm256_add_epi64(s2, s11);
|
||||
s3 = _mm256_add_epi64(s3, s15);
|
||||
|
||||
x_ptr += 2;
|
||||
y_ptr += 2;
|
||||
}
|
||||
// final reduction
|
||||
const uint64_t H2 = precomp->h;
|
||||
s8 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); // MASK2
|
||||
s9 = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); // S2L_POW_RED
|
||||
s10 = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); // S2H_POW_RED
|
||||
//--- s0,s1
|
||||
s11 = _mm256_and_si256(s1, s8);
|
||||
s12 = _mm256_srli_epi64(s1, H2);
|
||||
s13 = _mm256_mul_epu32(s11, s9);
|
||||
s14 = _mm256_mul_epu32(s12, s10);
|
||||
s0 = _mm256_add_epi64(s0, s13);
|
||||
s0 = _mm256_add_epi64(s0, s14);
|
||||
_mm256_storeu_si256(res_ptr + 0, s0);
|
||||
//--- s2,s3
|
||||
s11 = _mm256_and_si256(s3, s8);
|
||||
s12 = _mm256_srli_epi64(s3, H2);
|
||||
s13 = _mm256_mul_epu32(s11, s9);
|
||||
s14 = _mm256_mul_epu32(s12, s10);
|
||||
s2 = _mm256_add_epi64(s2, s13);
|
||||
s2 = _mm256_add_epi64(s2, s14);
|
||||
_mm256_storeu_si256(res_ptr + 1, s2);
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
#ifndef SPQLIOS_Q120_ARITHMETIC_DEF_H
|
||||
#define SPQLIOS_Q120_ARITHMETIC_DEF_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
typedef struct _q120_mat1col_product_baa_precomp {
|
||||
uint64_t h;
|
||||
uint64_t h_pow_red[4];
|
||||
#ifndef NDEBUG
|
||||
double res_bit_size;
|
||||
#endif
|
||||
} q120_mat1col_product_baa_precomp;
|
||||
|
||||
typedef struct _q120_mat1col_product_bbb_precomp {
|
||||
uint64_t h;
|
||||
uint64_t s1h_pow_red[4];
|
||||
uint64_t s2l_pow_red[4];
|
||||
uint64_t s2h_pow_red[4];
|
||||
uint64_t s3l_pow_red[4];
|
||||
uint64_t s3h_pow_red[4];
|
||||
uint64_t s4l_pow_red[4];
|
||||
uint64_t s4h_pow_red[4];
|
||||
#ifndef NDEBUG
|
||||
double res_bit_size;
|
||||
#endif
|
||||
} q120_mat1col_product_bbb_precomp;
|
||||
|
||||
typedef struct _q120_mat1col_product_bbc_precomp {
|
||||
uint64_t h;
|
||||
uint64_t s2l_pow_red[4];
|
||||
uint64_t s2h_pow_red[4];
|
||||
#ifndef NDEBUG
|
||||
double res_bit_size;
|
||||
#endif
|
||||
} q120_mat1col_product_bbc_precomp;
|
||||
|
||||
#endif // SPQLIOS_Q120_ARITHMETIC_DEF_H
|
||||
@@ -1,506 +0,0 @@
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "q120_arithmetic.h"
|
||||
#include "q120_arithmetic_private.h"
|
||||
#include "q120_common.h"
|
||||
|
||||
#define MODQ(val, q) ((val) % (q))
|
||||
|
||||
double comp_bit_size_red(const uint64_t h, const uint64_t qs[4]) {
|
||||
assert(h < 128);
|
||||
double h_pow2_bs = 0;
|
||||
for (uint64_t k = 0; k < 4; ++k) {
|
||||
double t = log2((double)MODQ((__uint128_t)1 << h, qs[k]));
|
||||
if (t > h_pow2_bs) h_pow2_bs = t;
|
||||
}
|
||||
return h_pow2_bs;
|
||||
}
|
||||
|
||||
double comp_bit_size_sum(const uint64_t n, const double* const bs) {
|
||||
double s = 0;
|
||||
for (uint64_t i = 0; i < n; ++i) {
|
||||
s += pow(2, bs[i]);
|
||||
}
|
||||
return log2(s);
|
||||
}
|
||||
|
||||
void vec_mat1col_product_baa_precomp(q120_mat1col_product_baa_precomp* precomp) {
|
||||
uint64_t qs[4] = {Q1, Q2, Q3, Q4};
|
||||
|
||||
double min_res_bs = 1000;
|
||||
uint64_t min_h = -1;
|
||||
|
||||
double ell_bs = log2((double)MAX_ELL);
|
||||
for (uint64_t h = 1; h < 64; ++h) {
|
||||
double h_pow2_bs = comp_bit_size_red(h, qs);
|
||||
|
||||
const double bs[] = {h + ell_bs, 64 - h + ell_bs + h_pow2_bs};
|
||||
const double res_bs = comp_bit_size_sum(2, bs);
|
||||
|
||||
if (min_res_bs > res_bs) {
|
||||
min_res_bs = res_bs;
|
||||
min_h = h;
|
||||
}
|
||||
}
|
||||
|
||||
assert(min_res_bs < 64);
|
||||
precomp->h = min_h;
|
||||
for (uint64_t k = 0; k < 4; ++k) {
|
||||
precomp->h_pow_red[k] = MODQ(UINT64_C(1) << precomp->h, qs[k]);
|
||||
}
|
||||
#ifndef NDEBUG
|
||||
precomp->res_bit_size = min_res_bs;
|
||||
#endif
|
||||
// printf("AA %lu %lf\n", min_h, min_res_bs);
|
||||
}
|
||||
|
||||
EXPORT q120_mat1col_product_baa_precomp* q120_new_vec_mat1col_product_baa_precomp() {
|
||||
q120_mat1col_product_baa_precomp* res = malloc(sizeof(q120_mat1col_product_baa_precomp));
|
||||
vec_mat1col_product_baa_precomp(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
EXPORT void q120_delete_vec_mat1col_product_baa_precomp(q120_mat1col_product_baa_precomp* addr) { free(addr); }
|
||||
|
||||
EXPORT void q120_vec_mat1col_product_baa_ref(q120_mat1col_product_baa_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120a* const x, const q120a* const y) {
|
||||
/**
|
||||
* Algorithm:
|
||||
* - res = acc1 + acc2 . ((2^H) % Q)
|
||||
* - acc1 is the sum of H LSB of products x[i].y[i]
|
||||
* - acc2 is the sum of 64-H MSB of products x[i]].y[i]
|
||||
* - for l < 10k acc1 will have H + log2(10000) and acc2 64 - H + log2(10000) bits
|
||||
* - final sum has max(H, 64 - H + bit_size((2^H) % Q)) + log2(10000) + 1 bits
|
||||
*/
|
||||
const uint64_t H = precomp->h;
|
||||
const uint64_t MASK = (UINT64_C(1) << H) - 1;
|
||||
|
||||
uint64_t acc1[4] = {0, 0, 0, 0}; // accumulate H least significant bits of product
|
||||
uint64_t acc2[4] = {0, 0, 0, 0}; // accumulate 64 - H most significan bits of product
|
||||
|
||||
const uint64_t* const x_ptr = (uint64_t*)x;
|
||||
const uint64_t* const y_ptr = (uint64_t*)y;
|
||||
|
||||
for (uint64_t i = 0; i < 4 * ell; i += 4) {
|
||||
for (uint64_t j = 0; j < 4; ++j) {
|
||||
uint64_t t = x_ptr[i + j] * y_ptr[i + j];
|
||||
acc1[j] += t & MASK;
|
||||
acc2[j] += t >> H;
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t* const res_ptr = (uint64_t*)res;
|
||||
for (uint64_t j = 0; j < 4; ++j) {
|
||||
res_ptr[j] = acc1[j] + acc2[j] * precomp->h_pow_red[j];
|
||||
assert(log2(res_ptr[j]) < precomp->res_bit_size);
|
||||
}
|
||||
}
|
||||
|
||||
void vec_mat1col_product_bbb_precomp(q120_mat1col_product_bbb_precomp* precomp) {
|
||||
uint64_t qs[4] = {Q1, Q2, Q3, Q4};
|
||||
|
||||
double ell_bs = log2((double)MAX_ELL);
|
||||
double min_res_bs = 1000;
|
||||
uint64_t min_h = -1;
|
||||
|
||||
const double s1_bs = 32 + ell_bs;
|
||||
const double s2_bs = 32 + ell_bs + log2(3);
|
||||
const double s3_bs = 32 + ell_bs + log2(3);
|
||||
const double s4_bs = 32 + ell_bs;
|
||||
for (uint64_t h = 16; h < 32; ++h) {
|
||||
const double s1l_bs = h;
|
||||
const double s1h_bs = (s1_bs - h) + comp_bit_size_red(h, qs);
|
||||
const double s2l_bs = h + comp_bit_size_red(32, qs);
|
||||
const double s2h_bs = (s2_bs - h) + comp_bit_size_red(32 + h, qs);
|
||||
const double s3l_bs = h + comp_bit_size_red(64, qs);
|
||||
const double s3h_bs = (s3_bs - h) + comp_bit_size_red(64 + h, qs);
|
||||
const double s4l_bs = h + comp_bit_size_red(96, qs);
|
||||
const double s4h_bs = (s4_bs - h) + comp_bit_size_red(96 + h, qs);
|
||||
|
||||
const double bs[] = {s1l_bs, s1h_bs, s2l_bs, s2h_bs, s3l_bs, s3h_bs, s4l_bs, s4h_bs};
|
||||
const double res_bs = comp_bit_size_sum(8, bs);
|
||||
|
||||
if (min_res_bs > res_bs) {
|
||||
min_res_bs = res_bs;
|
||||
min_h = h;
|
||||
}
|
||||
}
|
||||
|
||||
assert(min_res_bs < 64);
|
||||
precomp->h = min_h;
|
||||
for (uint64_t k = 0; k < 4; ++k) {
|
||||
precomp->s1h_pow_red[k] = UINT64_C(1) << precomp->h; // 2^24
|
||||
precomp->s2l_pow_red[k] = MODQ(UINT64_C(1) << 32, qs[k]); // 2^32
|
||||
precomp->s2h_pow_red[k] = MODQ(precomp->s2l_pow_red[k] * precomp->s1h_pow_red[k], qs[k]); // 2^(32+24)
|
||||
precomp->s3l_pow_red[k] = MODQ(precomp->s2l_pow_red[k] * precomp->s2l_pow_red[k], qs[k]); // 2^64 = 2^(32+32)
|
||||
precomp->s3h_pow_red[k] = MODQ(precomp->s3l_pow_red[k] * precomp->s1h_pow_red[k], qs[k]); // 2^(64+24)
|
||||
precomp->s4l_pow_red[k] = MODQ(precomp->s3l_pow_red[k] * precomp->s2l_pow_red[k], qs[k]); // 2^96 = 2^(64+32)
|
||||
precomp->s4h_pow_red[k] = MODQ(precomp->s4l_pow_red[k] * precomp->s1h_pow_red[k], qs[k]); // 2^(96+24)
|
||||
}
|
||||
// printf("AA %lu %lf\n", min_h, min_res_bs);
|
||||
#ifndef NDEBUG
|
||||
precomp->res_bit_size = min_res_bs;
|
||||
#endif
|
||||
}
|
||||
|
||||
EXPORT q120_mat1col_product_bbb_precomp* q120_new_vec_mat1col_product_bbb_precomp() {
|
||||
q120_mat1col_product_bbb_precomp* res = malloc(sizeof(q120_mat1col_product_bbb_precomp));
|
||||
vec_mat1col_product_bbb_precomp(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
EXPORT void q120_delete_vec_mat1col_product_bbb_precomp(q120_mat1col_product_bbb_precomp* addr) { free(addr); }
|
||||
|
||||
EXPORT void q120_vec_mat1col_product_bbb_ref(q120_mat1col_product_bbb_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120b* const x, const q120b* const y) {
|
||||
/**
|
||||
* Algorithm:
|
||||
* 1. Split x_i and y_i in 2 32-bit parts and compute the cross-products:
|
||||
* - x_i = xl_i + xh_i . 2^32
|
||||
* - y_i = yl_i + yh_i . 2^32
|
||||
* - A_i = xl_i . yl_i
|
||||
* - B_i = xl_i . yh_i
|
||||
* - C_i = xh_i . yl_i
|
||||
* - D_i = xh_i . yh_i
|
||||
* - we have x_i . y_i == A_i + (B_i + C_i) . 2^32 + D_i . 2^64
|
||||
* 2. Split A_i, B_i, C_i and D_i into 2 32-bit parts
|
||||
* - A_i = Al_i + Ah_i . 2^32
|
||||
* - B_i = Bl_i + Bh_i . 2^32
|
||||
* - C_i = Cl_i + Ch_i . 2^32
|
||||
* - D_i = Dl_i + Dh_i . 2^32
|
||||
* 3. Compute the sums:
|
||||
* - S1 = \sum Al_i
|
||||
* - S2 = \sum (Ah_i + Bl_i + Cl_i)
|
||||
* - S3 = \sum (Bh_i + Ch_i + Dl_i)
|
||||
* - S4 = \sum Dh_i
|
||||
* - here S1, S4 have 32 + log2(ell) bits and S2, S3 have 32 + log2(ell) +
|
||||
* log2(3) bits
|
||||
* - for ell == 10000 S2, S3 have < 47 bits
|
||||
* 4. Split S1, S2, S3 and S4 in 2 24-bit parts (24 = ceil(47/2))
|
||||
* - S1 = S1l + S1h . 2^24
|
||||
* - S2 = S2l + S2h . 2^24
|
||||
* - S3 = S3l + S3h . 2^24
|
||||
* - S4 = S4l + S4h . 2^24
|
||||
* 5. Compute final result as:
|
||||
* - \sum x_i . y_i = S1l + S1h . 2^24
|
||||
* + S2l . 2^32 + S2h . 2^(32+24)
|
||||
* + S3l . 2^64 + S3h . 2^(64 + 24)
|
||||
* + S4l . 2^96 + S4l . 2^(96+24)
|
||||
* - here the powers of 2 are reduced modulo the primes Q before
|
||||
* multiplications
|
||||
* - the result will be on 24 + 3 + bit size of primes Q
|
||||
*/
|
||||
const uint64_t H1 = 32;
|
||||
const uint64_t MASK1 = (UINT64_C(1) << H1) - 1;
|
||||
|
||||
uint64_t s1[4] = {0, 0, 0, 0};
|
||||
uint64_t s2[4] = {0, 0, 0, 0};
|
||||
uint64_t s3[4] = {0, 0, 0, 0};
|
||||
uint64_t s4[4] = {0, 0, 0, 0};
|
||||
|
||||
const uint64_t* const x_ptr = (uint64_t*)x;
|
||||
const uint64_t* const y_ptr = (uint64_t*)y;
|
||||
|
||||
for (uint64_t i = 0; i < 4 * ell; i += 4) {
|
||||
for (uint64_t j = 0; j < 4; ++j) {
|
||||
const uint64_t xl = x_ptr[i + j] & MASK1;
|
||||
const uint64_t xh = x_ptr[i + j] >> H1;
|
||||
const uint64_t yl = y_ptr[i + j] & MASK1;
|
||||
const uint64_t yh = y_ptr[i + j] >> H1;
|
||||
|
||||
const uint64_t a = xl * yl;
|
||||
const uint64_t al = a & MASK1;
|
||||
const uint64_t ah = a >> H1;
|
||||
|
||||
const uint64_t b = xl * yh;
|
||||
const uint64_t bl = b & MASK1;
|
||||
const uint64_t bh = b >> H1;
|
||||
|
||||
const uint64_t c = xh * yl;
|
||||
const uint64_t cl = c & MASK1;
|
||||
const uint64_t ch = c >> H1;
|
||||
|
||||
const uint64_t d = xh * yh;
|
||||
const uint64_t dl = d & MASK1;
|
||||
const uint64_t dh = d >> H1;
|
||||
|
||||
s1[j] += al;
|
||||
s2[j] += ah + bl + cl;
|
||||
s3[j] += bh + ch + dl;
|
||||
s4[j] += dh;
|
||||
}
|
||||
}
|
||||
|
||||
const uint64_t H2 = precomp->h;
|
||||
const uint64_t MASK2 = (UINT64_C(1) << H2) - 1;
|
||||
|
||||
uint64_t* const res_ptr = (uint64_t*)res;
|
||||
for (uint64_t j = 0; j < 4; ++j) {
|
||||
const uint64_t s1l = s1[j] & MASK2;
|
||||
const uint64_t s1h = s1[j] >> H2;
|
||||
const uint64_t s2l = s2[j] & MASK2;
|
||||
const uint64_t s2h = s2[j] >> H2;
|
||||
const uint64_t s3l = s3[j] & MASK2;
|
||||
const uint64_t s3h = s3[j] >> H2;
|
||||
const uint64_t s4l = s4[j] & MASK2;
|
||||
const uint64_t s4h = s4[j] >> H2;
|
||||
|
||||
uint64_t t = s1l;
|
||||
t += s1h * precomp->s1h_pow_red[j];
|
||||
t += s2l * precomp->s2l_pow_red[j];
|
||||
t += s2h * precomp->s2h_pow_red[j];
|
||||
t += s3l * precomp->s3l_pow_red[j];
|
||||
t += s3h * precomp->s3h_pow_red[j];
|
||||
t += s4l * precomp->s4l_pow_red[j];
|
||||
t += s4h * precomp->s4h_pow_red[j];
|
||||
|
||||
res_ptr[j] = t;
|
||||
assert(log2(res_ptr[j]) < precomp->res_bit_size);
|
||||
}
|
||||
}
|
||||
|
||||
void vec_mat1col_product_bbc_precomp(q120_mat1col_product_bbc_precomp* precomp) {
|
||||
uint64_t qs[4] = {Q1, Q2, Q3, Q4};
|
||||
|
||||
double min_res_bs = 1000;
|
||||
uint64_t min_h = -1;
|
||||
|
||||
double pow2_32_bs = comp_bit_size_red(32, qs);
|
||||
|
||||
double ell_bs = log2((double)MAX_ELL);
|
||||
double s1_bs = 32 + ell_bs;
|
||||
for (uint64_t h = 16; h < 32; ++h) {
|
||||
double s2l_bs = pow2_32_bs + h;
|
||||
double s2h_bs = s1_bs - h + comp_bit_size_red(32 + h, qs);
|
||||
|
||||
const double bs[] = {s1_bs, s2l_bs, s2h_bs};
|
||||
const double res_bs = comp_bit_size_sum(3, bs);
|
||||
|
||||
if (min_res_bs > res_bs) {
|
||||
min_res_bs = res_bs;
|
||||
min_h = h;
|
||||
}
|
||||
}
|
||||
|
||||
assert(min_res_bs < 64);
|
||||
precomp->h = min_h;
|
||||
for (uint64_t k = 0; k < 4; ++k) {
|
||||
precomp->s2l_pow_red[k] = MODQ(UINT64_C(1) << 32, qs[k]);
|
||||
precomp->s2h_pow_red[k] = MODQ(UINT64_C(1) << (32 + precomp->h), qs[k]);
|
||||
}
|
||||
#ifndef NDEBUG
|
||||
precomp->res_bit_size = min_res_bs;
|
||||
#endif
|
||||
// printf("AA %lu %lf\n", min_h, min_res_bs);
|
||||
}
|
||||
|
||||
EXPORT q120_mat1col_product_bbc_precomp* q120_new_vec_mat1col_product_bbc_precomp() {
|
||||
q120_mat1col_product_bbc_precomp* res = malloc(sizeof(q120_mat1col_product_bbc_precomp));
|
||||
vec_mat1col_product_bbc_precomp(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
EXPORT void q120_delete_vec_mat1col_product_bbc_precomp(q120_mat1col_product_bbc_precomp* addr) { free(addr); }
|
||||
|
||||
EXPORT void q120_vec_mat1col_product_bbc_ref_old(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||
/**
|
||||
* Algorithm:
|
||||
* 0. We have
|
||||
* - y0_i == y_i % Q and y1_i == (y_i . 2^32) % Q
|
||||
* 1. Split x_i in 2 32-bit parts and compute the cross-products:
|
||||
* - x_i = xl_i + xh_i . 2^32
|
||||
* - A_i = xl_i . y1_i
|
||||
* - B_i = xh_i . y2_i
|
||||
* - we have x_i . y_i == A_i + B_i
|
||||
* 2. Split A_i and B_i into 2 32-bit parts
|
||||
* - A_i = Al_i + Ah_i . 2^32
|
||||
* - B_i = Bl_i + Bh_i . 2^32
|
||||
* 3. Compute the sums:
|
||||
* - S1 = \sum Al_i + Bl_i
|
||||
* - S2 = \sum Ah_i + Bh_i
|
||||
* - here S1 and S2 have 32 + log2(ell) bits
|
||||
* - for ell == 10000 S1, S2 have < 46 bits
|
||||
* 4. Split S2 in 27-bit and 19-bit parts (27+19 == 46)
|
||||
* - S2 = S2l + S2h . 2^27
|
||||
* 5. Compute final result as:
|
||||
* - \sum x_i . y_i = S1 + S2l . 2^32 + S2h . 2^(32+27)
|
||||
* - here the powers of 2 are reduced modulo the primes Q before
|
||||
* multiplications
|
||||
* - the result will be on < 52 bits
|
||||
*/
|
||||
|
||||
const uint64_t H1 = 32;
|
||||
const uint64_t MASK1 = (UINT64_C(1) << H1) - 1;
|
||||
|
||||
uint64_t s1[4] = {0, 0, 0, 0};
|
||||
uint64_t s2[4] = {0, 0, 0, 0};
|
||||
|
||||
const uint64_t* const x_ptr = (uint64_t*)x;
|
||||
const uint32_t* const y_ptr = (uint32_t*)y;
|
||||
|
||||
for (uint64_t i = 0; i < 4 * ell; i += 4) {
|
||||
for (uint64_t j = 0; j < 4; ++j) {
|
||||
const uint64_t xl = x_ptr[i + j] & MASK1;
|
||||
const uint64_t xh = x_ptr[i + j] >> H1;
|
||||
const uint64_t y0 = y_ptr[2 * (i + j)];
|
||||
const uint64_t y1 = y_ptr[2 * (i + j) + 1];
|
||||
|
||||
const uint64_t a = xl * y0;
|
||||
const uint64_t al = a & MASK1;
|
||||
const uint64_t ah = a >> H1;
|
||||
|
||||
const uint64_t b = xh * y1;
|
||||
const uint64_t bl = b & MASK1;
|
||||
const uint64_t bh = b >> H1;
|
||||
|
||||
s1[j] += al + bl;
|
||||
s2[j] += ah + bh;
|
||||
}
|
||||
}
|
||||
|
||||
const uint64_t H2 = precomp->h;
|
||||
const uint64_t MASK2 = (UINT64_C(1) << H2) - 1;
|
||||
|
||||
uint64_t* const res_ptr = (uint64_t*)res;
|
||||
for (uint64_t k = 0; k < 4; ++k) {
|
||||
const uint64_t s2l = s2[k] & MASK2;
|
||||
const uint64_t s2h = s2[k] >> H2;
|
||||
|
||||
uint64_t t = s1[k];
|
||||
t += s2l * precomp->s2l_pow_red[k];
|
||||
t += s2h * precomp->s2h_pow_red[k];
|
||||
|
||||
res_ptr[k] = t;
|
||||
assert(log2(res_ptr[k]) < precomp->res_bit_size);
|
||||
}
|
||||
}
|
||||
|
||||
static __always_inline void accum_mul_q120_bc(uint64_t res[8], //
|
||||
const uint32_t x_layb[8], const uint32_t y_layc[8]) {
|
||||
for (uint64_t i = 0; i < 4; ++i) {
|
||||
static const uint64_t MASK32 = 0xFFFFFFFFUL;
|
||||
uint64_t x_lo = x_layb[2 * i];
|
||||
uint64_t x_hi = x_layb[2 * i + 1];
|
||||
uint64_t y_lo = y_layc[2 * i];
|
||||
uint64_t y_hi = y_layc[2 * i + 1];
|
||||
uint64_t xy_lo = x_lo * y_lo;
|
||||
uint64_t xy_hi = x_hi * y_hi;
|
||||
res[2 * i] += (xy_lo & MASK32) + (xy_hi & MASK32);
|
||||
res[2 * i + 1] += (xy_lo >> 32) + (xy_hi >> 32);
|
||||
}
|
||||
}
|
||||
|
||||
static __always_inline void accum_to_q120b(uint64_t res[4], //
|
||||
const uint64_t s[8], const q120_mat1col_product_bbc_precomp* precomp) {
|
||||
const uint64_t H2 = precomp->h;
|
||||
const uint64_t MASK2 = (UINT64_C(1) << H2) - 1;
|
||||
for (uint64_t k = 0; k < 4; ++k) {
|
||||
const uint64_t s2l = s[2 * k + 1] & MASK2;
|
||||
const uint64_t s2h = s[2 * k + 1] >> H2;
|
||||
uint64_t t = s[2 * k];
|
||||
t += s2l * precomp->s2l_pow_red[k];
|
||||
t += s2h * precomp->s2h_pow_red[k];
|
||||
res[k] = t;
|
||||
assert(log2(res[k]) < precomp->res_bit_size);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void q120_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||
uint64_t s[8] = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
const uint32_t(*const x_ptr)[8] = (const uint32_t(*const)[8])x;
|
||||
const uint32_t(*const y_ptr)[8] = (const uint32_t(*const)[8])y;
|
||||
|
||||
for (uint64_t i = 0; i < ell; i++) {
|
||||
accum_mul_q120_bc(s, x_ptr[i], y_ptr[i]);
|
||||
}
|
||||
accum_to_q120b((uint64_t*)res, s, precomp);
|
||||
}
|
||||
|
||||
EXPORT void q120x2_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||
uint64_t s[2][16] = {0};
|
||||
|
||||
const uint32_t(*const x_ptr)[2][8] = (const uint32_t(*const)[2][8])x;
|
||||
const uint32_t(*const y_ptr)[2][8] = (const uint32_t(*const)[2][8])y;
|
||||
uint64_t(*re)[4] = (uint64_t(*)[4])res;
|
||||
|
||||
for (uint64_t i = 0; i < ell; i++) {
|
||||
accum_mul_q120_bc(s[0], x_ptr[i][0], y_ptr[i][0]);
|
||||
accum_mul_q120_bc(s[1], x_ptr[i][1], y_ptr[i][1]);
|
||||
}
|
||||
accum_to_q120b(re[0], s[0], precomp);
|
||||
accum_to_q120b(re[1], s[1], precomp);
|
||||
}
|
||||
|
||||
EXPORT void q120x2_vec_mat2cols_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell,
|
||||
q120b* const res, const q120b* const x, const q120c* const y) {
|
||||
uint64_t s[4][16] = {0};
|
||||
|
||||
const uint32_t(*const x_ptr)[2][8] = (const uint32_t(*const)[2][8])x;
|
||||
const uint32_t(*const y_ptr)[4][8] = (const uint32_t(*const)[4][8])y;
|
||||
uint64_t(*re)[4] = (uint64_t(*)[4])res;
|
||||
|
||||
for (uint64_t i = 0; i < ell; i++) {
|
||||
accum_mul_q120_bc(s[0], x_ptr[i][0], y_ptr[i][0]);
|
||||
accum_mul_q120_bc(s[1], x_ptr[i][1], y_ptr[i][1]);
|
||||
accum_mul_q120_bc(s[2], x_ptr[i][0], y_ptr[i][2]);
|
||||
accum_mul_q120_bc(s[3], x_ptr[i][1], y_ptr[i][3]);
|
||||
}
|
||||
accum_to_q120b(re[0], s[0], precomp);
|
||||
accum_to_q120b(re[1], s[1], precomp);
|
||||
accum_to_q120b(re[2], s[2], precomp);
|
||||
accum_to_q120b(re[3], s[3], precomp);
|
||||
}
|
||||
|
||||
EXPORT void q120x2_extract_1blk_from_q120b_ref(uint64_t nn, uint64_t blk,
|
||||
q120x2b* const dst, // 8 doubles
|
||||
const q120b* const src // a q120b vector
|
||||
) {
|
||||
const uint64_t* in = (uint64_t*)src;
|
||||
uint64_t* out = (uint64_t*)dst;
|
||||
for (uint64_t i = 0; i < 8; ++i) {
|
||||
out[i] = in[8 * blk + i];
|
||||
}
|
||||
}
|
||||
|
||||
// function on layout c is the exact same as on layout b
|
||||
#ifdef __APPLE__
|
||||
#pragma weak q120x2_extract_1blk_from_q120c_ref = q120x2_extract_1blk_from_q120b_ref
|
||||
#else
|
||||
EXPORT void q120x2_extract_1blk_from_q120c_ref(uint64_t nn, uint64_t blk,
|
||||
q120x2c* const dst, // 8 doubles
|
||||
const q120c* const src // a q120c vector
|
||||
) __attribute__((alias("q120x2_extract_1blk_from_q120b_ref")));
|
||||
#endif
|
||||
|
||||
EXPORT void q120x2_extract_1blk_from_contiguous_q120b_ref(
|
||||
uint64_t nn, uint64_t nrows, uint64_t blk,
|
||||
q120x2b* const dst, // nrows * 2 q120
|
||||
const q120b* const src // a contiguous array of nrows q120b vectors
|
||||
) {
|
||||
const uint64_t* in = (uint64_t*)src;
|
||||
uint64_t* out = (uint64_t*)dst;
|
||||
for (uint64_t row = 0; row < nrows; ++row) {
|
||||
for (uint64_t i = 0; i < 8; ++i) {
|
||||
out[i] = in[8 * blk + i];
|
||||
}
|
||||
in += 4 * nn;
|
||||
out += 8;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void q120x2b_save_1blk_to_q120b_ref(uint64_t nn, uint64_t blk,
|
||||
q120b* dest, // 1 reim vector of length m
|
||||
const q120x2b* src // 8 doubles
|
||||
) {
|
||||
const uint64_t* in = (uint64_t*)src;
|
||||
uint64_t* out = (uint64_t*)dest;
|
||||
for (uint64_t i = 0; i < 8; ++i) {
|
||||
out[8 * blk + i] = in[i];
|
||||
}
|
||||
}
|
||||
@@ -1,111 +0,0 @@
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "q120_arithmetic.h"
|
||||
#include "q120_common.h"
|
||||
|
||||
EXPORT void q120_add_bbb_simple(uint64_t nn, q120b* const res, const q120b* const x, const q120b* const y) {
|
||||
const uint64_t* x_u64 = (uint64_t*)x;
|
||||
const uint64_t* y_u64 = (uint64_t*)y;
|
||||
uint64_t* res_u64 = (uint64_t*)res;
|
||||
for (uint64_t i = 0; i < 4 * nn; i += 4) {
|
||||
res_u64[i + 0] = x_u64[i + 0] % ((uint64_t)Q1 << 33) + y_u64[i + 0] % ((uint64_t)Q1 << 33);
|
||||
res_u64[i + 1] = x_u64[i + 1] % ((uint64_t)Q2 << 33) + y_u64[i + 1] % ((uint64_t)Q2 << 33);
|
||||
res_u64[i + 2] = x_u64[i + 2] % ((uint64_t)Q3 << 33) + y_u64[i + 2] % ((uint64_t)Q3 << 33);
|
||||
res_u64[i + 3] = x_u64[i + 3] % ((uint64_t)Q4 << 33) + y_u64[i + 3] % ((uint64_t)Q4 << 33);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void q120_add_ccc_simple(uint64_t nn, q120c* const res, const q120c* const x, const q120c* const y) {
|
||||
const uint32_t* x_u32 = (uint32_t*)x;
|
||||
const uint32_t* y_u32 = (uint32_t*)y;
|
||||
uint32_t* res_u32 = (uint32_t*)res;
|
||||
for (uint64_t i = 0; i < 8 * nn; i += 8) {
|
||||
res_u32[i + 0] = (uint32_t)(((uint64_t)x_u32[i + 0] + (uint64_t)y_u32[i + 0]) % Q1);
|
||||
res_u32[i + 1] = (uint32_t)(((uint64_t)x_u32[i + 1] + (uint64_t)y_u32[i + 1]) % Q1);
|
||||
res_u32[i + 2] = (uint32_t)(((uint64_t)x_u32[i + 2] + (uint64_t)y_u32[i + 2]) % Q2);
|
||||
res_u32[i + 3] = (uint32_t)(((uint64_t)x_u32[i + 3] + (uint64_t)y_u32[i + 3]) % Q2);
|
||||
res_u32[i + 4] = (uint32_t)(((uint64_t)x_u32[i + 4] + (uint64_t)y_u32[i + 4]) % Q3);
|
||||
res_u32[i + 5] = (uint32_t)(((uint64_t)x_u32[i + 5] + (uint64_t)y_u32[i + 5]) % Q3);
|
||||
res_u32[i + 6] = (uint32_t)(((uint64_t)x_u32[i + 6] + (uint64_t)y_u32[i + 6]) % Q4);
|
||||
res_u32[i + 7] = (uint32_t)(((uint64_t)x_u32[i + 7] + (uint64_t)y_u32[i + 7]) % Q4);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void q120_c_from_b_simple(uint64_t nn, q120c* const res, const q120b* const x) {
|
||||
const uint64_t* x_u64 = (uint64_t*)x;
|
||||
uint32_t* res_u32 = (uint32_t*)res;
|
||||
for (uint64_t i = 0, j = 0; i < 4 * nn; i += 4, j += 8) {
|
||||
res_u32[j + 0] = x_u64[i + 0] % Q1;
|
||||
res_u32[j + 1] = ((uint64_t)res_u32[j + 0] << 32) % Q1;
|
||||
res_u32[j + 2] = x_u64[i + 1] % Q2;
|
||||
res_u32[j + 3] = ((uint64_t)res_u32[j + 2] << 32) % Q2;
|
||||
res_u32[j + 4] = x_u64[i + 2] % Q3;
|
||||
res_u32[j + 5] = ((uint64_t)res_u32[j + 4] << 32) % Q3;
|
||||
res_u32[j + 6] = x_u64[i + 3] % Q4;
|
||||
res_u32[j + 7] = ((uint64_t)res_u32[j + 6] << 32) % Q4;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void q120_b_from_znx64_simple(uint64_t nn, q120b* const res, const int64_t* const x) {
|
||||
static const int64_t MASK_HI = INT64_C(0x8000000000000000);
|
||||
static const int64_t MASK_LO = ~MASK_HI;
|
||||
static const uint64_t OQ[4] = {
|
||||
(Q1 - (UINT64_C(0x8000000000000000) % Q1)),
|
||||
(Q2 - (UINT64_C(0x8000000000000000) % Q2)),
|
||||
(Q3 - (UINT64_C(0x8000000000000000) % Q3)),
|
||||
(Q4 - (UINT64_C(0x8000000000000000) % Q4)),
|
||||
};
|
||||
uint64_t* res_u64 = (uint64_t*)res;
|
||||
for (uint64_t i = 0, j = 0; j < nn; i += 4, ++j) {
|
||||
uint64_t xj_lo = x[j] & MASK_LO;
|
||||
uint64_t xj_hi = x[j] & MASK_HI;
|
||||
res_u64[i + 0] = xj_lo + (xj_hi ? OQ[0] : 0);
|
||||
res_u64[i + 1] = xj_lo + (xj_hi ? OQ[1] : 0);
|
||||
res_u64[i + 2] = xj_lo + (xj_hi ? OQ[2] : 0);
|
||||
res_u64[i + 3] = xj_lo + (xj_hi ? OQ[3] : 0);
|
||||
}
|
||||
}
|
||||
|
||||
static int64_t posmod(int64_t x, int64_t q) {
|
||||
int64_t t = x % q;
|
||||
if (t < 0)
|
||||
return t + q;
|
||||
else
|
||||
return t;
|
||||
}
|
||||
|
||||
EXPORT void q120_c_from_znx64_simple(uint64_t nn, q120c* const res, const int64_t* const x) {
|
||||
uint32_t* res_u32 = (uint32_t*)res;
|
||||
for (uint64_t i = 0, j = 0; j < nn; i += 8, ++j) {
|
||||
res_u32[i + 0] = posmod(x[j], Q1);
|
||||
res_u32[i + 1] = ((uint64_t)res_u32[i + 0] << 32) % Q1;
|
||||
res_u32[i + 2] = posmod(x[j], Q2);
|
||||
res_u32[i + 3] = ((uint64_t)res_u32[i + 2] << 32) % Q2;
|
||||
res_u32[i + 4] = posmod(x[j], Q3);
|
||||
res_u32[i + 5] = ((uint64_t)res_u32[i + 4] << 32) % Q3;
|
||||
res_u32[i + 6] = posmod(x[j], Q4);
|
||||
res_u32[i + 7] = ((uint64_t)res_u32[i + 6] << 32) % Q4;
|
||||
;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void q120_b_to_znx128_simple(uint64_t nn, __int128_t* const res, const q120b* const x) {
|
||||
static const __int128_t Q = (__int128_t)Q1 * Q2 * Q3 * Q4;
|
||||
static const __int128_t Qm1 = (__int128_t)Q2 * Q3 * Q4;
|
||||
static const __int128_t Qm2 = (__int128_t)Q1 * Q3 * Q4;
|
||||
static const __int128_t Qm3 = (__int128_t)Q1 * Q2 * Q4;
|
||||
static const __int128_t Qm4 = (__int128_t)Q1 * Q2 * Q3;
|
||||
|
||||
const uint64_t* x_u64 = (uint64_t*)x;
|
||||
for (uint64_t i = 0, j = 0; j < nn; i += 4, ++j) {
|
||||
__int128_t tmp = 0;
|
||||
tmp += (((x_u64[i + 0] % Q1) * Q1_CRT_CST) % Q1) * Qm1;
|
||||
tmp += (((x_u64[i + 1] % Q2) * Q2_CRT_CST) % Q2) * Qm2;
|
||||
tmp += (((x_u64[i + 2] % Q3) * Q3_CRT_CST) % Q3) * Qm3;
|
||||
tmp += (((x_u64[i + 3] % Q4) * Q4_CRT_CST) % Q4) * Qm4;
|
||||
tmp %= Q;
|
||||
res[j] = (tmp >= (Q + 1) / 2) ? tmp - Q : tmp;
|
||||
}
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
#ifndef SPQLIOS_Q120_COMMON_H
|
||||
#define SPQLIOS_Q120_COMMON_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#if !defined(SPQLIOS_Q120_USE_29_BIT_PRIMES) && !defined(SPQLIOS_Q120_USE_30_BIT_PRIMES) && \
|
||||
!defined(SPQLIOS_Q120_USE_31_BIT_PRIMES)
|
||||
#define SPQLIOS_Q120_USE_30_BIT_PRIMES
|
||||
#endif
|
||||
|
||||
/**
|
||||
* 29-bit primes and 2*2^16 roots of unity
|
||||
*/
|
||||
#ifdef SPQLIOS_Q120_USE_29_BIT_PRIMES
|
||||
#define Q1 ((1u << 29) - 2 * (1u << 17) + 1)
|
||||
#define OMEGA1 78289835
|
||||
#define Q1_CRT_CST 301701286 // (Q2*Q3*Q4)^-1 mod Q1
|
||||
|
||||
#define Q2 ((1u << 29) - 5 * (1u << 17) + 1)
|
||||
#define OMEGA2 178519192
|
||||
#define Q2_CRT_CST 536020447 // (Q1*Q3*Q4)^-1 mod Q2
|
||||
|
||||
#define Q3 ((1u << 29) - 26 * (1u << 17) + 1)
|
||||
#define OMEGA3 483889678
|
||||
#define Q3_CRT_CST 86367873 // (Q1*Q2*Q4)^-1 mod Q3
|
||||
|
||||
#define Q4 ((1u << 29) - 35 * (1u << 17) + 1)
|
||||
#define OMEGA4 239808033
|
||||
#define Q4_CRT_CST 147030781 // (Q1*Q2*Q3)^-1 mod Q4
|
||||
#endif
|
||||
|
||||
/**
|
||||
* 30-bit primes and 2*2^16 roots of unity
|
||||
*/
|
||||
#ifdef SPQLIOS_Q120_USE_30_BIT_PRIMES
|
||||
#define Q1 ((1u << 30) - 2 * (1u << 17) + 1)
|
||||
#define OMEGA1 1070907127
|
||||
#define Q1_CRT_CST 43599465 // (Q2*Q3*Q4)^-1 mod Q1
|
||||
|
||||
#define Q2 ((1u << 30) - 17 * (1u << 17) + 1)
|
||||
#define OMEGA2 315046632
|
||||
#define Q2_CRT_CST 292938863 // (Q1*Q3*Q4)^-1 mod Q2
|
||||
|
||||
#define Q3 ((1u << 30) - 23 * (1u << 17) + 1)
|
||||
#define OMEGA3 309185662
|
||||
#define Q3_CRT_CST 594011630 // (Q1*Q2*Q4)^-1 mod Q3
|
||||
|
||||
#define Q4 ((1u << 30) - 42 * (1u << 17) + 1)
|
||||
#define OMEGA4 846468380
|
||||
#define Q4_CRT_CST 140177212 // (Q1*Q2*Q3)^-1 mod Q4
|
||||
#endif
|
||||
|
||||
/**
|
||||
* 31-bit primes and 2*2^16 roots of unity
|
||||
*/
|
||||
#ifdef SPQLIOS_Q120_USE_31_BIT_PRIMES
|
||||
#define Q1 ((1u << 31) - 1 * (1u << 17) + 1)
|
||||
#define OMEGA1 1615402923
|
||||
#define Q1_CRT_CST 1811422063 // (Q2*Q3*Q4)^-1 mod Q1
|
||||
|
||||
#define Q2 ((1u << 31) - 4 * (1u << 17) + 1)
|
||||
#define OMEGA2 1137738560
|
||||
#define Q2_CRT_CST 2093150204 // (Q1*Q3*Q4)^-1 mod Q2
|
||||
|
||||
#define Q3 ((1u << 31) - 11 * (1u << 17) + 1)
|
||||
#define OMEGA3 154880552
|
||||
#define Q3_CRT_CST 164149010 // (Q1*Q2*Q4)^-1 mod Q3
|
||||
|
||||
#define Q4 ((1u << 31) - 23 * (1u << 17) + 1)
|
||||
#define OMEGA4 558784885
|
||||
#define Q4_CRT_CST 225197446 // (Q1*Q2*Q3)^-1 mod Q4
|
||||
#endif
|
||||
|
||||
static const uint32_t PRIMES_VEC[4] = {Q1, Q2, Q3, Q4};
|
||||
static const uint32_t OMEGAS_VEC[4] = {OMEGA1, OMEGA2, OMEGA3, OMEGA4};
|
||||
|
||||
#define MAX_ELL 10000
|
||||
|
||||
// each number x mod Q120 is represented by uint64_t[4] with (non-unique) values (x mod q1, x mod q2,x mod q3,x mod q4),
|
||||
// each between [0 and 2^32-1]
|
||||
typedef struct _q120a q120a;
|
||||
|
||||
// each number x mod Q120 is represented by uint64_t[4] with (non-unique) values (x mod q1, x mod q2,x mod q3,x mod q4),
|
||||
// each between [0 and 2^64-1]
|
||||
typedef struct _q120b q120b;
|
||||
|
||||
// each number x mod Q120 is represented by uint32_t[8] with values (x mod q1, 2^32x mod q1, x mod q2, 2^32.x mod q2, x
|
||||
// mod q3, 2^32.x mod q3, x mod q4, 2^32.x mod q4) each between [0 and 2^32-1]
|
||||
typedef struct _q120c q120c;
|
||||
|
||||
typedef struct _q120x2b q120x2b;
|
||||
typedef struct _q120x2c q120x2c;
|
||||
|
||||
#endif // SPQLIOS_Q120_COMMON_H
|
||||
@@ -1,5 +0,0 @@
|
||||
#include "q120_ntt_private.h"
|
||||
|
||||
EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data) { UNDEFINED(); }
|
||||
|
||||
EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data) { UNDEFINED(); }
|
||||
@@ -1,340 +0,0 @@
|
||||
#include <assert.h>
|
||||
#include <inttypes.h>
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "q120_ntt_private.h"
|
||||
|
||||
q120_ntt_precomp* new_precomp(const uint64_t n) {
|
||||
q120_ntt_precomp* precomp = malloc(sizeof(*precomp));
|
||||
precomp->n = n;
|
||||
|
||||
assert(n && !(n & (n - 1)) && n <= (1 << 16)); // n is a power of 2 smaller than 2^16
|
||||
const uint64_t logN = ceil(log2(n));
|
||||
precomp->level_metadata = malloc((logN + 2) * sizeof(*precomp->level_metadata));
|
||||
|
||||
precomp->powomega = spqlios_alloc_custom_align(32, 4 * 2 * n * sizeof(*(precomp->powomega)));
|
||||
|
||||
return precomp;
|
||||
}
|
||||
|
||||
uint32_t modq_pow(const uint32_t x, const int64_t n, const uint32_t q) {
|
||||
uint64_t np = (n % (q - 1) + q - 1) % (q - 1);
|
||||
|
||||
uint64_t val_pow = x;
|
||||
uint64_t res = 1;
|
||||
while (np != 0) {
|
||||
if (np & 1) res = (res * val_pow) % q;
|
||||
val_pow = (val_pow * val_pow) % q;
|
||||
np >>= 1;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void fill_omegas(const uint64_t n, uint32_t omegas[4]) {
|
||||
for (uint64_t k = 0; k < 4; ++k) {
|
||||
omegas[k] = modq_pow(OMEGAS_VEC[k], (1 << 16) / n, PRIMES_VEC[k]);
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
|
||||
const uint64_t logQ = ceil(log2(PRIMES_VEC[0]));
|
||||
for (int k = 1; k < 4; ++k) {
|
||||
if (logQ != ceil(log2(PRIMES_VEC[k]))) {
|
||||
fprintf(stderr, "The 4 primes must have the same bit-size\n");
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
// check if each omega is a 2.n primitive root of unity
|
||||
for (uint64_t k = 0; k < 4; ++k) {
|
||||
assert(modq_pow(omegas[k], 2 * n, PRIMES_VEC[k]) == 1);
|
||||
for (uint64_t i = 1; i < 2 * n; ++i) {
|
||||
assert(modq_pow(omegas[k], i, PRIMES_VEC[k]) != 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (logQ > 31) {
|
||||
fprintf(stderr, "Modulus q bit-size is larger than 30 bit\n");
|
||||
exit(-1);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
uint64_t fill_reduction_meta(const uint64_t bs_start, q120_ntt_reduc_step_precomp* reduc_metadata) {
|
||||
// fill reduction metadata
|
||||
uint64_t bs_after_reduc = -1;
|
||||
{
|
||||
uint64_t min_h = -1;
|
||||
|
||||
for (uint64_t h = bs_start / 2; h < bs_start; ++h) {
|
||||
uint64_t t = 0;
|
||||
for (uint64_t k = 0; k < 4; ++k) {
|
||||
const uint64_t t1 = bs_start - h + (uint64_t)ceil(log2((UINT64_C(1) << h) % PRIMES_VEC[k]));
|
||||
const uint64_t t2 = UINT64_C(1) + ((t1 > h) ? t1 : h);
|
||||
if (t < t2) t = t2;
|
||||
}
|
||||
if (t < bs_after_reduc) {
|
||||
min_h = h;
|
||||
bs_after_reduc = t;
|
||||
}
|
||||
}
|
||||
|
||||
reduc_metadata->h = min_h;
|
||||
reduc_metadata->mask = (UINT64_C(1) << min_h) - 1;
|
||||
for (uint64_t k = 0; k < 4; ++k) {
|
||||
reduc_metadata->modulo_red_cst[k] = (UINT64_C(1) << min_h) % PRIMES_VEC[k];
|
||||
}
|
||||
|
||||
assert(bs_after_reduc < 64);
|
||||
}
|
||||
|
||||
return bs_after_reduc;
|
||||
}
|
||||
|
||||
uint64_t round_up_half_n(const uint64_t n) { return (n + 1) / 2; }
|
||||
|
||||
EXPORT q120_ntt_precomp* q120_new_ntt_bb_precomp(const uint64_t n) {
|
||||
uint32_t omega_vec[4];
|
||||
fill_omegas(n, omega_vec);
|
||||
|
||||
const uint64_t logQ = ceil(log2(PRIMES_VEC[0]));
|
||||
|
||||
q120_ntt_precomp* precomp = new_precomp(n);
|
||||
|
||||
uint64_t bs = precomp->input_bit_size = 64;
|
||||
|
||||
LOG("NTT parameters:\n");
|
||||
LOG("\tsize = %" PRIu64 "\n", n)
|
||||
LOG("\tlogQ = %" PRIu64 "\n", logQ);
|
||||
LOG("\tinput bit-size = %" PRIu64 "\n", bs);
|
||||
|
||||
if (n == 1) return precomp;
|
||||
|
||||
// fill reduction metadata
|
||||
uint64_t bs_after_reduc = fill_reduction_meta(bs, &(precomp->reduc_metadata));
|
||||
|
||||
// forward metadata
|
||||
q120_ntt_step_precomp* level_metadata_ptr = precomp->level_metadata;
|
||||
|
||||
// first level a_k.omega^k
|
||||
{
|
||||
const uint64_t half_bs = (bs + 1) / 2;
|
||||
level_metadata_ptr->half_bs = half_bs;
|
||||
level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1);
|
||||
level_metadata_ptr->bs = bs = half_bs + logQ + 1;
|
||||
LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 " (a_k.omega^k) \n", n, bs);
|
||||
level_metadata_ptr++;
|
||||
}
|
||||
|
||||
for (uint64_t nn = n; nn >= 4; nn /= 2) {
|
||||
level_metadata_ptr->reduce = (bs == 64);
|
||||
if (level_metadata_ptr->reduce) {
|
||||
bs = bs_after_reduc;
|
||||
LOG("\treduce output bit-size = %" PRIu64 "\n", bs);
|
||||
}
|
||||
|
||||
for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs - logQ);
|
||||
|
||||
double bs_1 = bs + 1; // bit-size of term a+b or a-b
|
||||
|
||||
const uint64_t half_bs = round_up_half_n(bs_1);
|
||||
uint64_t bs_2 = half_bs + logQ + 1; // bit-size of term (a-b).omega^k
|
||||
bs = (bs_1 > bs_2) ? bs_1 : bs_2;
|
||||
assert(bs <= 64);
|
||||
|
||||
level_metadata_ptr->bs = bs;
|
||||
level_metadata_ptr->half_bs = half_bs;
|
||||
level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1);
|
||||
level_metadata_ptr++;
|
||||
|
||||
LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", nn / 2, bs);
|
||||
}
|
||||
|
||||
// last level (a-b, a+b)
|
||||
{
|
||||
level_metadata_ptr->reduce = (bs == 64);
|
||||
if (level_metadata_ptr->reduce) {
|
||||
bs = bs_after_reduc;
|
||||
LOG("\treduce output bit-size = %" PRIu64 "\n", bs);
|
||||
}
|
||||
for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = ((uint64_t)PRIMES_VEC[k] << (bs - logQ));
|
||||
level_metadata_ptr->bs = ++bs;
|
||||
level_metadata_ptr->half_bs = level_metadata_ptr->mask = UINT64_C(0); // not used
|
||||
|
||||
LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", UINT64_C(1), bs);
|
||||
}
|
||||
precomp->output_bit_size = bs;
|
||||
|
||||
// omega powers
|
||||
uint64_t* powomega = malloc(sizeof(*powomega) * 2 * n);
|
||||
for (uint64_t k = 0; k < 4; ++k) {
|
||||
const uint64_t q = PRIMES_VEC[k];
|
||||
|
||||
for (uint64_t i = 0; i < 2 * n; ++i) {
|
||||
powomega[i] = modq_pow(omega_vec[k], i, q);
|
||||
}
|
||||
|
||||
uint64_t* powomega_ptr = precomp->powomega + k;
|
||||
level_metadata_ptr = precomp->level_metadata;
|
||||
|
||||
{
|
||||
// const uint64_t hpow = UINT64_C(1) << level_metadata_ptr->half_bs;
|
||||
for (uint64_t i = 0; i < n; ++i) {
|
||||
uint64_t t = powomega[i];
|
||||
uint64_t t1 = (t << level_metadata_ptr->half_bs) % q;
|
||||
powomega_ptr[4 * i] = (t1 << 32) + t;
|
||||
}
|
||||
powomega_ptr += 4 * n;
|
||||
level_metadata_ptr++;
|
||||
}
|
||||
|
||||
for (uint64_t nn = n; nn >= 4; nn /= 2) {
|
||||
const uint64_t halfnn = nn / 2;
|
||||
const uint64_t m = n / halfnn;
|
||||
|
||||
// const uint64_t hpow = UINT64_C(1) << level_metadata_ptr->half_bs;
|
||||
for (uint64_t i = 1; i < halfnn; ++i) {
|
||||
uint64_t t = powomega[i * m];
|
||||
uint64_t t1 = (t << level_metadata_ptr->half_bs) % q;
|
||||
powomega_ptr[4 * (i - 1)] = (t1 << 32) + t;
|
||||
}
|
||||
powomega_ptr += 4 * (halfnn - 1);
|
||||
level_metadata_ptr++;
|
||||
}
|
||||
}
|
||||
free(powomega);
|
||||
|
||||
return precomp;
|
||||
}
|
||||
|
||||
EXPORT q120_ntt_precomp* q120_new_intt_bb_precomp(const uint64_t n) {
|
||||
uint32_t omega_vec[4];
|
||||
fill_omegas(n, omega_vec);
|
||||
|
||||
const uint64_t logQ = ceil(log2(PRIMES_VEC[0]));
|
||||
|
||||
q120_ntt_precomp* precomp = new_precomp(n);
|
||||
|
||||
uint64_t bs = precomp->input_bit_size = 64;
|
||||
|
||||
LOG("iNTT parameters:\n");
|
||||
LOG("\tsize = %" PRIu64 "\n", n)
|
||||
LOG("\tlogQ = %" PRIu64 "\n", logQ);
|
||||
LOG("\tinput bit-size = %" PRIu64 "\n", bs);
|
||||
|
||||
if (n == 1) return precomp;
|
||||
|
||||
// fill reduction metadata
|
||||
uint64_t bs_after_reduc = fill_reduction_meta(bs, &(precomp->reduc_metadata));
|
||||
|
||||
// backward metadata
|
||||
q120_ntt_step_precomp* level_metadata_ptr = precomp->level_metadata;
|
||||
|
||||
// first level (a+b, a-b) adds 1-bit
|
||||
{
|
||||
level_metadata_ptr->reduce = (bs == 64);
|
||||
if (level_metadata_ptr->reduce) {
|
||||
bs = bs_after_reduc;
|
||||
LOG("\treduce output bit-size = %" PRIu64 "\n", bs);
|
||||
}
|
||||
|
||||
for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs - logQ);
|
||||
|
||||
level_metadata_ptr->bs = ++bs;
|
||||
level_metadata_ptr->half_bs = level_metadata_ptr->mask = UINT64_C(0); // not used
|
||||
LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", UINT64_C(1), bs);
|
||||
level_metadata_ptr++;
|
||||
}
|
||||
|
||||
for (uint64_t nn = 4; nn <= n; nn *= 2) {
|
||||
level_metadata_ptr->reduce = (bs == 64);
|
||||
if (level_metadata_ptr->reduce) {
|
||||
bs = bs_after_reduc;
|
||||
LOG("\treduce output bit-size = %" PRIu64 "\n", bs);
|
||||
}
|
||||
|
||||
const uint64_t half_bs = round_up_half_n(bs);
|
||||
const uint64_t bs_mult = half_bs + logQ + 1; // bit-size of term b.omega^k
|
||||
bs = 1 + ((bs > bs_mult) ? bs : bs_mult); // bit-size of a+b.omega^k or a-b.omega^k
|
||||
assert(bs <= 64);
|
||||
|
||||
level_metadata_ptr->bs = bs;
|
||||
level_metadata_ptr->half_bs = half_bs;
|
||||
level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1);
|
||||
for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs_mult - logQ);
|
||||
level_metadata_ptr++;
|
||||
|
||||
LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", nn / 2, bs);
|
||||
}
|
||||
|
||||
// last level a_k.omega^k
|
||||
{
|
||||
level_metadata_ptr->reduce = (bs == 64);
|
||||
if (level_metadata_ptr->reduce) {
|
||||
bs = bs_after_reduc;
|
||||
LOG("\treduce output bit-size = %" PRIu64 "\n", bs);
|
||||
}
|
||||
|
||||
const uint64_t half_bs = round_up_half_n(bs);
|
||||
|
||||
bs = half_bs + logQ + 1; // bit-size of term a.omega^k
|
||||
assert(bs <= 64);
|
||||
|
||||
level_metadata_ptr->bs = bs;
|
||||
level_metadata_ptr->half_bs = half_bs;
|
||||
level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1);
|
||||
for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs - logQ);
|
||||
|
||||
LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", n, bs);
|
||||
}
|
||||
|
||||
// omega powers
|
||||
uint32_t* powomegabar = malloc(sizeof(*powomegabar) * 2 * n);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const uint64_t q = PRIMES_VEC[k];
|
||||
|
||||
for (uint64_t i = 0; i < 2 * n; ++i) {
|
||||
powomegabar[i] = modq_pow(omega_vec[k], -i, q);
|
||||
}
|
||||
|
||||
uint64_t* powomega_ptr = precomp->powomega + k;
|
||||
level_metadata_ptr = precomp->level_metadata + 1;
|
||||
|
||||
for (uint64_t nn = 4; nn <= n; nn *= 2) {
|
||||
const uint64_t halfnn = nn / 2;
|
||||
const uint64_t m = n / halfnn;
|
||||
|
||||
for (uint64_t i = 1; i < halfnn; ++i) {
|
||||
uint64_t t = powomegabar[i * m];
|
||||
uint64_t t1 = (t << level_metadata_ptr->half_bs) % q;
|
||||
powomega_ptr[4 * (i - 1)] = (t1 << 32) + t;
|
||||
}
|
||||
powomega_ptr += 4 * (halfnn - 1);
|
||||
level_metadata_ptr++;
|
||||
}
|
||||
|
||||
{
|
||||
const uint64_t invNmod = modq_pow(n, -1, q);
|
||||
for (uint64_t i = 0; i < n; ++i) {
|
||||
uint64_t t = (powomegabar[i] * invNmod) % q;
|
||||
uint64_t t1 = (t << level_metadata_ptr->half_bs) % q;
|
||||
powomega_ptr[4 * i] = (t1 << 32) + t;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(powomegabar);
|
||||
|
||||
return precomp;
|
||||
}
|
||||
|
||||
void del_precomp(q120_ntt_precomp* precomp) {
|
||||
spqlios_free(precomp->powomega);
|
||||
free(precomp->level_metadata);
|
||||
free(precomp);
|
||||
}
|
||||
|
||||
EXPORT void q120_del_ntt_bb_precomp(q120_ntt_precomp* precomp) { del_precomp(precomp); }
|
||||
|
||||
EXPORT void q120_del_intt_bb_precomp(q120_ntt_precomp* precomp) { del_precomp(precomp); }
|
||||
@@ -1,25 +0,0 @@
|
||||
#ifndef SPQLIOS_Q120_NTT_H
|
||||
#define SPQLIOS_Q120_NTT_H
|
||||
|
||||
#include "../commons.h"
|
||||
#include "q120_common.h"
|
||||
|
||||
typedef struct _q120_ntt_precomp q120_ntt_precomp;
|
||||
|
||||
EXPORT q120_ntt_precomp* q120_new_ntt_bb_precomp(const uint64_t n);
|
||||
EXPORT void q120_del_ntt_bb_precomp(q120_ntt_precomp* precomp);
|
||||
|
||||
EXPORT q120_ntt_precomp* q120_new_intt_bb_precomp(const uint64_t n);
|
||||
EXPORT void q120_del_intt_bb_precomp(q120_ntt_precomp* precomp);
|
||||
|
||||
/**
|
||||
* @brief computes a direct ntt in-place over data.
|
||||
*/
|
||||
EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data);
|
||||
|
||||
/**
|
||||
* @brief computes an inverse ntt in-place over data.
|
||||
*/
|
||||
EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data);
|
||||
|
||||
#endif // SPQLIOS_Q120_NTT_H
|
||||
@@ -1,479 +0,0 @@
|
||||
#include <assert.h>
|
||||
#include <immintrin.h>
|
||||
#include <inttypes.h>
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "q120_common.h"
|
||||
#include "q120_ntt_private.h"
|
||||
|
||||
// at which level to switch from computations by level to computations by block
|
||||
#define CHANGE_MODE_N 1024
|
||||
|
||||
__always_inline __m256i split_precompmul_si256(__m256i inp, __m256i powomega, const uint64_t h, const __m256i mask) {
|
||||
const __m256i inp_low = _mm256_and_si256(inp, mask);
|
||||
const __m256i t1 = _mm256_mul_epu32(inp_low, powomega);
|
||||
|
||||
const __m256i inp_high = _mm256_srli_epi64(inp, h);
|
||||
const __m256i powomega_high = _mm256_srli_epi64(powomega, 32);
|
||||
const __m256i t2 = _mm256_mul_epu32(inp_high, powomega_high);
|
||||
|
||||
return _mm256_add_epi64(t1, t2);
|
||||
}
|
||||
|
||||
__always_inline __m256i modq_red(const __m256i x, const uint64_t h, const __m256i mask, const __m256i _2_pow_h_modq) {
|
||||
__m256i xh = _mm256_srli_epi64(x, h);
|
||||
__m256i xl = _mm256_and_si256(x, mask);
|
||||
__m256i xh_1 = _mm256_mul_epu32(xh, _2_pow_h_modq);
|
||||
return _mm256_add_epi64(xl, xh_1);
|
||||
}
|
||||
|
||||
void print_data(const uint64_t n, const uint64_t* const data, const uint64_t q) {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
printf("%" PRIu64 " ", *(data + i) % q);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
double max_bit_size(const void* const begin, const void* const end) {
|
||||
double bs = 0;
|
||||
const uint64_t* data = (uint64_t*)begin;
|
||||
while (data != end) {
|
||||
double t = log2(*(data++));
|
||||
if (bs < t) {
|
||||
bs = t;
|
||||
}
|
||||
}
|
||||
return bs;
|
||||
}
|
||||
|
||||
void ntt_iter_first(__m256i* const begin, const __m256i* const end, const q120_ntt_step_precomp* const itData,
|
||||
const __m256i* powomega) {
|
||||
const uint64_t h = itData->half_bs;
|
||||
const __m256i vmask = _mm256_set1_epi64x(itData->mask);
|
||||
|
||||
__m256i* data = begin;
|
||||
while (data < end) {
|
||||
__m256i x = _mm256_loadu_si256(data);
|
||||
__m256i po = _mm256_loadu_si256(powomega);
|
||||
__m256i r = split_precompmul_si256(x, po, h, vmask);
|
||||
_mm256_storeu_si256(data, r);
|
||||
|
||||
data++;
|
||||
powomega++;
|
||||
}
|
||||
}
|
||||
|
||||
void ntt_iter(const uint64_t nn, __m256i* const begin, const __m256i* const end,
|
||||
const q120_ntt_step_precomp* const itData, const __m256i* const powomega) {
|
||||
assert(nn % 2 == 0);
|
||||
const uint64_t halfnn = nn / 2;
|
||||
|
||||
const __m256i vq2bs = _mm256_loadu_si256((__m256i*)itData->q2bs);
|
||||
const __m256i vmask = _mm256_set1_epi64x(itData->mask);
|
||||
|
||||
__m256i* data = begin;
|
||||
while (data < end) {
|
||||
__m256i* ptr1 = data;
|
||||
__m256i* ptr2 = data + halfnn;
|
||||
|
||||
const __m256i a = _mm256_loadu_si256(ptr1);
|
||||
const __m256i b = _mm256_loadu_si256(ptr2);
|
||||
|
||||
const __m256i ap = _mm256_add_epi64(a, b);
|
||||
_mm256_storeu_si256(ptr1, ap);
|
||||
|
||||
const __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b);
|
||||
_mm256_storeu_si256(ptr2, bp);
|
||||
|
||||
ptr1++;
|
||||
ptr2++;
|
||||
|
||||
const __m256i* po_ptr = powomega;
|
||||
for (uint64_t i = 1; i < halfnn; ++i) {
|
||||
__m256i a = _mm256_loadu_si256(ptr1);
|
||||
__m256i b = _mm256_loadu_si256(ptr2);
|
||||
|
||||
__m256i ap = _mm256_add_epi64(a, b);
|
||||
|
||||
_mm256_storeu_si256(ptr1, ap);
|
||||
|
||||
__m256i b1 = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b);
|
||||
__m256i po = _mm256_loadu_si256(po_ptr);
|
||||
|
||||
__m256i bp = split_precompmul_si256(b1, po, itData->half_bs, vmask);
|
||||
|
||||
_mm256_storeu_si256(ptr2, bp);
|
||||
|
||||
ptr1++;
|
||||
ptr2++;
|
||||
po_ptr++;
|
||||
}
|
||||
data += nn;
|
||||
}
|
||||
}
|
||||
|
||||
void ntt_iter_red(const uint64_t nn, __m256i* const begin, const __m256i* const end,
|
||||
const q120_ntt_step_precomp* const itData, const __m256i* const powomega,
|
||||
const q120_ntt_reduc_step_precomp* const reduc_precomp) {
|
||||
assert(nn % 2 == 0);
|
||||
const uint64_t halfnn = nn / 2;
|
||||
|
||||
const __m256i vq2bs = _mm256_loadu_si256((__m256i*)itData->q2bs);
|
||||
const __m256i vmask = _mm256_set1_epi64x(itData->mask);
|
||||
|
||||
const __m256i reduc_mask = _mm256_set1_epi64x(reduc_precomp->mask);
|
||||
const __m256i reduc_cst = _mm256_loadu_si256((__m256i*)reduc_precomp->modulo_red_cst);
|
||||
|
||||
__m256i* data = begin;
|
||||
while (data < end) {
|
||||
__m256i* ptr1 = data;
|
||||
__m256i* ptr2 = data + halfnn;
|
||||
|
||||
__m256i a = _mm256_loadu_si256(ptr1);
|
||||
__m256i b = _mm256_loadu_si256(ptr2);
|
||||
|
||||
a = modq_red(a, reduc_precomp->h, reduc_mask, reduc_cst);
|
||||
b = modq_red(b, reduc_precomp->h, reduc_mask, reduc_cst);
|
||||
|
||||
const __m256i ap = _mm256_add_epi64(a, b);
|
||||
_mm256_storeu_si256(ptr1, ap);
|
||||
|
||||
const __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b);
|
||||
_mm256_storeu_si256(ptr2, bp);
|
||||
|
||||
ptr1++;
|
||||
ptr2++;
|
||||
|
||||
const __m256i* po_ptr = powomega;
|
||||
for (uint64_t i = 1; i < halfnn; ++i) {
|
||||
__m256i a = _mm256_loadu_si256(ptr1);
|
||||
__m256i b = _mm256_loadu_si256(ptr2);
|
||||
|
||||
a = modq_red(a, reduc_precomp->h, reduc_mask, reduc_cst);
|
||||
b = modq_red(b, reduc_precomp->h, reduc_mask, reduc_cst);
|
||||
|
||||
__m256i ap = _mm256_add_epi64(a, b);
|
||||
|
||||
_mm256_storeu_si256(ptr1, ap);
|
||||
|
||||
__m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b);
|
||||
__m256i po = _mm256_loadu_si256(po_ptr);
|
||||
bp = split_precompmul_si256(bp, po, itData->half_bs, vmask);
|
||||
|
||||
_mm256_storeu_si256(ptr2, bp);
|
||||
|
||||
ptr1++;
|
||||
ptr2++;
|
||||
po_ptr++;
|
||||
}
|
||||
data += nn;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data_ptr) {
|
||||
// assert((size_t)data_ptr % 32 == 0); // alignment check
|
||||
|
||||
const uint64_t n = precomp->n;
|
||||
if (n == 1) return;
|
||||
|
||||
const q120_ntt_step_precomp* itData = precomp->level_metadata;
|
||||
const __m256i* powomega = (__m256i*)precomp->powomega;
|
||||
|
||||
__m256i* const begin = (__m256i*)data_ptr;
|
||||
const __m256i* const end = ((__m256i*)data_ptr) + n;
|
||||
|
||||
if (CHECK_BOUNDS) {
|
||||
double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end);
|
||||
LOG("Input %lf %" PRIu64 "\n", bs, precomp->input_bit_size);
|
||||
assert(bs <= precomp->input_bit_size);
|
||||
}
|
||||
|
||||
// first iteration a_k.omega^k
|
||||
ntt_iter_first(begin, end, itData, powomega);
|
||||
|
||||
if (CHECK_BOUNDS) {
|
||||
double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end);
|
||||
LOG("Iter %3" PRIu64 " - %lf %" PRIu64 "\n", n, bs, itData->bs);
|
||||
assert(bs < itData->bs);
|
||||
}
|
||||
|
||||
powomega += n;
|
||||
itData++;
|
||||
|
||||
const uint64_t split_nn = (CHANGE_MODE_N > n) ? n : CHANGE_MODE_N;
|
||||
// const uint64_t split_nn = 2;
|
||||
|
||||
// computations by level
|
||||
for (uint64_t nn = n; nn > split_nn; nn /= 2) {
|
||||
const uint64_t halfnn = nn / 2;
|
||||
|
||||
if (itData->reduce) {
|
||||
ntt_iter_red(nn, begin, end, itData, powomega, &precomp->reduc_metadata);
|
||||
} else {
|
||||
ntt_iter(nn, begin, end, itData, powomega);
|
||||
}
|
||||
|
||||
if (CHECK_BOUNDS) {
|
||||
double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end);
|
||||
LOG("Iter %3" PRIu64 " - %lf %" PRIu64 " %c\n", nn / 2, bs, itData->bs, itData->reduce ? '*' : ' ');
|
||||
assert(bs < itData->bs);
|
||||
}
|
||||
|
||||
powomega += halfnn - 1;
|
||||
itData++;
|
||||
}
|
||||
|
||||
// computations by memory block
|
||||
if (split_nn >= 2) {
|
||||
const q120_ntt_step_precomp* itData1 = itData;
|
||||
const __m256i* powomega1 = powomega;
|
||||
for (__m256i* it = begin; it < end; it += split_nn) {
|
||||
__m256i* const begin1 = it;
|
||||
const __m256i* const end1 = it + split_nn;
|
||||
|
||||
itData = itData1;
|
||||
powomega = powomega1;
|
||||
for (uint64_t nn = split_nn; nn >= 2; nn /= 2) {
|
||||
const uint64_t halfnn = nn / 2;
|
||||
|
||||
if (itData->reduce) {
|
||||
ntt_iter_red(nn, begin1, end1, itData, powomega, &precomp->reduc_metadata);
|
||||
} else {
|
||||
ntt_iter(nn, begin1, end1, itData, powomega);
|
||||
}
|
||||
|
||||
if (CHECK_BOUNDS) {
|
||||
double bs __attribute__((unused)) = max_bit_size((uint64_t*)begin1, (uint64_t*)end1);
|
||||
// LOG("Iter %3lu - %lf %lu\n", nn / 2, bs, itData->bs);
|
||||
assert(bs < itData->bs);
|
||||
}
|
||||
|
||||
powomega += halfnn - 1;
|
||||
itData++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (CHECK_BOUNDS) {
|
||||
double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end);
|
||||
LOG("Iter %3" PRIu64 " - %lf %" PRIu64 "\n", UINT64_C(1), bs, precomp->output_bit_size);
|
||||
assert(bs < precomp->output_bit_size);
|
||||
}
|
||||
}
|
||||
|
||||
void intt_iter(const uint64_t nn, __m256i* const begin, const __m256i* const end,
|
||||
const q120_ntt_step_precomp* const itData, const __m256i* const powomega) {
|
||||
assert(nn % 2 == 0);
|
||||
const uint64_t halfnn = nn / 2;
|
||||
|
||||
const __m256i vq2bs = _mm256_loadu_si256((__m256i*)itData->q2bs);
|
||||
const __m256i vmask = _mm256_set1_epi64x(itData->mask);
|
||||
|
||||
__m256i* data = begin;
|
||||
while (data < end) {
|
||||
__m256i* ptr1 = data;
|
||||
__m256i* ptr2 = data + halfnn;
|
||||
|
||||
const __m256i a = _mm256_loadu_si256(ptr1);
|
||||
const __m256i b = _mm256_loadu_si256(ptr2);
|
||||
|
||||
const __m256i ap = _mm256_add_epi64(a, b);
|
||||
_mm256_storeu_si256(ptr1, ap);
|
||||
|
||||
const __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b);
|
||||
_mm256_storeu_si256(ptr2, bp);
|
||||
|
||||
ptr1++;
|
||||
ptr2++;
|
||||
|
||||
const __m256i* po_ptr = powomega;
|
||||
for (uint64_t i = 1; i < halfnn; ++i) {
|
||||
__m256i a = _mm256_loadu_si256(ptr1);
|
||||
__m256i b = _mm256_loadu_si256(ptr2);
|
||||
|
||||
__m256i po = _mm256_loadu_si256(po_ptr);
|
||||
__m256i bo = split_precompmul_si256(b, po, itData->half_bs, vmask);
|
||||
|
||||
__m256i ap = _mm256_add_epi64(a, bo);
|
||||
|
||||
_mm256_storeu_si256(ptr1, ap);
|
||||
|
||||
__m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), bo);
|
||||
|
||||
_mm256_storeu_si256(ptr2, bp);
|
||||
|
||||
ptr1++;
|
||||
ptr2++;
|
||||
po_ptr++;
|
||||
}
|
||||
data += nn;
|
||||
}
|
||||
}
|
||||
|
||||
void intt_iter_red(const uint64_t nn, __m256i* const begin, const __m256i* const end,
|
||||
const q120_ntt_step_precomp* const itData, const __m256i* const powomega,
|
||||
const q120_ntt_reduc_step_precomp* const reduc_precomp) {
|
||||
assert(nn % 2 == 0);
|
||||
const uint64_t halfnn = nn / 2;
|
||||
|
||||
const __m256i vq2bs = _mm256_loadu_si256((__m256i*)itData->q2bs);
|
||||
const __m256i vmask = _mm256_set1_epi64x(itData->mask);
|
||||
|
||||
const __m256i reduc_mask = _mm256_set1_epi64x(reduc_precomp->mask);
|
||||
const __m256i reduc_cst = _mm256_loadu_si256((__m256i*)reduc_precomp->modulo_red_cst);
|
||||
|
||||
__m256i* data = begin;
|
||||
while (data < end) {
|
||||
__m256i* ptr1 = data;
|
||||
__m256i* ptr2 = data + halfnn;
|
||||
|
||||
__m256i a = _mm256_loadu_si256(ptr1);
|
||||
__m256i b = _mm256_loadu_si256(ptr2);
|
||||
|
||||
a = modq_red(a, reduc_precomp->h, reduc_mask, reduc_cst);
|
||||
b = modq_red(b, reduc_precomp->h, reduc_mask, reduc_cst);
|
||||
|
||||
const __m256i ap = _mm256_add_epi64(a, b);
|
||||
_mm256_storeu_si256(ptr1, ap);
|
||||
|
||||
const __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b);
|
||||
_mm256_storeu_si256(ptr2, bp);
|
||||
|
||||
ptr1++;
|
||||
ptr2++;
|
||||
|
||||
const __m256i* po_ptr = powomega;
|
||||
for (uint64_t i = 1; i < halfnn; ++i) {
|
||||
__m256i a = _mm256_loadu_si256(ptr1);
|
||||
__m256i b = _mm256_loadu_si256(ptr2);
|
||||
|
||||
a = modq_red(a, reduc_precomp->h, reduc_mask, reduc_cst);
|
||||
b = modq_red(b, reduc_precomp->h, reduc_mask, reduc_cst);
|
||||
|
||||
__m256i po = _mm256_loadu_si256(po_ptr);
|
||||
__m256i bo = split_precompmul_si256(b, po, itData->half_bs, vmask);
|
||||
|
||||
__m256i ap = _mm256_add_epi64(a, bo);
|
||||
|
||||
_mm256_storeu_si256(ptr1, ap);
|
||||
|
||||
__m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), bo);
|
||||
|
||||
_mm256_storeu_si256(ptr2, bp);
|
||||
|
||||
ptr1++;
|
||||
ptr2++;
|
||||
po_ptr++;
|
||||
}
|
||||
data += nn;
|
||||
}
|
||||
}
|
||||
|
||||
void ntt_iter_first_red(__m256i* const begin, const __m256i* const end, const q120_ntt_step_precomp* const itData,
|
||||
const __m256i* powomega, const q120_ntt_reduc_step_precomp* const reduc_precomp) {
|
||||
const uint64_t h = itData->half_bs;
|
||||
const __m256i vmask = _mm256_set1_epi64x(itData->mask);
|
||||
|
||||
const __m256i reduc_mask = _mm256_set1_epi64x(reduc_precomp->mask);
|
||||
const __m256i reduc_cst = _mm256_loadu_si256((__m256i*)reduc_precomp->modulo_red_cst);
|
||||
|
||||
__m256i* data = begin;
|
||||
while (data < end) {
|
||||
__m256i x = _mm256_loadu_si256(data);
|
||||
x = modq_red(x, reduc_precomp->h, reduc_mask, reduc_cst);
|
||||
__m256i po = _mm256_loadu_si256(powomega);
|
||||
__m256i r = split_precompmul_si256(x, po, h, vmask);
|
||||
_mm256_storeu_si256(data, r);
|
||||
|
||||
data++;
|
||||
powomega++;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data_ptr) {
|
||||
// assert((size_t)data_ptr % 32 == 0); // alignment check
|
||||
|
||||
const uint64_t n = precomp->n;
|
||||
if (n == 1) return;
|
||||
|
||||
const q120_ntt_step_precomp* itData = precomp->level_metadata;
|
||||
const __m256i* powomega = (__m256i*)precomp->powomega;
|
||||
|
||||
__m256i* const begin = (__m256i*)data_ptr;
|
||||
const __m256i* const end = ((__m256i*)data_ptr) + n;
|
||||
|
||||
if (CHECK_BOUNDS) {
|
||||
double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end);
|
||||
LOG("Input %lf %" PRIu64 "\n", bs, precomp->input_bit_size);
|
||||
assert(bs <= precomp->input_bit_size);
|
||||
}
|
||||
|
||||
const uint64_t split_nn = (CHANGE_MODE_N > n) ? n : CHANGE_MODE_N;
|
||||
|
||||
// computations by memory block
|
||||
if (split_nn >= 2) {
|
||||
const q120_ntt_step_precomp* itData1 = itData;
|
||||
const __m256i* powomega1 = powomega;
|
||||
for (__m256i* it = begin; it < end; it += split_nn) {
|
||||
__m256i* const begin1 = it;
|
||||
const __m256i* const end1 = it + split_nn;
|
||||
|
||||
itData = itData1;
|
||||
powomega = powomega1;
|
||||
for (uint64_t nn = 2; nn <= split_nn; nn *= 2) {
|
||||
const uint64_t halfnn = nn / 2;
|
||||
|
||||
if (itData->reduce) {
|
||||
intt_iter_red(nn, begin1, end1, itData, powomega, &precomp->reduc_metadata);
|
||||
} else {
|
||||
intt_iter(nn, begin1, end1, itData, powomega);
|
||||
}
|
||||
|
||||
if (CHECK_BOUNDS) {
|
||||
double bs __attribute__((unused)) = max_bit_size((uint64_t*)begin1, (uint64_t*)end1);
|
||||
// LOG("Iter %3lu - %lf %lu\n", nn / 2, bs, itData->bs);
|
||||
assert(bs < itData->bs);
|
||||
}
|
||||
|
||||
powomega += halfnn - 1;
|
||||
itData++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// computations by level
|
||||
// for (uint64_t nn = 2; nn <= n; nn *= 2) {
|
||||
for (uint64_t nn = 2 * split_nn; nn <= n; nn *= 2) {
|
||||
const uint64_t halfnn = nn / 2;
|
||||
|
||||
if (itData->reduce) {
|
||||
intt_iter_red(nn, begin, end, itData, powomega, &precomp->reduc_metadata);
|
||||
} else {
|
||||
intt_iter(nn, begin, end, itData, powomega);
|
||||
}
|
||||
|
||||
if (CHECK_BOUNDS) {
|
||||
double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end);
|
||||
LOG("Iter %3" PRIu64 " - %lf %" PRIu64 " %c\n", nn / 2, bs, itData->bs, itData->reduce ? '*' : ' ');
|
||||
assert(bs < itData->bs);
|
||||
}
|
||||
|
||||
powomega += halfnn - 1;
|
||||
itData++;
|
||||
}
|
||||
|
||||
// last iteration a_k . omega^k . n^-1
|
||||
if (itData->reduce) {
|
||||
ntt_iter_first_red(begin, end, itData, powomega, &precomp->reduc_metadata);
|
||||
} else {
|
||||
ntt_iter_first(begin, end, itData, powomega);
|
||||
}
|
||||
|
||||
if (CHECK_BOUNDS) {
|
||||
double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end);
|
||||
LOG("Iter %3" PRIu64 " - %lf %" PRIu64 "\n", n, bs, itData->bs);
|
||||
assert(bs < itData->bs);
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user